summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/rseq.go130
-rw-r--r--pkg/fspath/BUILD2
-rw-r--r--pkg/fspath/fspath.go24
-rw-r--r--pkg/fspath/fspath_test.go25
-rw-r--r--pkg/sentry/arch/arch.go6
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/fs/file.go7
-rw-r--r--pkg/sentry/fs/fsutil/BUILD2
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go14
-rw-r--r--pkg/sentry/fs/proc/task.go4
-rw-r--r--pkg/sentry/fs/splice.go5
-rw-r--r--pkg/sentry/fs/tty/master.go2
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go15
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go10
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go13
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go8
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go19
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go48
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go15
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go38
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go23
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go190
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go20
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go21
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go13
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go45
-rw-r--r--pkg/sentry/fsimpl/memfs/filesystem.go592
-rw-r--r--pkg/sentry/fsimpl/memfs/regular_file.go154
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD33
-rw-r--r--pkg/sentry/fsimpl/proc/boot_test.go149
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go69
-rw-r--r--pkg/sentry/fsimpl/proc/loadavg.go8
-rw-r--r--pkg/sentry/fsimpl/proc/meminfo.go6
-rw-r--r--pkg/sentry/fsimpl/proc/stat.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task.go337
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go272
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go218
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go92
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go555
-rw-r--r--pkg/sentry/fsimpl/proc/version.go6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD (renamed from pkg/sentry/fsimpl/memfs/BUILD)34
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go (renamed from pkg/sentry/fsimpl/memfs/benchmark_test.go)43
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go (renamed from pkg/sentry/fsimpl/memfs/directory.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go698
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go (renamed from pkg/sentry/fsimpl/memfs/named_pipe.go)6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go (renamed from pkg/sentry/fsimpl/memfs/pipe_test.go)26
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go357
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go224
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go (renamed from pkg/sentry/fsimpl/memfs/symlink.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go (renamed from pkg/sentry/fsimpl/memfs/memfs.go)89
-rw-r--r--pkg/sentry/kernel/kernel.go7
-rw-r--r--pkg/sentry/kernel/rseq.go383
-rw-r--r--pkg/sentry/kernel/shm/shm.go85
-rw-r--r--pkg/sentry/kernel/task.go43
-rw-r--r--pkg/sentry/kernel/task_clone.go9
-rw-r--r--pkg/sentry/kernel/task_exec.go6
-rw-r--r--pkg/sentry/kernel/task_run.go16
-rw-r--r--pkg/sentry/kernel/task_start.go10
-rw-r--r--pkg/sentry/kernel/thread_group.go26
-rw-r--r--pkg/sentry/mm/procfs.go12
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go12
-rw-r--r--pkg/sentry/platform/kvm/physical_map_amd64.go (renamed from pkg/sentry/fsimpl/proc/filesystems.go)17
-rw-r--r--pkg/sentry/platform/kvm/physical_map_arm64.go (renamed from pkg/sentry/fsimpl/proc/proc.go)7
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s29
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s30
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go25
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD1
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go3
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s97
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go26
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s121
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/socket/control/control.go6
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/socket.go29
-rw-r--r--pkg/sentry/socket/netstack/netstack.go54
-rw-r--r--pkg/sentry/socket/unix/io.go13
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go24
-rw-r--r--pkg/sentry/socket/unix/unix.go23
-rw-r--r--pkg/sentry/strace/linux64_amd64.go4
-rw-r--r--pkg/sentry/strace/select.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go2
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go7
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/dentry.go47
-rw-r--r--pkg/sentry/vfs/device.go100
-rw-r--r--pkg/sentry/vfs/file_description.go302
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go51
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go4
-rw-r--r--pkg/sentry/vfs/filesystem.go288
-rw-r--r--pkg/sentry/vfs/filesystem_type.go55
-rw-r--r--pkg/sentry/vfs/mount.go15
-rw-r--r--pkg/sentry/vfs/options.go21
-rw-r--r--pkg/sentry/vfs/resolving_path.go86
-rw-r--r--pkg/sentry/vfs/testutil.go27
-rw-r--r--pkg/sentry/vfs/vfs.go332
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/BUILD8
-rw-r--r--pkg/tcpip/header/BUILD1
-rw-r--r--pkg/tcpip/header/ipv6.go45
-rw-r--r--pkg/tcpip/header/ipv6_test.go163
-rw-r--r--pkg/tcpip/network/arp/arp.go6
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go18
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go16
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/ndp.go600
-rw-r--r--pkg/tcpip/stack/ndp_test.go1700
-rw-r--r--pkg/tcpip/stack/nic.go152
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go6
-rw-r--r--pkg/tcpip/stack/stack.go176
-rw-r--r--pkg/tcpip/stack/stack_test.go249
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go92
-rw-r--r--pkg/tcpip/stack/transport_test.go14
-rw-r--r--pkg/tcpip/tcpip.go33
-rw-r--r--pkg/tcpip/timer.go161
-rw-r--r--pkg/tcpip/timer_test.go236
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go14
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go22
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go119
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go49
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go16
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go89
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go38
140 files changed, 8454 insertions, 2857 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 9553f164d..716ff22d2 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -41,6 +41,7 @@ go_library(
"poll.go",
"prctl.go",
"ptrace.go",
+ "rseq.go",
"rusage.go",
"sched.go",
"seccomp.go",
diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go
new file mode 100644
index 000000000..76253ba30
--- /dev/null
+++ b/pkg/abi/linux/rseq.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags passed to rseq(2).
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_FLAG_UNREGISTER unregisters the current thread.
+ RSEQ_FLAG_UNREGISTER = 1 << 0
+)
+
+// Critical section flags used in RSeqCriticalSection.Flags and RSeq.Flags.
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT inhibits restart on preemption.
+ RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT = 1 << 0
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL inhibits restart on signal
+ // delivery.
+ RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL = 1 << 1
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE inhibits restart on CPU
+ // migration.
+ RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE = 1 << 2
+)
+
+// RSeqCriticalSection describes a restartable sequences critical section. It
+// is equivalent to struct rseq_cs, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+//
+// +marshal
+type RSeqCriticalSection struct {
+ // Version is the version of this structure. Version 0 is defined here.
+ Version uint32
+
+ // Flags are the critical section flags, defined above.
+ Flags uint32
+
+ // Start is the start address of the critical section.
+ Start uint64
+
+ // PostCommitOffset is the offset from Start of the first instruction
+ // outside of the critical section.
+ PostCommitOffset uint64
+
+ // Abort is the abort address. It must be outside the critical section,
+ // and the 4 bytes prior must match the abort signature.
+ Abort uint64
+}
+
+const (
+ // SizeOfRSeqCriticalSection is the size of RSeqCriticalSection.
+ SizeOfRSeqCriticalSection = 32
+
+ // SizeOfRSeqSignature is the size of the signature immediately
+ // preceding RSeqCriticalSection.Abort.
+ SizeOfRSeqSignature = 4
+)
+
+// Special values for RSeq.CPUID, defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CPU_ID_UNINITIALIZED indicates that this thread has not
+ // performed rseq initialization.
+ RSEQ_CPU_ID_UNINITIALIZED = ^uint32(0) // -1
+
+ // RSEQ_CPU_ID_REGISTRATION_FAILED indicates that rseq initialization
+ // failed.
+ RSEQ_CPU_ID_REGISTRATION_FAILED = ^uint32(1) // -2
+)
+
+// RSeq is the thread-local restartable sequences config/status. It
+// is equivalent to struct rseq, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+type RSeq struct {
+ // CPUIDStart contains the current CPU ID if rseq is initialized.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUIDStart uint32
+
+ // CPUID contains the current CPU ID or one of the CPU ID special
+ // values defined above.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUID uint32
+
+ // RSeqCriticalSection is a pointer to the current RSeqCriticalSection
+ // block, or NULL. It is reset to NULL by the kernel on restart or
+ // non-restarting preempt/signal.
+ //
+ // This field should only be written by the thread which registered
+ // this structure, and must be written atomically.
+ RSeqCriticalSection uint64
+
+ // Flags are the critical section flags that apply to all critical
+ // sections on this thread, defined above.
+ Flags uint32
+}
+
+const (
+ // SizeOfRSeq is the size of RSeq.
+ //
+ // Note that RSeq is naively 24 bytes. However, it has 32-byte
+ // alignment, which in C increases sizeof to 32. That is the size that
+ // the Linux kernel uses.
+ SizeOfRSeq = 32
+
+ // AlignOfRSeq is the standard alignment of RSeq.
+ AlignOfRSeq = 32
+
+ // OffsetOfRSeqCriticalSection is the offset of RSeqCriticalSection in RSeq.
+ OffsetOfRSeqCriticalSection = 8
+)
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index 0c5f50397..ca540363c 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -14,7 +14,6 @@ go_library(
"fspath.go",
],
importpath = "gvisor.dev/gvisor/pkg/fspath",
- deps = ["//pkg/syserror"],
)
go_test(
@@ -25,5 +24,4 @@ go_test(
"fspath_test.go",
],
embed = [":fspath"],
- deps = ["//pkg/syserror"],
)
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index f68752560..9fb3fee24 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -18,19 +18,17 @@ package fspath
import (
"strings"
-
- "gvisor.dev/gvisor/pkg/syserror"
)
const pathSep = '/'
-// Parse parses a pathname as described by path_resolution(7).
-func Parse(pathname string) (Path, error) {
+// Parse parses a pathname as described by path_resolution(7), except that
+// empty pathnames will be parsed successfully to a Path for which
+// Path.Absolute == Path.Dir == Path.HasComponents() == false. (This is
+// necessary to support AT_EMPTY_PATH.)
+func Parse(pathname string) Path {
if len(pathname) == 0 {
- // "... POSIX decrees that an empty pathname must not be resolved
- // successfully. Linux returns ENOENT in this case." -
- // path_resolution(7)
- return Path{}, syserror.ENOENT
+ return Path{}
}
// Skip leading path separators.
i := 0
@@ -41,7 +39,7 @@ func Parse(pathname string) (Path, error) {
return Path{
Absolute: true,
Dir: true,
- }, nil
+ }
}
}
// Skip trailing path separators. This is required by Iterator.Next. This
@@ -64,7 +62,7 @@ func Parse(pathname string) (Path, error) {
},
Absolute: i != 0,
Dir: j != len(pathname)-1,
- }, nil
+ }
}
// Path contains the information contained in a pathname string.
@@ -111,6 +109,12 @@ func (p Path) String() string {
return b.String()
}
+// HasComponents returns true if p contains a non-zero number of path
+// components.
+func (p Path) HasComponents() bool {
+ return p.Begin.Ok()
+}
+
// An Iterator represents either a path component in a Path or a terminal
// iterator indicating that the end of the path has been reached.
//
diff --git a/pkg/fspath/fspath_test.go b/pkg/fspath/fspath_test.go
index 215b35622..d5e9a549a 100644
--- a/pkg/fspath/fspath_test.go
+++ b/pkg/fspath/fspath_test.go
@@ -18,15 +18,10 @@ import (
"reflect"
"strings"
"testing"
-
- "gvisor.dev/gvisor/pkg/syserror"
)
func TestParseIteratorPartialPathnames(t *testing.T) {
- path, err := Parse("/foo//bar///baz////")
- if err != nil {
- t.Fatalf("Parse failed: %v", err)
- }
+ path := Parse("/foo//bar///baz////")
// Parse strips leading slashes, and records their presence as
// Path.Absolute.
if !path.Absolute {
@@ -71,6 +66,12 @@ func TestParse(t *testing.T) {
}
tests := []testCase{
{
+ pathname: "",
+ relpath: []string{},
+ abs: false,
+ dir: false,
+ },
+ {
pathname: "/",
relpath: []string{},
abs: true,
@@ -113,10 +114,7 @@ func TestParse(t *testing.T) {
for _, test := range tests {
t.Run(test.pathname, func(t *testing.T) {
- p, err := Parse(test.pathname)
- if err != nil {
- t.Fatalf("failed to parse pathname %q: %v", test.pathname, err)
- }
+ p := Parse(test.pathname)
t.Logf("pathname %q => path %q", test.pathname, p)
if p.Absolute != test.abs {
t.Errorf("path absoluteness: got %v, wanted %v", p.Absolute, test.abs)
@@ -134,10 +132,3 @@ func TestParse(t *testing.T) {
})
}
}
-
-func TestParseEmptyPathname(t *testing.T) {
- p, err := Parse("")
- if err != syserror.ENOENT {
- t.Errorf("parsing empty pathname: got (%v, %v), wanted (<unspecified>, ENOENT)", p, err)
- }
-}
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 498ca4669..81ec98a77 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -125,9 +125,9 @@ type Context interface {
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
SetTLS(value uintptr) bool
- // SetRSEQInterruptedIP sets the register that contains the old IP when a
- // restartable sequence is interrupted.
- SetRSEQInterruptedIP(value uintptr)
+ // SetOldRSeqInterruptedIP sets the register that contains the old IP
+ // when an "old rseq" restartable sequence is interrupted.
+ SetOldRSeqInterruptedIP(value uintptr)
// StateData returns a pointer to underlying architecture state.
StateData() *State
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 67daa6c24..2aa08b1a9 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -174,8 +174,8 @@ func (c *context64) SetTLS(value uintptr) bool {
return true
}
-// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
-func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
c.Regs.R10 = uint64(value)
}
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index c0a6e884b..a2f966cb6 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -555,6 +555,10 @@ type lockedWriter struct {
//
// This applies only to Write, not WriteAt.
Offset int64
+
+ // Err contains the first error encountered while copying. This is
+ // useful to determine whether Writer or Reader failed during io.Copy.
+ Err error
}
// Write implements io.Writer.Write.
@@ -590,5 +594,8 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
break
}
}
+ if w.Err == nil {
+ w.Err = err
+ }
return written, err
}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index b2e8d9c77..9ca695a95 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -53,7 +53,7 @@ go_template_instance(
"Key": "uint64",
"Range": "memmap.MappableRange",
"Value": "uint64",
- "Functions": "fileRangeSetFunctions",
+ "Functions": "FileRangeSetFunctions",
},
)
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 0a5466b0a..f52d712e3 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -34,25 +34,25 @@ import (
//
// type FileRangeSet <generated by go_generics>
-// fileRangeSetFunctions implements segment.Functions for FileRangeSet.
-type fileRangeSetFunctions struct{}
+// FileRangeSetFunctions implements segment.Functions for FileRangeSet.
+type FileRangeSetFunctions struct{}
// MinKey implements segment.Functions.MinKey.
-func (fileRangeSetFunctions) MinKey() uint64 {
+func (FileRangeSetFunctions) MinKey() uint64 {
return 0
}
// MaxKey implements segment.Functions.MaxKey.
-func (fileRangeSetFunctions) MaxKey() uint64 {
+func (FileRangeSetFunctions) MaxKey() uint64 {
return math.MaxUint64
}
// ClearValue implements segment.Functions.ClearValue.
-func (fileRangeSetFunctions) ClearValue(_ *uint64) {
+func (FileRangeSetFunctions) ClearValue(_ *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
+func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
if frstart1+mr1.Length() != frstart2 {
return 0, false
}
@@ -60,7 +60,7 @@ func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _
}
// Split implements segment.Functions.Split.
-func (fileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
+func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
return frstart, frstart + (split - mr.Start)
}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 0e46c5fb7..9bf4b4527 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -604,6 +604,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(&buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n")
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0
}
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index 311798811..389c330a0 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -167,6 +167,11 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
if !srcPipe && !opts.SrcOffset {
atomic.StoreInt64(&src.offset, src.offset+n)
}
+
+ // Don't report any errors if we have some progress without data loss.
+ if w.Err == nil {
+ err = nil
+ }
}
// Drop locks.
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index 934828c12..6b07f6bf2 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -77,7 +77,7 @@ func (mi *masterInodeOperations) Release(ctx context.Context) {
}
// Truncate implements fs.InodeOperations.Truncate.
-func (masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 880b7bcd3..bc90330bc 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -74,6 +74,7 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/fspath",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
index bfc46dfa6..4fc8296ef 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/BUILD
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -7,6 +7,7 @@ go_test(
size = "small",
srcs = ["benchmark_test.go"],
deps = [
+ "//pkg/fspath",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fsimpl/ext",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 177ce2cb9..a56b03711 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -24,6 +24,7 @@ import (
"strings"
"testing"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext"
@@ -49,7 +50,9 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{})
+ vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
@@ -121,7 +124,7 @@ func BenchmarkVFS2Ext4fsStat(b *testing.B) {
stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
Root: *root,
Start: *root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
@@ -150,9 +153,9 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
creds := auth.CredentialsFromContext(ctx)
mountPointName := "/1/"
pop := vfs.PathOperation{
- Root: *root,
- Start: *root,
- Pathname: mountPointName,
+ Root: *root,
+ Start: *root,
+ Path: fspath.Parse(mountPointName),
}
// Save the mount point for later use.
@@ -181,7 +184,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
Root: *root,
Start: *root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 567523d32..4110649ab 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -29,8 +29,12 @@ package disklayout
// byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()).
const (
- // ExtentStructsSize is the size of all the three extent on-disk structs.
- ExtentStructsSize = 12
+ // ExtentHeaderSize is the size of the header of an extent tree node.
+ ExtentHeaderSize = 12
+
+ // ExtentEntrySize is the size of an entry in an extent tree node.
+ // This size is the same for both leaf and internal nodes.
+ ExtentEntrySize = 12
// ExtentMagic is the magic number which must be present in the header.
ExtentMagic = 0xf30a
@@ -57,7 +61,7 @@ type ExtentNode struct {
Entries []ExtentEntryPair
}
-// ExtentEntry reprsents an extent tree node entry. The entry can either be
+// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
// FileBlock returns the first file block number covered by this entry.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index b0fad9b71..8762b90db 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,7 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentStructsSize)
- assertSize(t, ExtentIdx{}, ExtentStructsSize)
- assertSize(t, Extent{}, ExtentStructsSize)
+ assertSize(t, ExtentHeader{}, ExtentHeaderSize)
+ assertSize(t, ExtentIdx{}, ExtentEntrySize)
+ assertSize(t, Extent{}, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index e9f756732..6c14a1e2d 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -25,6 +25,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
@@ -65,7 +66,9 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{})
+ vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
@@ -140,7 +143,7 @@ func TestSeek(t *testing.T) {
fd, err := vfsfs.OpenAt(
ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
&vfs.OpenOptions{},
)
if err != nil {
@@ -359,7 +362,7 @@ func TestStatAt(t *testing.T) {
got, err := vfsfs.StatAt(ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
&vfs.StatOptions{},
)
if err != nil {
@@ -429,7 +432,7 @@ func TestRead(t *testing.T) {
fd, err := vfsfs.OpenAt(
ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.absPath},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.absPath)},
&vfs.OpenOptions{},
)
if err != nil {
@@ -565,7 +568,7 @@ func TestIterDirents(t *testing.T) {
fd, err := vfsfs.OpenAt(
ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
&vfs.OpenOptions{},
)
if err != nil {
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 3d3ebaca6..11dcc0346 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -57,7 +57,7 @@ func newExtentFile(regFile regularFile) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header)
+ binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -67,7 +67,7 @@ func (f *extentFile) buildExtTree() error {
}
f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
- for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if f.root.Header.Height == 0 {
// Leaf node.
@@ -76,7 +76,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry)
+ binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
f.root.Entries[i].Entry = curEntry
}
@@ -105,7 +105,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla
}
entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
- for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if header.Height == 0 {
// Leaf node.
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index 5eca2b83f..841274daf 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -26,13 +26,6 @@ import (
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
-
- // flags is the same as vfs.OpenOptions.Flags which are passed to
- // vfs.FilesystemImpl.OpenAt.
- // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2),
- // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set.
- // Only close(2), fstat(2), fstatfs(2) should work.
- flags uint32
}
func (fd *fileDescription) filesystem() *filesystem {
@@ -43,18 +36,6 @@ func (fd *fileDescription) inode() *inode {
return fd.vfsfd.Dentry().Impl().(*dentry).inode
}
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
-}
-
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index e7aa3b41b..616fc002a 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -275,6 +275,16 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
return vfsd, nil
}
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ vfsd, inode, err := fs.walk(rp, true)
+ if err != nil {
+ return nil, err
+ }
+ inode.incRef()
+ return vfsd, nil
+}
+
// OpenAt implements vfs.FilesystemImpl.OpenAt.
func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
vfsd, inode, err := fs.walk(rp, false)
@@ -378,7 +388,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
if rp.Done() {
return syserror.ENOENT
}
@@ -443,6 +453,42 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
// PrependPath implements vfs.FilesystemImpl.PrependPath.
func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
fs.mu.RLock()
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 24249525c..8608805bf 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -157,10 +157,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
- fd.flags = flags
- mnt.IncRef()
- vfsd.IncRef()
- fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *directory:
// Can't open directories writably. This check is not necessary for a read
@@ -169,10 +166,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.EISDIR
}
var fd directoryFD
- fd.flags = flags
- mnt.IncRef()
- vfsd.IncRef()
- fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *symlink:
if flags&linux.O_PATH == 0 {
@@ -180,10 +174,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.ELOOP
}
var fd symlinkFD
- fd.flags = flags
- mnt.IncRef()
- vfsd.IncRef()
- fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 52596c090..39c03ee9d 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -25,6 +25,7 @@ go_library(
"inode_impl_util.go",
"kernfs.go",
"slot_list.go",
+ "symlink.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs",
visibility = ["//pkg/sentry:internal"],
@@ -49,6 +50,7 @@ go_test(
deps = [
":kernfs",
"//pkg/abi/linux",
+ "//pkg/fspath",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 30c06baf0..606ca692d 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -15,6 +15,8 @@
package kernfs
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -26,7 +28,10 @@ import (
// DynamicBytesFile implements kernfs.Inode and represents a read-only
// file whose contents are backed by a vfs.DynamicBytesSource.
//
-// Must be initialized with Init before first use.
+// Must be instantiated with NewDynamicBytesFile or initialized with Init
+// before first use.
+//
+// +stateify savable
type DynamicBytesFile struct {
InodeAttrs
InodeNoopRefCount
@@ -36,9 +41,14 @@ type DynamicBytesFile struct {
data vfs.DynamicBytesSource
}
-// Init intializes a dynamic bytes file.
-func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) {
- f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444)
+var _ Inode = (*DynamicBytesFile)(nil)
+
+// Init initializes a dynamic bytes file.
+func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(creds, ino, linux.ModeRegular|perm)
f.data = data
}
@@ -59,23 +69,21 @@ func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error {
// DynamicBytesFile.
//
// Must be initialized with Init before first use.
+//
+// +stateify savable
type DynamicBytesFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DynamicBytesFileDescriptionImpl
vfsfd vfs.FileDescription
inode Inode
- flags uint32
}
// Init initializes a DynamicBytesFD.
func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) {
- m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
- d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
- fd.flags = flags
fd.inode = d.Impl().(*Dentry).inode
fd.SetDataSource(data)
- fd.vfsfd.Init(fd, m, d)
+ fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{})
}
// Seek implements vfs.FileDescriptionImpl.Seek.
@@ -117,15 +125,3 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error {
// DynamicBytesFiles are immutable.
return syserror.EPERM
}
-
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *DynamicBytesFD) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *DynamicBytesFD) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
-}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index d6c18937a..bcf069b5f 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -39,17 +39,13 @@ type GenericDirectoryFD struct {
vfsfd vfs.FileDescription
children *OrderedChildren
- flags uint32
off int64
}
// Init initializes a GenericDirectoryFD.
func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) {
- m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
- d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
fd.children = children
- fd.flags = flags
- fd.vfsfd.Init(fd, m, d)
+ fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{})
}
// VFSFileDescription returns a pointer to the vfs.FileDescription representing
@@ -156,7 +152,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
fd.off++
}
- return nil
+ var err error
+ relOffset := fd.off - int64(len(fd.children.set)) - 2
+ fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ return err
}
// Seek implements vfs.FileDecriptionImpl.Seek.
@@ -180,18 +179,6 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int
return offset, nil
}
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *GenericDirectoryFD) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *GenericDirectoryFD) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
-}
-
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.filesystem()
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index db486b6c1..79759e0fc 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -44,39 +44,37 @@ func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingP
return nil, err
}
afterSymlink:
+ name := rp.Component()
+ // Revalidation must be skipped if name is "." or ".."; d or its parent
+ // respectively can't be expected to transition from invalidated back to
+ // valid, so detecting invalidation and retrying would loop forever. This
+ // is consistent with Linux: fs/namei.c:walk_component() => lookup_fast()
+ // calls d_revalidate(), but walk_component() => handle_dots() does not.
+ if name == "." {
+ rp.Advance()
+ return vfsd, nil
+ }
+ if name == ".." {
+ nextVFSD, err := rp.ResolveParent(vfsd)
+ if err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return nextVFSD, nil
+ }
d.dirMu.Lock()
- nextVFSD, err := rp.ResolveComponent(vfsd)
- d.dirMu.Unlock()
+ nextVFSD, err := rp.ResolveChild(vfsd, name)
if err != nil {
+ d.dirMu.Unlock()
return nil, err
}
- if nextVFSD != nil {
- // Cached dentry exists, revalidate.
- next := nextVFSD.Impl().(*Dentry)
- if !next.inode.Valid(ctx) {
- d.dirMu.Lock()
- rp.VirtualFilesystem().ForceDeleteDentry(nextVFSD)
- d.dirMu.Unlock()
- fs.deferDecRef(nextVFSD) // Reference from Lookup.
- nextVFSD = nil
- }
- }
- if nextVFSD == nil {
- // Dentry isn't cached; it either doesn't exist or failed
- // revalidation. Attempt to resolve it via Lookup.
- name := rp.Component()
- var err error
- nextVFSD, err = d.inode.Lookup(ctx, name)
- // Reference on nextVFSD dropped by a corresponding Valid.
- if err != nil {
- return nil, err
- }
- d.InsertChild(name, nextVFSD)
+ next, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, nextVFSD)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
}
- next := nextVFSD.Impl().(*Dentry)
-
// Resolve any symlink at current path component.
- if rp.ShouldFollowSymlink() && d.isSymlink() {
+ if rp.ShouldFollowSymlink() && next.isSymlink() {
// TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks".
target, err := next.inode.Readlink(ctx)
if err != nil {
@@ -89,7 +87,44 @@ afterSymlink:
}
rp.Advance()
- return nextVFSD, nil
+ return &next.vfsd, nil
+}
+
+// revalidateChildLocked must be called after a call to parent.vfsd.Child(name)
+// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be
+// nil) to verify that the returned child (or lack thereof) is correct.
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+// parent.dirMu must be locked. parent.isDir(). name is not "." or "..".
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, childVFSD *vfs.Dentry) (*Dentry, error) {
+ if childVFSD != nil {
+ // Cached dentry exists, revalidate.
+ child := childVFSD.Impl().(*Dentry)
+ if !child.inode.Valid(ctx) {
+ vfsObj.ForceDeleteDentry(childVFSD)
+ fs.deferDecRef(childVFSD) // Reference from Lookup.
+ childVFSD = nil
+ }
+ }
+ if childVFSD == nil {
+ // Dentry isn't cached; it either doesn't exist or failed
+ // revalidation. Attempt to resolve it via Lookup.
+ //
+ // FIXME(b/144498111): Inode.Lookup() should return *(kernfs.)Dentry,
+ // not *vfs.Dentry, since (kernfs.)Filesystem assumes that all dentries
+ // in the filesystem are (kernfs.)Dentry and performs vfs.DentryImpl
+ // casts accordingly.
+ var err error
+ childVFSD, err = parent.inode.Lookup(ctx, name)
+ if err != nil {
+ return nil, err
+ }
+ // Reference on childVFSD dropped by a corresponding Valid.
+ parent.insertChildLocked(name, childVFSD)
+ }
+ return childVFSD.Impl().(*Dentry), nil
}
// walkExistingLocked resolves rp to an existing file.
@@ -242,6 +277,19 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
return vfsd, nil
}
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+ vfsd, _, err := fs.walkParentDirLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ vfsd.IncRef() // Ownership transferred to caller.
+ return vfsd, nil
+}
+
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
if rp.Done() {
@@ -459,40 +507,42 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
}
// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
- noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
- exchange := opts.Flags&linux.RENAME_EXCHANGE != 0
- whiteout := opts.Flags&linux.RENAME_WHITEOUT != 0
- if exchange && (noReplace || whiteout) {
- // Can't specify RENAME_NOREPLACE or RENAME_WHITEOUT with RENAME_EXCHANGE.
- return syserror.EINVAL
- }
- if exchange || whiteout {
- // Exchange and Whiteout flags are not supported on kernfs.
+func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ // Only RENAME_NOREPLACE is supported.
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
return syserror.EINVAL
}
+ noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
fs.mu.Lock()
defer fs.mu.Lock()
+ // Resolve the destination directory first to verify that it's on this
+ // Mount.
+ dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
mnt := rp.Mount()
- if mnt != vd.Mount() {
+ if mnt != oldParentVD.Mount() {
return syserror.EXDEV
}
-
if err := mnt.CheckBeginWrite(); err != nil {
return err
}
defer mnt.EndWrite()
- dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp)
+ srcDirVFSD := oldParentVD.Dentry()
+ srcDir := srcDirVFSD.Impl().(*Dentry)
+ srcDir.dirMu.Lock()
+ src, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), srcDir, oldName, srcDirVFSD.Child(oldName))
+ srcDir.dirMu.Unlock()
fs.processDeferredDecRefsLocked()
if err != nil {
return err
}
-
- srcVFSD := vd.Dentry()
- srcDirVFSD := srcVFSD.Parent()
+ srcVFSD := &src.vfsd
// Can we remove the src dentry?
if err := checkDeleteLocked(rp, srcVFSD); err != nil {
@@ -683,6 +733,58 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return nil
}
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return nil, err
+ }
+ // kernfs currently does not support extended attributes.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return "", err
+ }
+ // kernfs currently does not support extended attributes.
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
// PrependPath implements vfs.FilesystemImpl.PrependPath.
func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
fs.mu.RLock()
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 7b45b702a..752e0f659 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -139,6 +139,11 @@ func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry,
panic("Lookup called on non-directory inode")
}
+// IterDirents implements Inode.IterDirents.
+func (*InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ panic("IterDirents called on non-directory inode")
+}
+
// Valid implements Inode.Valid.
func (*InodeNotDirectory) Valid(context.Context) bool {
return true
@@ -156,6 +161,11 @@ func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dent
return nil, syserror.ENOENT
}
+// IterDirents implements Inode.IterDirents.
+func (*InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return offset, nil
+}
+
// Valid implements Inode.Valid.
func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
return true
@@ -490,3 +500,13 @@ func (o *OrderedChildren) nthLocked(i int64) *slot {
}
return nil
}
+
+// InodeSymlink partially implements Inode interface for symlinks.
+type InodeSymlink struct {
+ InodeNotDirectory
+}
+
+// Open implements Inode.Open.
+func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ return nil, syserror.ELOOP
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index bb01c3d01..d69b299ae 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -239,14 +239,22 @@ func (d *Dentry) destroy() {
//
// Precondition: d must represent a directory inode.
func (d *Dentry) InsertChild(name string, child *vfs.Dentry) {
+ d.dirMu.Lock()
+ d.insertChildLocked(name, child)
+ d.dirMu.Unlock()
+}
+
+// insertChildLocked is equivalent to InsertChild, with additional
+// preconditions.
+//
+// Precondition: d.dirMu must be locked.
+func (d *Dentry) insertChildLocked(name string, child *vfs.Dentry) {
if !d.isDir() {
panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d))
}
vfsDentry := d.VFSDentry()
vfsDentry.IncRef() // DecRef in child's Dentry.destroy.
- d.dirMu.Lock()
vfsDentry.InsertChild(child, name)
- d.dirMu.Unlock()
}
// The Inode interface maps filesystem-level operations that operate on paths to
@@ -396,6 +404,15 @@ type inodeDynamicLookup interface {
// Valid should return true if this inode is still valid, or needs to
// be resolved again by a call to Lookup.
Valid(ctx context.Context) bool
+
+ // IterDirents is used to iterate over dynamically created entries. It invokes
+ // cb on each entry in the directory represented by the FileDescription.
+ // 'offset' is the offset for the entire IterDirents call, which may include
+ // results from the caller. 'relOffset' is the offset inside the entries
+ // returned by this IterDirents invocation. In other words,
+ // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff,
+ // while 'relOffset' is the place where iteration should start from.
+ IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
}
type inodeSymlink interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index f78bb7b04..4b6b95f5f 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -24,6 +24,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -58,7 +59,9 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
v := vfs.New()
- v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn})
+ v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("Failed to create testfs root mount: %v", err)
@@ -82,9 +85,9 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem {
// Precondition: path should be relative path.
func (s *TestSystem) PathOpAtRoot(path string) vfs.PathOperation {
return vfs.PathOperation{
- Root: s.root,
- Start: s.root,
- Pathname: path,
+ Root: s.root,
+ Start: s.root,
+ Path: fspath.Parse(path),
}
}
@@ -132,7 +135,7 @@ type file struct {
func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
f := &file{}
f.content = content
- f.DynamicBytesFile.Init(creds, fs.NextIno(), f)
+ f.DynamicBytesFile.Init(creds, fs.NextIno(), f, 0777)
d := &kernfs.Dentry{}
d.Init(f)
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
new file mode 100644
index 000000000..068063f4e
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+type staticSymlink struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeSymlink
+
+ target string
+}
+
+var _ Inode = (*staticSymlink)(nil)
+
+// NewStaticSymlink creates a new symlink file pointing to 'target'.
+func NewStaticSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, target string) *Dentry {
+ inode := &staticSymlink{target: target}
+ inode.Init(creds, ino, linux.ModeSymlink|perm)
+
+ d := &Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *staticSymlink) Readlink(_ context.Context) (string, error) {
+ return s.target, nil
+}
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go
deleted file mode 100644
index 1f2a5122a..000000000
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ /dev/null
@@ -1,592 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package memfs
-
-import (
- "fmt"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/fspath"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// stepLocked resolves rp.Component() in parent directory vfsd.
-//
-// stepLocked is loosely analogous to fs/namei.c:walk_component().
-//
-// Preconditions: filesystem.mu must be locked. !rp.Done(). inode ==
-// vfsd.Impl().(*dentry).inode.
-func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode) (*vfs.Dentry, *inode, error) {
- if !inode.isDir() {
- return nil, nil, syserror.ENOTDIR
- }
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
- return nil, nil, err
- }
-afterSymlink:
- nextVFSD, err := rp.ResolveComponent(vfsd)
- if err != nil {
- return nil, nil, err
- }
- if nextVFSD == nil {
- // Since the Dentry tree is the sole source of truth for memfs, if it's
- // not in the Dentry tree, it doesn't exist.
- return nil, nil, syserror.ENOENT
- }
- nextInode := nextVFSD.Impl().(*dentry).inode
- if symlink, ok := nextInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO: symlink traversals update access time
- if err := rp.HandleSymlink(symlink.target); err != nil {
- return nil, nil, err
- }
- goto afterSymlink // don't check the current directory again
- }
- rp.Advance()
- return nextVFSD, nextInode, nil
-}
-
-// walkExistingLocked resolves rp to an existing file.
-//
-// walkExistingLocked is loosely analogous to Linux's
-// fs/namei.c:path_lookupat().
-//
-// Preconditions: filesystem.mu must be locked.
-func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) {
- vfsd := rp.Start()
- inode := vfsd.Impl().(*dentry).inode
- for !rp.Done() {
- var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode)
- if err != nil {
- return nil, nil, err
- }
- }
- if rp.MustBeDir() && !inode.isDir() {
- return nil, nil, syserror.ENOTDIR
- }
- return vfsd, inode, nil
-}
-
-// walkParentDirLocked resolves all but the last path component of rp to an
-// existing directory. It does not check that the returned directory is
-// searchable by the provider of rp.
-//
-// walkParentDirLocked is loosely analogous to Linux's
-// fs/namei.c:path_parentat().
-//
-// Preconditions: filesystem.mu must be locked. !rp.Done().
-func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) {
- vfsd := rp.Start()
- inode := vfsd.Impl().(*dentry).inode
- for !rp.Final() {
- var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode)
- if err != nil {
- return nil, nil, err
- }
- }
- if !inode.isDir() {
- return nil, nil, syserror.ENOTDIR
- }
- return vfsd, inode, nil
-}
-
-// checkCreateLocked checks that a file named rp.Component() may be created in
-// directory parentVFSD, then returns rp.Component().
-//
-// Preconditions: filesystem.mu must be locked. parentInode ==
-// parentVFSD.Impl().(*dentry).inode. parentInode.isDir() == true.
-func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *inode) (string, error) {
- if err := parentInode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
- return "", err
- }
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return "", syserror.EEXIST
- }
- childVFSD, err := rp.ResolveChild(parentVFSD, pc)
- if err != nil {
- return "", err
- }
- if childVFSD != nil {
- return "", syserror.EEXIST
- }
- if parentVFSD.IsDisowned() {
- return "", syserror.ENOENT
- }
- return pc, nil
-}
-
-// checkDeleteLocked checks that the file represented by vfsd may be deleted.
-func checkDeleteLocked(vfsd *vfs.Dentry) error {
- parentVFSD := vfsd.Parent()
- if parentVFSD == nil {
- return syserror.EBUSY
- }
- if parentVFSD.IsDisowned() {
- return syserror.ENOENT
- }
- return nil
-}
-
-// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
-func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
- fs.mu.RLock()
- defer fs.mu.RUnlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return nil, err
- }
- if opts.CheckSearchable {
- if !inode.isDir() {
- return nil, syserror.ENOTDIR
- }
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
- return nil, err
- }
- }
- inode.incRef()
- return vfsd, nil
-}
-
-// LinkAt implements vfs.FilesystemImpl.LinkAt.
-func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if rp.Mount() != vd.Mount() {
- return syserror.EXDEV
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- d := vd.Dentry().Impl().(*dentry)
- if d.inode.isDir() {
- return syserror.EPERM
- }
- d.inode.incLinksLocked()
- child := fs.newDentry(d.inode)
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- return nil
-}
-
-// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
-func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- child := fs.newDentry(fs.newDirectory(rp.Credentials(), opts.Mode))
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- parentInode.incLinksLocked() // from child's ".."
- return nil
-}
-
-// MknodAt implements vfs.FilesystemImpl.MknodAt.
-func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
-
- switch opts.Mode.FileType() {
- case 0:
- // "Zero file type is equivalent to type S_IFREG." - mknod(2)
- fallthrough
- case linux.ModeRegular:
- // TODO(b/138862511): Implement.
- return syserror.EINVAL
-
- case linux.ModeNamedPipe:
- child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- return nil
-
- case linux.ModeSocket:
- // TODO(b/138862511): Implement.
- return syserror.EINVAL
-
- case linux.ModeCharacterDevice:
- fallthrough
- case linux.ModeBlockDevice:
- // TODO(b/72101894): We don't support creating block or character
- // devices at the moment.
- //
- // When we start supporting block and character devices, we'll
- // need to check for CAP_MKNOD here.
- return syserror.EPERM
-
- default:
- // "EINVAL - mode requested creation of something other than a
- // regular file, device special file, FIFO or socket." - mknod(2)
- return syserror.EINVAL
- }
-}
-
-// OpenAt implements vfs.FilesystemImpl.OpenAt.
-func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- // Filter out flags that are not supported by memfs. O_DIRECTORY and
- // O_NOFOLLOW have no effect here (they're handled by VFS by setting
- // appropriate bits in rp), but are returned by
- // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by
- // pipes.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
-
- if opts.Flags&linux.O_CREAT == 0 {
- fs.mu.RLock()
- defer fs.mu.RUnlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return nil, err
- }
- return inode.open(ctx, rp, vfsd, opts.Flags, false)
- }
-
- mustCreate := opts.Flags&linux.O_EXCL != 0
- vfsd := rp.Start()
- inode := vfsd.Impl().(*dentry).inode
- fs.mu.Lock()
- defer fs.mu.Unlock()
- if rp.Done() {
- if rp.MustBeDir() {
- return nil, syserror.EISDIR
- }
- if mustCreate {
- return nil, syserror.EEXIST
- }
- return inode.open(ctx, rp, vfsd, opts.Flags, false)
- }
-afterTrailingSymlink:
- // Walk to the parent directory of the last path component.
- for !rp.Final() {
- var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode)
- if err != nil {
- return nil, err
- }
- }
- if !inode.isDir() {
- return nil, syserror.ENOTDIR
- }
- // Check for search permission in the parent directory.
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
- return nil, err
- }
- // Reject attempts to open directories with O_CREAT.
- if rp.MustBeDir() {
- return nil, syserror.EISDIR
- }
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return nil, syserror.EISDIR
- }
- // Determine whether or not we need to create a file.
- childVFSD, err := rp.ResolveChild(vfsd, pc)
- if err != nil {
- return nil, err
- }
- if childVFSD == nil {
- // Already checked for searchability above; now check for writability.
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
- return nil, err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return nil, err
- }
- defer rp.Mount().EndWrite()
- // Create and open the child.
- childInode := fs.newRegularFile(rp.Credentials(), opts.Mode)
- child := fs.newDentry(childInode)
- vfsd.InsertChild(&child.vfsd, pc)
- inode.impl.(*directory).childList.PushBack(child)
- return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true)
- }
- // Open existing file or follow symlink.
- if mustCreate {
- return nil, syserror.EEXIST
- }
- childInode := childVFSD.Impl().(*dentry).inode
- if symlink, ok := childInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO: symlink traversals update access time
- if err := rp.HandleSymlink(symlink.target); err != nil {
- return nil, err
- }
- // rp.Final() may no longer be true since we now need to resolve the
- // symlink target.
- goto afterTrailingSymlink
- }
- return childInode.open(ctx, rp, childVFSD, opts.Flags, false)
-}
-
-func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
- if !afterCreate {
- if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil {
- return nil, err
- }
- }
- mnt := rp.Mount()
- switch impl := i.impl.(type) {
- case *regularFile:
- var fd regularFileFD
- fd.flags = flags
- fd.readable = vfs.MayReadFileWithOpenFlags(flags)
- fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
- if fd.writable {
- if err := mnt.CheckBeginWrite(); err != nil {
- return nil, err
- }
- // mnt.EndWrite() is called by regularFileFD.Release().
- }
- mnt.IncRef()
- vfsd.IncRef()
- fd.vfsfd.Init(&fd, mnt, vfsd)
- if flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data = impl.data[:0]
- atomic.StoreInt64(&impl.dataLen, 0)
- impl.mu.Unlock()
- }
- return &fd.vfsfd, nil
- case *directory:
- // Can't open directories writably.
- if ats&vfs.MayWrite != 0 {
- return nil, syserror.EISDIR
- }
- var fd directoryFD
- mnt.IncRef()
- vfsd.IncRef()
- fd.vfsfd.Init(&fd, mnt, vfsd)
- fd.flags = flags
- return &fd.vfsfd, nil
- case *symlink:
- // Can't open symlinks without O_PATH (which is unimplemented).
- return nil, syserror.ELOOP
- case *namedPipe:
- return newNamedPipeFD(ctx, impl, rp, vfsd, flags)
- default:
- panic(fmt.Sprintf("unknown inode type: %T", i.impl))
- }
-}
-
-// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
-func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
- fs.mu.RLock()
- _, inode, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return "", err
- }
- symlink, ok := inode.impl.(*symlink)
- if !ok {
- return "", syserror.EINVAL
- }
- return symlink.target, nil
-}
-
-// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
- if rp.Done() {
- return syserror.ENOENT
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- _, err = checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- // TODO: actually implement RenameAt
- return syserror.EPERM
-}
-
-// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
-func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- fs.mu.Lock()
- defer fs.mu.Unlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(vfsd); err != nil {
- return err
- }
- if !inode.isDir() {
- return syserror.ENOTDIR
- }
- if vfsd.HasChildren() {
- return syserror.ENOTEMPTY
- }
- if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
- return err
- }
- // Remove from parent directory's childList.
- vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry))
- inode.decRef()
- return nil
-}
-
-// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
-func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
- fs.mu.RLock()
- _, _, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return err
- }
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
-}
-
-// StatAt implements vfs.FilesystemImpl.StatAt.
-func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
- fs.mu.RLock()
- _, inode, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return linux.Statx{}, err
- }
- var stat linux.Statx
- inode.statTo(&stat)
- return stat, nil
-}
-
-// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
-func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
- fs.mu.RLock()
- _, _, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return linux.Statfs{}, err
- }
- // TODO: actually implement statfs
- return linux.Statfs{}, syserror.ENOSYS
-}
-
-// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
-func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- child := fs.newDentry(fs.newSymlink(rp.Credentials(), target))
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- return nil
-}
-
-// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
-func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- fs.mu.Lock()
- defer fs.mu.Unlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(vfsd); err != nil {
- return err
- }
- if inode.isDir() {
- return syserror.EISDIR
- }
- if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
- return err
- }
- // Remove from parent directory's childList.
- vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry))
- inode.decLinksLocked()
- return nil
-}
-
-// PrependPath implements vfs.FilesystemImpl.PrependPath.
-func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
- fs.mu.RLock()
- defer fs.mu.RUnlock()
- return vfs.GenericPrependPath(vfsroot, vd, b)
-}
diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go
deleted file mode 100644
index b7f4853b3..000000000
--- a/pkg/sentry/fsimpl/memfs/regular_file.go
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package memfs
-
-import (
- "io"
- "sync"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-type regularFile struct {
- inode inode
-
- mu sync.RWMutex
- data []byte
- // dataLen is len(data), but accessed using atomic memory operations to
- // avoid locking in inode.stat().
- dataLen int64
-}
-
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &regularFile{}
- file.inode.init(file, fs, creds, mode)
- file.inode.nlink = 1 // from parent directory
- return &file.inode
-}
-
-type regularFileFD struct {
- fileDescription
-
- // These are immutable.
- readable bool
- writable bool
-
- // off is the file offset. off is accessed using atomic memory operations.
- // offMu serializes operations that may mutate off.
- off int64
- offMu sync.Mutex
-}
-
-// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
- }
-}
-
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.RLock()
- if offset >= int64(len(f.data)) {
- f.mu.RUnlock()
- return 0, io.EOF
- }
- n, err := dst.CopyOut(ctx, f.data[offset:])
- f.mu.RUnlock()
- return int64(n), err
-}
-
-// Read implements vfs.FileDescriptionImpl.Read.
-func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PRead(ctx, dst, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- srclen := src.NumBytes()
- if srclen == 0 {
- return 0, nil
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.Lock()
- end := offset + srclen
- if end < offset {
- // Overflow.
- f.mu.Unlock()
- return 0, syserror.EFBIG
- }
- if end > f.dataLen {
- f.data = append(f.data, make([]byte, end-f.dataLen)...)
- atomic.StoreInt64(&f.dataLen, end)
- }
- n, err := src.CopyIn(ctx, f.data[offset:end])
- f.mu.Unlock()
- return int64(n), err
-}
-
-// Write implements vfs.FileDescriptionImpl.Write.
-func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// Seek implements vfs.FileDescriptionImpl.Seek.
-func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
- fd.offMu.Lock()
- defer fd.offMu.Unlock()
- switch whence {
- case linux.SEEK_SET:
- // use offset as specified
- case linux.SEEK_CUR:
- offset += fd.off
- case linux.SEEK_END:
- offset += atomic.LoadInt64(&fd.inode().impl.(*regularFile).dataLen)
- default:
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- fd.off = offset
- return offset, nil
-}
-
-// Sync implements vfs.FileDescriptionImpl.Sync.
-func (fd *regularFileFD) Sync(ctx context.Context) error {
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index ade6ac946..1f44b3217 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -6,15 +6,17 @@ package(licenses = ["notice"])
go_library(
name = "proc",
srcs = [
- "filesystems.go",
+ "filesystem.go",
"loadavg.go",
"meminfo.go",
"mounts.go",
"net.go",
- "proc.go",
"stat.go",
"sys.go",
"task.go",
+ "task_files.go",
+ "tasks.go",
+ "tasks_files.go",
"version.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc",
@@ -24,8 +26,10 @@ go_library(
"//pkg/log",
"//pkg/sentry/context",
"//pkg/sentry/fs",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
"//pkg/sentry/mm",
"//pkg/sentry/socket",
@@ -34,17 +38,40 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
],
)
go_test(
name = "proc_test",
size = "small",
- srcs = ["net_test.go"],
+ srcs = [
+ "boot_test.go",
+ "net_test.go",
+ "tasks_test.go",
+ ],
embed = [":proc"],
deps = [
"//pkg/abi/linux",
+ "//pkg/cpuid",
+ "//pkg/fspath",
+ "//pkg/memutil",
+ "//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
+ "//pkg/sentry/fs",
"//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/kvm",
+ "//pkg/sentry/platform/ptrace",
+ "//pkg/sentry/time",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
],
)
diff --git a/pkg/sentry/fsimpl/proc/boot_test.go b/pkg/sentry/fsimpl/proc/boot_test.go
new file mode 100644
index 000000000..84a93ee56
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/boot_test.go
@@ -0,0 +1,149 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+
+ // Platforms are plugable.
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace"
+)
+
+var (
+ platformFlag = flag.String("platform", "ptrace", "specify which platform to use")
+)
+
+// boot initializes a new bare bones kernel for test.
+func boot() (*kernel.Kernel, error) {
+ platformCtr, err := platform.Lookup(*platformFlag)
+ if err != nil {
+ return nil, fmt.Errorf("platform not found: %v", err)
+ }
+ deviceFile, err := platformCtr.OpenDevice()
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+ plat, err := platformCtr.New(deviceFile)
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+
+ k := &kernel.Kernel{
+ Platform: plat,
+ }
+
+ mf, err := createMemoryFile()
+ if err != nil {
+ return nil, err
+ }
+ k.SetMemoryFile(mf)
+
+ // Pass k as the platform since it is savable, unlike the actual platform.
+ vdso, err := loader.PrepareVDSO(nil, k)
+ if err != nil {
+ return nil, fmt.Errorf("creating vdso: %v", err)
+ }
+
+ // Create timekeeper.
+ tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
+ if err != nil {
+ return nil, fmt.Errorf("creating timekeeper: %v", err)
+ }
+ tk.SetClocks(time.NewCalibratedClocks())
+
+ creds := auth.NewRootCredentials(auth.NewRootUserNamespace())
+
+ // Initiate the Kernel object, which is required by the Context passed
+ // to createVFS in order to mount (among other things) procfs.
+ if err = k.Init(kernel.InitKernelArgs{
+ ApplicationCores: uint(runtime.GOMAXPROCS(-1)),
+ FeatureSet: cpuid.HostFeatureSet(),
+ Timekeeper: tk,
+ RootUserNamespace: creds.UserNamespace,
+ Vdso: vdso,
+ RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
+ RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
+ RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
+ }); err != nil {
+ return nil, fmt.Errorf("initializing kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+
+ // Create mount namespace without root as it's the minimum required to create
+ // the global thread group.
+ mntns, err := fs.NewMountNamespace(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ ls, err := limits.NewLinuxLimitSet()
+ if err != nil {
+ return nil, err
+ }
+ tg := k.NewThreadGroup(mntns, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
+ k.TestOnly_SetGlobalInit(tg)
+
+ return k, nil
+}
+
+// createTask creates a new bare bones task for tests.
+func createTask(ctx context.Context, name string, tc *kernel.ThreadGroup) (*kernel.Task, error) {
+ k := kernel.KernelFromContext(ctx)
+ config := &kernel.TaskConfig{
+ Kernel: k,
+ ThreadGroup: tc,
+ TaskContext: &kernel.TaskContext{Name: name},
+ Credentials: auth.CredentialsFromContext(ctx),
+ AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
+ UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
+ IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
+ AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ }
+ return k.TaskSet().NewTask(config)
+}
+
+func createMemoryFile() (*pgalloc.MemoryFile, error) {
+ const memfileName = "test-memory"
+ memfd, err := memutil.CreateMemFD(memfileName, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating memfd: %v", err)
+ }
+ memfile := os.NewFile(uintptr(memfd), memfileName)
+ mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
+ if err != nil {
+ memfile.Close()
+ return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
+ }
+ return mf, nil
+}
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
new file mode 100644
index 000000000..d09182c77
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -0,0 +1,69 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for procfs.
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// procFSType is the factory class for procfs.
+//
+// +stateify savable
+type procFSType struct{}
+
+var _ vfs.FilesystemType = (*procFSType)(nil)
+
+// GetFilesystem implements vfs.FilesystemType.
+func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, nil, fmt.Errorf("procfs requires a kernel")
+ }
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return nil, nil, fmt.Errorf("procfs requires a PID namespace")
+ }
+
+ procfs := &kernfs.Filesystem{}
+ procfs.VFSFilesystem().Init(vfsObj, procfs)
+
+ _, dentry := newTasksInode(procfs, k, pidns)
+ return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
+}
+
+// dynamicInode is an overfitted interface for common Inodes with
+// dynamicByteSource types used in procfs.
+type dynamicInode interface {
+ kernfs.Inode
+ vfs.DynamicBytesSource
+
+ Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+}
+
+func newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ inode.Init(creds, ino, inode, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go
index 9135afef1..5351d86e8 100644
--- a/pkg/sentry/fsimpl/proc/loadavg.go
+++ b/pkg/sentry/fsimpl/proc/loadavg.go
@@ -19,15 +19,17 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
)
// loadavgData backs /proc/loadavg.
//
// +stateify savable
-type loadavgData struct{}
+type loadavgData struct {
+ kernfs.DynamicBytesFile
+}
-var _ vfs.DynamicBytesSource = (*loadavgData)(nil)
+var _ dynamicInode = (*loadavgData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go
index 9a827cd66..cbdd4f3fc 100644
--- a/pkg/sentry/fsimpl/proc/meminfo.go
+++ b/pkg/sentry/fsimpl/proc/meminfo.go
@@ -19,21 +19,23 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
)
// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
//
// +stateify savable
type meminfoData struct {
+ kernfs.DynamicBytesFile
+
// k is the owning Kernel.
k *kernel.Kernel
}
-var _ vfs.DynamicBytesSource = (*meminfoData)(nil)
+var _ dynamicInode = (*meminfoData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go
index 720db3828..50894a534 100644
--- a/pkg/sentry/fsimpl/proc/stat.go
+++ b/pkg/sentry/fsimpl/proc/stat.go
@@ -20,8 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
)
// cpuStats contains the breakdown of CPU time for /proc/stat.
@@ -66,11 +66,13 @@ func (c cpuStats) String() string {
//
// +stateify savable
type statData struct {
+ kernfs.DynamicBytesFile
+
// k is the owning Kernel.
k *kernel.Kernel
}
-var _ vfs.DynamicBytesSource = (*statData)(nil)
+var _ dynamicInode = (*statData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index c46e05c3a..11a64c777 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -15,247 +15,176 @@
package proc
import (
- "bytes"
- "fmt"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// mapsCommon is embedded by mapsData and smapsData.
-type mapsCommon struct {
- t *kernel.Task
-}
-
-// mm gets the kernel task's MemoryManager. No additional reference is taken on
-// mm here. This is safe because MemoryManager.destroy is required to leave the
-// MemoryManager in a state where it's still usable as a DynamicBytesSource.
-func (md *mapsCommon) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- md.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- tmm = mm
- }
- })
- return tmm
-}
-
-// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+// taskInode represents the inode for /proc/PID/ directory.
//
// +stateify savable
-type mapsData struct {
- mapsCommon
+type taskInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ task *kernel.Task
}
-var _ vfs.DynamicBytesSource = (*mapsData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (md *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if mm := md.mm(); mm != nil {
- mm.ReadMapsDataInto(ctx, buf)
+var _ kernfs.Inode = (*taskInode)(nil)
+
+func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool) *kernfs.Dentry {
+ contents := map[string]*kernfs.Dentry{
+ //"auxv": newAuxvec(t, msrc),
+ //"cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ //"comm": newComm(t, msrc),
+ //"environ": newExecArgInode(t, msrc, environExecArg),
+ //"exe": newExe(t, msrc),
+ //"fd": newFdDir(t, msrc),
+ //"fdinfo": newFdInfoDir(t, msrc),
+ //"gid_map": newGIDMap(t, msrc),
+ "io": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, newIO(task, isThreadGroup)),
+ "maps": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &mapsData{task: task}),
+ //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ //"ns": newNamespaceDir(t, msrc),
+ "smaps": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &smapsData{task: task}),
+ "stat": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &taskStatData{t: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &statmData{t: task}),
+ "status": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &statusData{t: task, pidns: pidns}),
+ //"uid_map": newUIDMap(t, msrc),
}
- return nil
-}
-
-// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
-//
-// +stateify savable
-type smapsData struct {
- mapsCommon
-}
-
-var _ vfs.DynamicBytesSource = (*smapsData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (sd *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if mm := sd.mm(); mm != nil {
- mm.ReadSmapsDataInto(ctx, buf)
+ if isThreadGroup {
+ //contents["task"] = p.newSubtasks(t, msrc)
}
- return nil
-}
-
-// +stateify savable
-type taskStatData struct {
- t *kernel.Task
+ //if len(p.cgroupControllers) > 0 {
+ // contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers)
+ //}
- // If tgstats is true, accumulate fault stats (not implemented) and CPU
- // time across all tasks in t's thread group.
- tgstats bool
+ taskInode := &taskInode{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ taskInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
- // pidns is the PID namespace associated with the proc filesystem that
- // includes the file using this statData.
- pidns *kernel.PIDNamespace
-}
-
-var _ vfs.DynamicBytesSource = (*taskStatData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t))
- fmt.Fprintf(buf, "(%s) ", s.t.Name())
- fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0])
- ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
- ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
- }
- fmt.Fprintf(buf, "%d ", ppid)
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
- fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
- fmt.Fprintf(buf, "0 " /* flags */)
- fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
- var cputime usage.CPUStats
- if s.tgstats {
- cputime = s.t.ThreadGroup().CPUStats()
- } else {
- cputime = s.t.CPUStats()
- }
- fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- cputime = s.t.ThreadGroup().JoinedChildCPUStats()
- fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness())
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count())
+ inode := &taskOwnedInode{Inode: taskInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
- // itrealvalue. Since kernel 2.6.17, this field is no longer
- // maintained, and is hard coded as 0.
- fmt.Fprintf(buf, "0 ")
+ taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := taskInode.OrderedChildren.Populate(dentry, contents)
+ taskInode.IncLinks(links)
- // Start time is relative to boot time, expressed in clock ticks.
- fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
+ return dentry
+}
- var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
- fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long
+// as the task is still running. When it's dead, another tasks with the same
+// PID could replace it.
+func (i *taskInode) Valid(ctx context.Context) bool {
+ return i.task.ExitState() != kernel.TaskExitDead
+}
- // rsslim.
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
+// Open implements kernfs.Inode.
+func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
- fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
- fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
- fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
- terminationSignal := linux.Signal(0)
- if s.t == s.t.ThreadGroup().Leader() {
- terminationSignal = s.t.ThreadGroup().TerminationSignal()
+// SetStat implements kernfs.Inode.
+func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
+ stat := opts.Stat
+ if stat.Mask&linux.STATX_MODE != 0 {
+ return syserror.EPERM
}
- fmt.Fprintf(buf, "%d ", terminationSignal)
- fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
- fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
- fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
- fmt.Fprintf(buf, "0\n" /* exit_code */)
-
return nil
}
-// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
-//
-// +stateify savable
-type statmData struct {
- t *kernel.Task
+// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
+// effective user and group.
+type taskOwnedInode struct {
+ kernfs.Inode
+
+ // owner is the task that owns this inode.
+ owner *kernel.Task
}
-var _ vfs.DynamicBytesSource = (*statmData)(nil)
+var _ kernfs.Inode = (*taskOwnedInode)(nil)
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), ino, inode, perm)
- fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
- return nil
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
}
-// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
-//
-// +stateify savable
-type statusData struct {
- t *kernel.Task
- pidns *kernel.PIDNamespace
+// Stat implements kernfs.Inode.
+func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx {
+ stat := i.Inode.Stat(fs)
+ uid, gid := i.getOwner(linux.FileMode(stat.Mode))
+ stat.UID = uint32(uid)
+ stat.GID = uint32(gid)
+ return stat
}
-var _ vfs.DynamicBytesSource = (*statusData)(nil)
+// CheckPermissions implements kernfs.Inode.
+func (i *taskOwnedInode) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := i.Mode()
+ uid, gid := i.getOwner(mode)
+ return vfs.GenericCheckPermissions(
+ creds,
+ ats,
+ mode.FileType() == linux.ModeDirectory,
+ uint16(mode),
+ uid,
+ gid,
+ )
+}
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name())
- fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus())
- fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
- fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
- ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
- ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
+ // By default, set the task owner as the file owner.
+ creds := i.owner.Credentials()
+ uid := creds.EffectiveKUID
+ gid := creds.EffectiveKGID
+
+ // Linux doesn't apply dumpability adjustments to world readable/executable
+ // directories so that applications can stat /proc/PID to determine the
+ // effective UID of a process. See fs/proc/base.c:task_dump_owner.
+ if mode.FileType() == linux.ModeDirectory && mode.Permissions() == 0555 {
+ return uid, gid
}
- fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
- tpid := kernel.ThreadID(0)
- if tracer := s.t.Tracer(); tracer != nil {
- tpid = s.pidns.IDOfTask(tracer)
+
+ // If the task is not dumpable, then root (in the namespace preferred)
+ // owns the file.
+ m := getMM(i.owner)
+ if m == nil {
+ return auth.RootKUID, auth.RootKGID
}
- fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
- var fds int
- var vss, rss, data uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ if m.Dumpability() != mm.UserDumpable {
+ uid = auth.RootKUID
+ if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() {
+ uid = kuid
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
+ gid = auth.RootKGID
+ if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() {
+ gid = kgid
}
- })
- fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
- fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
- fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
- fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
- fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
- creds := s.t.Credentials()
- fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
- fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
- fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
- fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
- fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
- return nil
-}
-
-// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
-type ioUsage interface {
- // IOUsage returns the io usage data.
- IOUsage() *usage.IO
-}
-
-// +stateify savable
-type ioData struct {
- ioUsage
+ }
+ return uid, gid
}
-var _ vfs.DynamicBytesSource = (*ioData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- io := usage.IO{}
- io.Accumulate(i.IOUsage())
-
- fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
- fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
- fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
- fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
- fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
- fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
- fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
- return nil
+func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
+ if isThreadGroup {
+ return &ioData{ioUsage: t.ThreadGroup()}
+ }
+ return &ioData{ioUsage: t}
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
new file mode 100644
index 000000000..93f0e1aa8
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -0,0 +1,272 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// mm here. This is safe because MemoryManager.destroy is required to leave the
+// MemoryManager in a state where it's still usable as a DynamicBytesSource.
+func getMM(task *kernel.Task) *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadMapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*smapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadSmapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// +stateify savable
+type taskStatData struct {
+ kernfs.DynamicBytesFile
+
+ t *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*taskStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t))
+ fmt.Fprintf(buf, "(%s) ", s.t.Name())
+ fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "%d ", ppid)
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
+ fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(buf, "0 " /* flags */)
+ fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.t.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.t.CPUStats()
+ }
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.t.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness())
+ fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.t == s.t.ThreadGroup().Leader() {
+ terminationSignal = s.t.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(buf, "%d ", terminationSignal)
+ fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(buf, "0\n" /* exit_code */)
+
+ return nil
+}
+
+// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ kernfs.DynamicBytesFile
+
+ t *kernel.Task
+}
+
+var _ dynamicInode = (*statmData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ return nil
+}
+
+// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ kernfs.DynamicBytesFile
+
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*statusData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name())
+ fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus())
+ fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
+ fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.t.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
+ creds := s.t.Credentials()
+ fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(buf, "Mems_allowed_list:\t0\n")
+ return nil
+}
+
+// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ kernfs.DynamicBytesFile
+
+ ioUsage
+}
+
+var _ dynamicInode = (*ioData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
new file mode 100644
index 000000000..d8f92d52f
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -0,0 +1,218 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ defaultPermission = 0444
+ selfName = "self"
+ threadSelfName = "thread-self"
+)
+
+// InoGenerator generates unique inode numbers for a given filesystem.
+type InoGenerator interface {
+ NextIno() uint64
+}
+
+// tasksInode represents the inode for /proc/ directory.
+//
+// +stateify savable
+type tasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ inoGen InoGenerator
+ pidns *kernel.PIDNamespace
+
+ // '/proc/self' and '/proc/thread-self' have custom directory offsets in
+ // Linux. So handle them outside of OrderedChildren.
+ selfSymlink *vfs.Dentry
+ threadSelfSymlink *vfs.Dentry
+}
+
+var _ kernfs.Inode = (*tasksInode)(nil)
+
+func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace) (*tasksInode, *kernfs.Dentry) {
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+ contents := map[string]*kernfs.Dentry{
+ //"cpuinfo": newCPUInfo(ctx, msrc),
+ //"filesystems": seqfile.NewSeqFileInode(ctx, &filesystemsData{}, msrc),
+ "loadavg": newDentry(root, inoGen.NextIno(), defaultPermission, &loadavgData{}),
+ "meminfo": newDentry(root, inoGen.NextIno(), defaultPermission, &meminfoData{k: k}),
+ "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), defaultPermission, "self/mounts"),
+ "stat": newDentry(root, inoGen.NextIno(), defaultPermission, &statData{k: k}),
+ //"uptime": newUptime(ctx, msrc),
+ //"version": newVersionData(root, inoGen.NextIno(), k),
+ "version": newDentry(root, inoGen.NextIno(), defaultPermission, &versionData{k: k}),
+ }
+
+ inode := &tasksInode{
+ pidns: pidns,
+ inoGen: inoGen,
+ selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ }
+ inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, contents)
+ inode.IncLinks(links)
+
+ return inode, dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ // Try to lookup a corresponding task.
+ tid, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ // If it failed to parse, check if it's one of the special handled files.
+ switch name {
+ case selfName:
+ return i.selfSymlink, nil
+ case threadSelfName:
+ return i.threadSelfSymlink, nil
+ }
+ return nil, syserror.ENOENT
+ }
+
+ task := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return nil, syserror.ENOENT
+ }
+
+ taskDentry := newTaskInode(i.inoGen, task, i.pidns, true)
+ return taskDentry.VFSDentry(), nil
+}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) Valid(ctx context.Context) bool {
+ return true
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+ // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
+ const FIRST_PROCESS_ENTRY = 256
+
+ // Use maxTaskID to shortcut searches that will result in 0 entries.
+ const maxTaskID = kernel.TasksLimit + 1
+ if offset >= maxTaskID {
+ return offset, nil
+ }
+
+ // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories
+ // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by
+ // '/proc/thread-self' and then '/proc/[pid]'.
+ if offset < FIRST_PROCESS_ENTRY {
+ offset = FIRST_PROCESS_ENTRY
+ }
+
+ if offset == FIRST_PROCESS_ENTRY {
+ dirent := vfs.Dirent{
+ Name: selfName,
+ Type: linux.DT_LNK,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ if offset == FIRST_PROCESS_ENTRY+1 {
+ dirent := vfs.Dirent{
+ Name: threadSelfName,
+ Type: linux.DT_LNK,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+
+ // Collect all tasks that TGIDs are greater than the offset specified. Per
+ // Linux we only include in directory listings if it's the leader. But for
+ // whatever crazy reason, you can still walk to the given node.
+ var tids []int
+ startTid := offset - FIRST_PROCESS_ENTRY - 2
+ for _, tg := range i.pidns.ThreadGroups() {
+ tid := i.pidns.IDOfThreadGroup(tg)
+ if int64(tid) < startTid {
+ continue
+ }
+ if leader := tg.Leader(); leader != nil {
+ tids = append(tids, int(tid))
+ }
+ }
+
+ if len(tids) == 0 {
+ return offset, nil
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.inoGen.NextIno(),
+ NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ return maxTaskID, nil
+}
+
+// Open implements kernfs.Inode.
+func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
+
+func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
+ stat := i.InodeAttrs.Stat(vsfs)
+
+ // Add dynamic children to link count.
+ for _, tg := range i.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ stat.Nlink++
+ }
+ }
+
+ return stat
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
new file mode 100644
index 000000000..91f30a798
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -0,0 +1,92 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type selfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*selfSymlink)(nil)
+
+func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &selfSymlink{pidns: pidns}
+ inode.Init(creds, ino, linux.ModeSymlink|perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ if tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return strconv.FormatUint(uint64(tgid), 10), nil
+}
+
+type threadSelfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*threadSelfSymlink)(nil)
+
+func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &threadSelfSymlink{pidns: pidns}
+ inode.Init(creds, ino, linux.ModeSymlink|perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ tid := s.pidns.IDOfTask(t)
+ if tid == 0 || tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return fmt.Sprintf("%d/task/%d", tgid, tid), nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
new file mode 100644
index 000000000..ca8c87ec2
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -0,0 +1,555 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "math"
+ "path"
+ "strconv"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var (
+ // Next offset 256 by convention. Adds 1 for the next offset.
+ selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1}
+ threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1}
+
+ // /proc/[pid] next offset starts at 256+2 (files above), then adds the
+ // PID, and adds 1 for the next offset.
+ proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1}
+ proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1}
+ proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1}
+)
+
+type testIterDirentsCallback struct {
+ dirents []vfs.Dirent
+}
+
+func (t *testIterDirentsCallback) Handle(d vfs.Dirent) bool {
+ t.dirents = append(t.dirents, d)
+ return true
+}
+
+func checkDots(dirs []vfs.Dirent) ([]vfs.Dirent, error) {
+ if got := len(dirs); got < 2 {
+ return dirs, fmt.Errorf("wrong number of dirents, want at least: 2, got: %d: %v", got, dirs)
+ }
+ for i, want := range []string{".", ".."} {
+ if got := dirs[i].Name; got != want {
+ return dirs, fmt.Errorf("wrong name, want: %s, got: %s", want, got)
+ }
+ if got := dirs[i].Type; got != linux.DT_DIR {
+ return dirs, fmt.Errorf("wrong type, want: %d, got: %d", linux.DT_DIR, got)
+ }
+ }
+ return dirs[2:], nil
+}
+
+func checkTasksStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
+ wants := map[string]vfs.Dirent{
+ "loadavg": {Type: linux.DT_REG},
+ "meminfo": {Type: linux.DT_REG},
+ "mounts": {Type: linux.DT_LNK},
+ "self": selfLink,
+ "stat": {Type: linux.DT_REG},
+ "thread-self": threadSelfLink,
+ "version": {Type: linux.DT_REG},
+ }
+ return checkFiles(gots, wants)
+}
+
+func checkTaskStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
+ wants := map[string]vfs.Dirent{
+ "io": {Type: linux.DT_REG},
+ "maps": {Type: linux.DT_REG},
+ "smaps": {Type: linux.DT_REG},
+ "stat": {Type: linux.DT_REG},
+ "statm": {Type: linux.DT_REG},
+ "status": {Type: linux.DT_REG},
+ }
+ return checkFiles(gots, wants)
+}
+
+func checkFiles(gots []vfs.Dirent, wants map[string]vfs.Dirent) ([]vfs.Dirent, error) {
+ // Go over all files, when there is a match, the file is removed from both
+ // 'gots' and 'wants'. wants is expected to reach 0, as all files must
+ // be present. Remaining files in 'gots', is returned to caller to decide
+ // whether this is valid or not.
+ for i := 0; i < len(gots); i++ {
+ got := gots[i]
+ want, ok := wants[got.Name]
+ if !ok {
+ continue
+ }
+ if want.Type != got.Type {
+ return gots, fmt.Errorf("wrong file type, want: %v, got: %v: %+v", want.Type, got.Type, got)
+ }
+ if want.NextOff != 0 && want.NextOff != got.NextOff {
+ return gots, fmt.Errorf("wrong dirent offset, want: %v, got: %v: %+v", want.NextOff, got.NextOff, got)
+ }
+
+ delete(wants, got.Name)
+ gots = append(gots[0:i], gots[i+1:]...)
+ i--
+ }
+ if len(wants) != 0 {
+ return gots, fmt.Errorf("not all files were found, missing: %+v", wants)
+ }
+ return gots, nil
+}
+
+func setup() (context.Context, *vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+ k, err := boot()
+ if err != nil {
+ return nil, nil, vfs.VirtualDentry{}, fmt.Errorf("creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("procfs", &procFSType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "procfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace(): %v", err)
+ }
+ return ctx, vfsObj, mntns.Root(), nil
+}
+
+func TestTasksEmpty(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTasksStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+}
+
+func TestTasks(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ tasks = append(tasks, task)
+ }
+
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTasksStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ lastPid := 0
+ for _, d := range cb.dirents {
+ pid, err := strconv.Atoi(d.Name)
+ if err != nil {
+ t.Fatalf("Invalid process directory %q", d.Name)
+ }
+ if lastPid > pid {
+ t.Errorf("pids not in order: %v", cb.dirents)
+ }
+ found := false
+ for _, t := range tasks {
+ if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) {
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("Additional task ID %d listed: %v", pid, tasks)
+ }
+ // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the
+ // PID, and adds 1 for the next offset.
+ if want := int64(256 + 2 + pid + 1); d.NextOff != want {
+ t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d)
+ }
+ }
+
+ // Test lookup.
+ for _, path := range []string{"/1", "/2"} {
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(path)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
+ }
+ buf := make([]byte, 1)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ if _, err := fd.Read(ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
+ t.Errorf("wrong error reading directory: %v", err)
+ }
+ }
+
+ if _, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/9999")},
+ &vfs.OpenOptions{},
+ ); err != syserror.ENOENT {
+ t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err)
+ }
+}
+
+func TestTasksOffset(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ for i := 0; i < 3; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ if _, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc); err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ }
+
+ for _, tc := range []struct {
+ name string
+ offset int64
+ wants map[string]vfs.Dirent
+ }{
+ {
+ name: "small offset",
+ offset: 100,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "offset at start",
+ offset: 256,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip /proc/self",
+ offset: 257,
+ wants: map[string]vfs.Dirent{
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip symlinks",
+ offset: 258,
+ wants: map[string]vfs.Dirent{
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip first process",
+ offset: 260,
+ wants: map[string]vfs.Dirent{
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "last process",
+ offset: 261,
+ wants: map[string]vfs.Dirent{
+ "3": proc3,
+ },
+ },
+ {
+ name: "after last",
+ offset: 262,
+ wants: nil,
+ },
+ {
+ name: "TaskLimit+1",
+ offset: kernel.TasksLimit + 1,
+ wants: nil,
+ },
+ {
+ name: "max",
+ offset: math.MaxInt64,
+ wants: nil,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ if _, err := fd.Impl().Seek(ctx, tc.offset, linux.SEEK_SET); err != nil {
+ t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ if cb.dirents, err = checkFiles(cb.dirents, tc.wants); err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+ })
+ }
+}
+
+func TestTask(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ _, err = createTask(ctx, "name", tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/1")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/1) failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTaskStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+}
+
+func TestProcSelf(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := createTask(ctx, "name", tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ fd, err := vfsObj.OpenAt(
+ task,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/self/"), FollowFinalSymlink: true},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/self/) failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTaskStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+}
+
+func iterateDir(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, fd *vfs.FileDescription) {
+ t.Logf("Iterating: /proc%s", fd.MappedName(ctx))
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ var err error
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ for _, d := range cb.dirents {
+ childPath := path.Join(fd.MappedName(ctx), d.Name)
+ if d.Type == linux.DT_LNK {
+ link, err := vfsObj.ReadlinkAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(childPath)},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", childPath, err)
+ } else {
+ t.Logf("Skipping symlink: /proc%s => %s", childPath, link)
+ }
+ continue
+ }
+
+ t.Logf("Opening: /proc%s", childPath)
+ child, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(childPath)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err)
+ continue
+ }
+ stat, err := child.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Errorf("Stat(%v) failed: %v", childPath, err)
+ }
+ if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type {
+ t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type)
+ }
+ if d.Type == linux.DT_DIR {
+ // Found another dir, let's do it again!
+ iterateDir(ctx, t, vfsObj, root, child)
+ }
+ }
+}
+
+// TestTree iterates all directories and stats every file.
+func TestTree(t *testing.T) {
+ uberCtx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(uberCtx)
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := createTask(uberCtx, fmt.Sprintf("name-%d", i), tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ tasks = append(tasks, task)
+ }
+
+ ctx := tasks[0]
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(uberCtx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ iterateDir(ctx, t, vfsObj, root, fd)
+}
diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go
index e1643d4e0..367f2396b 100644
--- a/pkg/sentry/fsimpl/proc/version.go
+++ b/pkg/sentry/fsimpl/proc/version.go
@@ -19,19 +19,21 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
)
// versionData implements vfs.DynamicBytesSource for /proc/version.
//
// +stateify savable
type versionData struct {
+ kernfs.DynamicBytesFile
+
// k is the owning Kernel.
k *kernel.Kernel
}
-var _ vfs.DynamicBytesSource = (*versionData)(nil)
+var _ dynamicInode = (*versionData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 0cc751eb8..a5b285987 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -1,14 +1,13 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "dentry_list",
out = "dentry_list.go",
- package = "memfs",
+ package = "tmpfs",
prefix = "dentry",
template = "//pkg/ilist:generic_list",
types = {
@@ -18,25 +17,34 @@ go_template_instance(
)
go_library(
- name = "memfs",
+ name = "tmpfs",
srcs = [
"dentry_list.go",
"directory.go",
"filesystem.go",
- "memfs.go",
"named_pipe.go",
"regular_file.go",
"symlink.go",
+ "tmpfs.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs",
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/fspath",
+ "//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
@@ -48,8 +56,9 @@ go_test(
size = "small",
srcs = ["benchmark_test.go"],
deps = [
- ":memfs",
+ ":tmpfs",
"//pkg/abi/linux",
+ "//pkg/fspath",
"//pkg/refs",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
@@ -62,15 +71,20 @@ go_test(
)
go_test(
- name = "memfs_test",
+ name = "tmpfs_test",
size = "small",
- srcs = ["pipe_test.go"],
- embed = [":memfs"],
+ srcs = [
+ "pipe_test.go",
+ "regular_file_test.go",
+ ],
+ embed = [":tmpfs"],
deps = [
"//pkg/abi/linux",
+ "//pkg/fspath",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 4a7a94a52..d88c83499 100644
--- a/pkg/sentry/fsimpl/memfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -21,12 +21,13 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -175,8 +176,10 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -193,9 +196,9 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
pop := vfs.PathOperation{
- Root: root,
- Start: vd,
- Pathname: name,
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(name),
}
if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
Mode: 0755,
@@ -216,7 +219,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: vd,
- Pathname: filename,
+ Path: fspath.Parse(filename),
FollowFinalSymlink: true,
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
@@ -237,7 +240,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
@@ -364,8 +367,10 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -378,9 +383,9 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
root := mntns.Root()
defer root.DecRef()
pop := vfs.PathOperation{
- Root: root,
- Start: root,
- Pathname: mountPointName,
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(mountPointName),
}
if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
Mode: 0755,
@@ -394,7 +399,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
}
defer mountPoint.DecRef()
// Create and mount the submount.
- if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil {
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
filePathBuilder.WriteString(mountPointName)
@@ -408,9 +413,9 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
pop := vfs.PathOperation{
- Root: root,
- Start: vd,
- Pathname: name,
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(name),
}
if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
Mode: 0755,
@@ -438,7 +443,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: vd,
- Pathname: filename,
+ Path: fspath.Parse(filename),
FollowFinalSymlink: true,
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
@@ -458,7 +463,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 0bd82e480..887ca2619 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
new file mode 100644
index 000000000..26979729e
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -0,0 +1,698 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // All filesystem state is in-memory.
+ return nil
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// stepLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
+ if !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ nextVFSD, err := rp.ResolveComponent(&d.vfsd)
+ if err != nil {
+ return nil, err
+ }
+ if nextVFSD == nil {
+ // Since the Dentry tree is the sole source of truth for tmpfs, if it's
+ // not in the Dentry tree, it doesn't exist.
+ return nil, syserror.ENOENT
+ }
+ next := nextVFSD.Impl().(*dentry)
+ if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
+ // TODO: symlink traversals update access time
+ if err := rp.HandleSymlink(symlink.target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return next, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// walkParentDirLocked is loosely analogous to Linux's
+// fs/namei.c:path_parentat().
+//
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
+ for !rp.Final() {
+ next, err := stepLocked(rp, d)
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat().
+//
+// Preconditions: filesystem.mu must be locked.
+func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ for !rp.Done() {
+ next, err := stepLocked(rp, d)
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// doCreateAt is loosely analogous to a conjunction of Linux's
+// fs/namei.c:filename_create() and done_path_create().
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string) error) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ // Call parent.vfsd.Child() instead of stepLocked() or rp.ResolveChild(),
+ // because if the child exists we want to return EEXIST immediately instead
+ // of attempting symlink/mount traversal.
+ if parent.vfsd.Child(name) != nil {
+ return syserror.EEXIST
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ // In memfs, the only way to cause a dentry to be disowned is by removing
+ // it from the filesystem, so this check is equivalent to checking if
+ // parent has been removed.
+ if parent.vfsd.IsDisowned() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ return create(parent, name)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ d := vd.Dentry().Impl().(*dentry)
+ if d.inode.isDir() {
+ return syserror.EPERM
+ }
+ if d.inode.nlink == 0 {
+ return syserror.ENOENT
+ }
+ if d.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ d.inode.incLinksLocked()
+ child := fs.newDentry(d.inode)
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return nil
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(rp, true /* dir */, func(parent *dentry, name string) error {
+ if parent.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ parent.inode.incLinksLocked() // from child's ".."
+ child := fs.newDentry(fs.newDirectory(rp.Credentials(), opts.Mode))
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error {
+ switch opts.Mode.FileType() {
+ case 0, linux.S_IFREG:
+ child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return nil
+ case linux.S_IFIFO:
+ child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return nil
+ case linux.S_IFBLK, linux.S_IFCHR, linux.S_IFSOCK:
+ // Not yet supported.
+ return syserror.EPERM
+ default:
+ return syserror.EINVAL
+ }
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ // Not yet supported.
+ return nil, syserror.EOPNOTSUPP
+ }
+
+ // Handle O_CREAT and !O_CREAT separately, since in the latter case we
+ // don't need fs.mu for writing.
+ if opts.Flags&linux.O_CREAT == 0 {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ return d.open(ctx, rp, opts.Flags, false /* afterCreate */)
+ }
+
+ mustCreate := opts.Flags&linux.O_EXCL != 0
+ start := rp.Start().Impl().(*dentry)
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ if rp.Done() {
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.open(ctx, rp, opts.Flags, false /* afterCreate */)
+ }
+afterTrailingSymlink:
+ parent, err := walkParentDirLocked(rp, start)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return nil, syserror.EISDIR
+ }
+ // Determine whether or not we need to create a file.
+ child, err := stepLocked(rp, parent)
+ if err == syserror.ENOENT {
+ // Already checked for searchability above; now check for writability.
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ return nil, err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer rp.Mount().EndWrite()
+ // Create and open the child.
+ child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return child.open(ctx, rp, opts.Flags, true)
+ }
+ if err != nil {
+ return nil, err
+ }
+ // Do we need to resolve a trailing symlink?
+ if !rp.Done() {
+ start = parent
+ goto afterTrailingSymlink
+ }
+ // Open existing file.
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return child.open(ctx, rp, opts.Flags, false)
+}
+
+func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(flags)
+ if !afterCreate {
+ if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
+ return nil, err
+ }
+ }
+ mnt := rp.Mount()
+ switch impl := d.inode.impl.(type) {
+ case *regularFile:
+ var fd regularFileFD
+ fd.readable = vfs.MayReadFileWithOpenFlags(flags)
+ fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
+ if fd.writable {
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ // mnt.EndWrite() is called by regularFileFD.Release().
+ }
+ fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{})
+ if flags&linux.O_TRUNC != 0 {
+ impl.mu.Lock()
+ impl.data.Truncate(0, impl.memFile)
+ atomic.StoreUint64(&impl.size, 0)
+ impl.mu.Unlock()
+ }
+ return &fd.vfsfd, nil
+ case *directory:
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ var fd directoryFD
+ fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{})
+ return &fd.vfsfd, nil
+ case *symlink:
+ // Can't open symlinks without O_PATH (which is unimplemented).
+ return nil, syserror.ELOOP
+ case *namedPipe:
+ return newNamedPipeFD(ctx, impl, rp, &d.vfsd, flags)
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl))
+ }
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return "", err
+ }
+ symlink, ok := d.inode.impl.(*symlink)
+ if !ok {
+ return "", syserror.EINVAL
+ }
+ return symlink.target, nil
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ // TODO(b/145974740): Support renameat2 flags.
+ return syserror.EINVAL
+ }
+
+ // Resolve newParent first to verify that it's on this Mount.
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ newParent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ oldParent := oldParentVD.Dentry().Impl().(*dentry)
+ if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ return err
+ }
+ // Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(),
+ // because if the existing child is a symlink or mount point then we want
+ // to rename over it rather than follow it.
+ renamedVFSD := oldParent.vfsd.Child(oldName)
+ if renamedVFSD == nil {
+ return syserror.ENOENT
+ }
+ renamed := renamedVFSD.Impl().(*dentry)
+ if renamed.inode.isDir() {
+ if renamed == newParent || renamedVFSD.IsAncestorOf(&newParent.vfsd) {
+ return syserror.EINVAL
+ }
+ if oldParent != newParent {
+ // Writability is needed to change renamed's "..".
+ if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ return err
+ }
+ replacedVFSD := newParent.vfsd.Child(newName)
+ var replaced *dentry
+ if replacedVFSD != nil {
+ replaced = replacedVFSD.Impl().(*dentry)
+ if replaced.inode.isDir() {
+ if !renamed.inode.isDir() {
+ return syserror.EISDIR
+ }
+ if replaced.vfsd.HasChildren() {
+ return syserror.ENOTEMPTY
+ }
+ } else {
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ if renamed.inode.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ } else {
+ if renamed.inode.isDir() && newParent.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ }
+ if newParent.vfsd.IsDisowned() {
+ return syserror.ENOENT
+ }
+
+ // Linux places this check before some of those above; we do it here for
+ // simplicity, under the assumption that applications are not intentionally
+ // doing noop renames expecting them to succeed where non-noop renames
+ // would fail.
+ if renamedVFSD == replacedVFSD {
+ return nil
+ }
+ vfsObj := rp.VirtualFilesystem()
+ oldParentDir := oldParent.inode.impl.(*directory)
+ newParentDir := newParent.inode.impl.(*directory)
+ if err := vfsObj.PrepareRenameDentry(vfs.MountNamespaceFromContext(ctx), renamedVFSD, replacedVFSD); err != nil {
+ return err
+ }
+ if replaced != nil {
+ newParentDir.childList.Remove(replaced)
+ if replaced.inode.isDir() {
+ newParent.inode.decLinksLocked() // from replaced's ".."
+ }
+ replaced.inode.decLinksLocked()
+ }
+ oldParentDir.childList.Remove(renamed)
+ newParentDir.childList.PushBack(renamed)
+ if renamed.inode.isDir() {
+ oldParent.inode.decLinksLocked()
+ newParent.inode.incLinksLocked()
+ }
+ // TODO: update timestamps and parent directory sizes
+ vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD)
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ childVFSD := parent.vfsd.Child(name)
+ if childVFSD == nil {
+ return syserror.ENOENT
+ }
+ child := childVFSD.Impl().(*dentry)
+ if !child.inode.isDir() {
+ return syserror.ENOTDIR
+ }
+ if childVFSD.HasChildren() {
+ return syserror.ENOTEMPTY
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ vfsObj := rp.VirtualFilesystem()
+ if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil {
+ return err
+ }
+ parent.inode.impl.(*directory).childList.Remove(child)
+ parent.inode.decLinksLocked() // from child's ".."
+ child.inode.decLinksLocked()
+ vfsObj.CommitDeleteDentry(childVFSD)
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, err := resolveLocked(rp)
+ if err != nil {
+ return err
+ }
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ // TODO: implement inode.setStat
+ return syserror.EPERM
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ var stat linux.Statx
+ d.inode.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, err := resolveLocked(rp)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ // TODO: actually implement statfs
+ return linux.Statfs{}, syserror.ENOSYS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error {
+ child := fs.newDentry(fs.newSymlink(rp.Credentials(), target))
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return nil
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ childVFSD := parent.vfsd.Child(name)
+ if childVFSD == nil {
+ return syserror.ENOENT
+ }
+ child := childVFSD.Impl().(*dentry)
+ if child.inode.isDir() {
+ return syserror.EISDIR
+ }
+ if !rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ vfsObj := rp.VirtualFilesystem()
+ if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil {
+ return err
+ }
+ parent.inode.impl.(*directory).childList.Remove(child)
+ child.inode.decLinksLocked()
+ vfsObj.CommitDeleteDentry(childVFSD)
+ return nil
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(b/127675828): support extended attributes
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, err := resolveLocked(rp)
+ if err != nil {
+ return "", err
+ }
+ // TODO(b/127675828): support extended attributes
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, err := resolveLocked(rp)
+ if err != nil {
+ return err
+ }
+ // TODO(b/127675828): support extended attributes
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, err := resolveLocked(rp)
+ if err != nil {
+ return err
+ }
+ // TODO(b/127675828): support extended attributes
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return vfs.GenericPrependPath(vfsroot, vd, b)
+}
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 91cb4b1fc..40bde54de 100644
--- a/pkg/sentry/fsimpl/memfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -55,8 +55,6 @@ func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, v
return nil, err
}
mnt := rp.Mount()
- mnt.IncRef()
- vfsd.IncRef()
- fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
}
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 5bf527c80..70b42a6ec 100644
--- a/pkg/sentry/fsimpl/memfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"bytes"
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -38,7 +39,7 @@ func TestSeparateFDs(t *testing.T) {
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
rfdchan := make(chan *vfs.FileDescription)
@@ -76,7 +77,7 @@ func TestNonblockingRead(t *testing.T) {
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK}
@@ -108,7 +109,7 @@ func TestNonblockingWriteError(t *testing.T) {
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
@@ -126,7 +127,7 @@ func TestSingleFD(t *testing.T) {
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
openOpts := vfs.OpenOptions{Flags: linux.O_RDWR}
@@ -151,8 +152,10 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -160,10 +163,9 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create the pipe.
root := mntns.Root()
pop := vfs.PathOperation{
- Root: root,
- Start: root,
- Pathname: fileName,
- FollowFinalSymlink: true,
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
}
mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644}
if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil {
@@ -174,7 +176,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
new file mode 100644
index 000000000..f51e247a7
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -0,0 +1,357 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "io"
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type regularFile struct {
+ inode inode
+
+ // memFile is a platform.File used to allocate pages to this regularFile.
+ memFile *pgalloc.MemoryFile
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // data maps offsets into the file to offsets into memFile that store
+ // the file's data.
+ data fsutil.FileRangeSet
+
+ // size is the size of data, but accessed using atomic memory
+ // operations to avoid locking in inode.stat().
+ size uint64
+
+ // seals represents file seals on this inode.
+ seals uint32
+}
+
+func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
+ file := &regularFile{
+ memFile: fs.memFile,
+ }
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // These are immutable.
+ readable bool
+ writable bool
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ // offMu serializes operations that may mutate off.
+ off int64
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {
+ if fd.writable {
+ fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ }
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putRegularFileReadWriter(rw)
+ return int64(n), err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ srclen := src.NumBytes()
+ if srclen == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ end := offset + srclen
+ if end < offset {
+ // Overflow.
+ return 0, syserror.EFBIG
+ }
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ putRegularFileReadWriter(rw)
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size))
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *regularFileFD) Sync(ctx context.Context) error {
+ return nil
+}
+
+// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
+type regularFileReadWriter struct {
+ file *regularFile
+
+ // Offset into the file to read/write at. Note that this may be
+ // different from the FD offset if PRead/PWrite is used.
+ off uint64
+}
+
+var regularFileReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &regularFileReadWriter{}
+ },
+}
+
+func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter {
+ rw := regularFileReadWriterPool.Get().(*regularFileReadWriter)
+ rw.file = file
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putRegularFileReadWriter(rw *regularFileReadWriter) {
+ rw.file = nil
+ regularFileReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.file.mu.RLock()
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= rw.file.size {
+ rw.file.mu.RUnlock()
+ return 0, io.EOF
+ }
+ end := rw.file.size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ rw.file.mu.RUnlock()
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ rw.file.mu.Lock()
+
+ // Compute the range to write (overflow-checked).
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ rw.file.mu.Unlock()
+ return 0, syserror.EPERM
+ case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.off {
+ // Truncation would result in no data being written.
+ rw.file.mu.Unlock()
+ return 0, syserror.EPERM
+ }
+ }
+
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.off).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += uint64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.off > rw.file.size {
+ atomic.StoreUint64(&rw.file.size, rw.off)
+ }
+
+ rw.file.mu.Unlock()
+ return done, retErr
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
new file mode 100644
index 000000000..3731c5b6f
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -0,0 +1,224 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
+// the returned err is not nil, then cleanup should be called when the FD is no
+// longer needed.
+func newFileFD(ctx context.Context, filename string) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
+ }
+ root := mntns.Root()
+
+ // Create the file that will be write/read.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ FollowFinalSymlink: true,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: 0644,
+ })
+ if err != nil {
+ root.DecRef()
+ mntns.DecRef(vfsObj)
+ return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err)
+ }
+
+ return fd, func() {
+ root.DecRef()
+ mntns.DecRef(vfsObj)
+ }, nil
+}
+
+// Test that we can write some data to a file and read it back.`
+func TestSimpleWriteRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, "simpleReadWrite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write.
+ data := []byte("foobarbaz")
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+
+ // Seek back to beginning.
+ if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Seek failed: %v", err)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want {
+ t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want)
+ }
+
+ // Read.
+ buf := make([]byte, len(data))
+ n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Read got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(buf), string(data); got != want {
+ t.Errorf("Read got %q want %s", got, want)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+}
+
+func TestPWrite(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, "PRead")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill file with 1k 'a's.
+ data := bytes.Repeat([]byte{'a'}, 1000)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Write "gVisor is awesome" at various offsets.
+ buf := []byte("gVisor is awesome")
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PWrite offset=%d", offset)
+ t.Run(name, func(t *testing.T) {
+ n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{})
+ if err != nil {
+ t.Errorf("fd.PWrite got err %v want nil", err)
+ }
+ if n != int64(len(buf)) {
+ t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf))
+ }
+
+ // Update data to reflect expected file contents.
+ if len(data) < offset+len(buf) {
+ data = append(data, make([]byte, (offset+len(buf))-len(data))...)
+ }
+ copy(data[offset:], buf)
+
+ // Read the whole file and compare with data.
+ readBuf := make([]byte, len(data))
+ n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("fd.PRead failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.PRead got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(readBuf), string(data); got != want {
+ t.Errorf("PRead got %q want %s", got, want)
+ }
+
+ })
+ }
+}
+
+func TestPRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, "PRead")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write 100 sequences of 'gVisor is awesome'.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Read various sizes from various offsets.
+ sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000}
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+
+ for _, size := range sizes {
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PRead offset=%d size=%d", offset, size)
+ t.Run(name, func(t *testing.T) {
+ var (
+ wantRead []byte
+ wantErr error
+ )
+ if offset < len(data) {
+ wantRead = data[offset:]
+ } else if size > 0 {
+ wantErr = io.EOF
+ }
+ if offset+size < len(data) {
+ wantRead = wantRead[:size]
+ }
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{})
+ if err != wantErr {
+ t.Errorf("fd.PRead got err %v want %v", err, wantErr)
+ }
+ if n != int64(len(wantRead)) {
+ t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead))
+ }
+ if got := string(buf[:n]); got != string(wantRead) {
+ t.Errorf("fd.PRead got %q want %q", got, string(wantRead))
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index b2ac2cbeb..5246aca84 100644
--- a/pkg/sentry/fsimpl/memfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 4cb2a4e0f..7be6faa5b 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -12,29 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package memfs provides a filesystem implementation that behaves like tmpfs:
+// Package tmpfs provides a filesystem implementation that behaves like tmpfs:
// the Dentry tree is the sole source of truth for the state of the filesystem.
//
-// memfs is intended primarily to demonstrate filesystem implementation
-// patterns. Real uses cases for an in-memory filesystem should use tmpfs
-// instead.
-//
// Lock order:
//
// filesystem.mu
// regularFileFD.offMu
// regularFile.mu
// inode.mu
-package memfs
+package tmpfs
import (
"fmt"
+ "math"
"sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -46,6 +44,9 @@ type FilesystemType struct{}
type filesystem struct {
vfsfs vfs.Filesystem
+ // memFile is used to allocate pages to for regular files.
+ memFile *pgalloc.MemoryFile
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex
@@ -54,7 +55,13 @@ type filesystem struct {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- var fs filesystem
+ memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
+ if memFileProvider == nil {
+ panic("MemoryFileProviderFromContext returned nil")
+ }
+ fs := filesystem{
+ memFile: memFileProvider.MemoryFile(),
+ }
fs.vfsfs.Init(vfsObj, &fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
return &fs.vfsfs, &root.vfsd, nil
@@ -64,12 +71,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
func (fs *filesystem) Release() {
}
-// Sync implements vfs.FilesystemImpl.Sync.
-func (fs *filesystem) Sync(ctx context.Context) error {
- // All filesystem state is in-memory.
- return nil
-}
-
// dentry implements vfs.DentryImpl.
type dentry struct {
vfsd vfs.Dentry
@@ -79,11 +80,11 @@ type dentry struct {
// immutable.
inode *inode
- // memfs doesn't count references on dentries; because the dentry tree is
+ // tmpfs doesn't count references on dentries; because the dentry tree is
// the sole source of truth, it is by definition always consistent with the
// state of the filesystem. However, it does count references on inodes,
// because inode resources are released when all references are dropped.
- // (memfs doesn't really have resources to release, but we implement
+ // (tmpfs doesn't really have resources to release, but we implement
// reference counting because tmpfs regular files will.)
// dentryEntry (ugh) links dentries into their parent directory.childList.
@@ -137,6 +138,8 @@ type inode struct {
impl interface{} // immutable
}
+const maxLinks = math.MaxUint32
+
func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
i.refs = 1
i.mode = uint32(mode)
@@ -147,25 +150,33 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials,
i.impl = impl
}
-// Preconditions: filesystem.mu must be locked for writing.
+// incLinksLocked increments i's link count.
+//
+// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
+// i.nlink < maxLinks.
func (i *inode) incLinksLocked() {
- if atomic.AddUint32(&i.nlink, 1) <= 1 {
- panic("memfs.inode.incLinksLocked() called with no existing links")
+ if i.nlink == 0 {
+ panic("tmpfs.inode.incLinksLocked() called with no existing links")
}
+ if i.nlink == maxLinks {
+ panic("memfs.inode.incLinksLocked() called with maximum link count")
+ }
+ atomic.AddUint32(&i.nlink, 1)
}
-// Preconditions: filesystem.mu must be locked for writing.
+// decLinksLocked decrements i's link count.
+//
+// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
func (i *inode) decLinksLocked() {
- if nlink := atomic.AddUint32(&i.nlink, ^uint32(0)); nlink == 0 {
- i.decRef()
- } else if nlink == ^uint32(0) { // negative overflow
- panic("memfs.inode.decLinksLocked() called with no existing links")
+ if i.nlink == 0 {
+ panic("tmpfs.inode.decLinksLocked() called with no existing links")
}
+ atomic.AddUint32(&i.nlink, ^uint32(0))
}
func (i *inode) incRef() {
if atomic.AddInt64(&i.refs, 1) <= 1 {
- panic("memfs.inode.incRef() called without holding a reference")
+ panic("tmpfs.inode.incRef() called without holding a reference")
}
}
@@ -184,14 +195,14 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
// This is unnecessary; it's mostly to simulate what tmpfs would do.
- if regfile, ok := i.impl.(*regularFile); ok {
- regfile.mu.Lock()
- regfile.data = nil
- atomic.StoreInt64(&regfile.dataLen, 0)
- regfile.mu.Unlock()
+ if regFile, ok := i.impl.(*regularFile); ok {
+ regFile.mu.Lock()
+ regFile.data.DropAll(regFile.memFile)
+ atomic.StoreUint64(&regFile.size, 0)
+ regFile.mu.Unlock()
}
} else if refs < 0 {
- panic("memfs.inode.decRef() called without holding a reference")
+ panic("tmpfs.inode.decRef() called without holding a reference")
}
}
@@ -215,7 +226,7 @@ func (i *inode) statTo(stat *linux.Statx) {
case *regularFile:
stat.Mode |= linux.S_IFREG
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
- stat.Size = uint64(atomic.LoadInt64(&impl.dataLen))
+ stat.Size = uint64(atomic.LoadUint64(&impl.size))
// In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in
// a uint64 accessed using atomic memory operations to avoid taking
// locks).
@@ -256,13 +267,11 @@ func (i *inode) direntType() uint8 {
}
}
-// fileDescription is embedded by memfs implementations of
+// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
-
- flags uint32 // status flags; immutable
}
func (fd *fileDescription) filesystem() *filesystem {
@@ -273,18 +282,6 @@ func (fd *fileDescription) inode() *inode {
return fd.vfsfd.Dentry().Impl().(*dentry).inode
}
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
-}
-
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index bd3fb4c03..8653d2f63 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -762,7 +762,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
mounts.IncRef()
}
- tg := k.newThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock)
+ tg := k.NewThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
ctx := args.NewContext(k)
// Get the root directory from the MountNamespace.
@@ -1191,6 +1191,11 @@ func (k *Kernel) GlobalInit() *ThreadGroup {
return k.globalInit
}
+// TestOnly_SetGlobalInit sets the thread group with ID 1 in the root PID namespace.
+func (k *Kernel) TestOnly_SetGlobalInit(tg *ThreadGroup) {
+ k.globalInit = tg
+}
+
// ApplicationCores returns the number of CPUs visible to sandboxed
// applications.
func (k *Kernel) ApplicationCores() uint {
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 24ea002ba..b14429854 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -15,17 +15,29 @@
package kernel
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
-// Restartable sequences, as described in https://lwn.net/Articles/650333/.
+// Restartable sequences.
+//
+// We support two different APIs for restartable sequences.
+//
+// 1. The upstream interface added in v4.18.
+// 2. The interface described in https://lwn.net/Articles/650333/.
+//
+// Throughout this file and other parts of the kernel, the latter is referred
+// to as "old rseq". This interface was never merged upstream, but is supported
+// for a limited set of applications that use it regardless.
-// RSEQCriticalRegion describes a restartable sequence critical region.
+// OldRSeqCriticalRegion describes an old rseq critical region.
//
// +stateify savable
-type RSEQCriticalRegion struct {
+type OldRSeqCriticalRegion struct {
// When a task in this thread group has its CPU preempted (as defined by
// platform.ErrContextCPUPreempted) or has a signal delivered to an
// application handler while its instruction pointer is in CriticalSection,
@@ -35,86 +47,359 @@ type RSEQCriticalRegion struct {
Restart usermem.Addr
}
-// RSEQAvailable returns true if t supports restartable sequences.
-func (t *Task) RSEQAvailable() bool {
+// RSeqAvailable returns true if t supports (old and new) restartable sequences.
+func (t *Task) RSeqAvailable() bool {
return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption()
}
-// RSEQCriticalRegion returns a copy of t's thread group's current restartable
-// sequence.
-func (t *Task) RSEQCriticalRegion() RSEQCriticalRegion {
- return *t.tg.rscr.Load().(*RSEQCriticalRegion)
+// SetRSeq registers addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr != 0 {
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EINVAL
+ }
+ return syserror.EBUSY
+ }
+
+ // rseq must be aligned and correctly sized.
+ if addr&(linux.AlignOfRSeq-1) != 0 {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok {
+ return syserror.EFAULT
+ }
+
+ t.rseqAddr = addr
+ t.rseqSignature = signature
+
+ // Initialize the CPUID.
+ //
+ // Linux implicitly does this on return from userspace, where failure
+ // would cause SIGSEGV.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return syserror.EFAULT
+ }
+
+ return nil
}
-// SetRSEQCriticalRegion replaces t's thread group's restartable sequence.
+// ClearRSeq unregisters addr as this thread's rseq structure.
//
-// Preconditions: t.RSEQAvailable() == true.
-func (t *Task) SetRSEQCriticalRegion(rscr RSEQCriticalRegion) error {
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr == 0 {
+ return syserror.EINVAL
+ }
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EPERM
+ }
+
+ if err := t.rseqClearCPU(); err != nil {
+ return err
+ }
+
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ if t.oldRSeqCPUAddr == 0 {
+ // rseqCPU no longer needed.
+ t.rseqCPU = -1
+ }
+
+ return nil
+}
+
+// OldRSeqCriticalRegion returns a copy of t's thread group's current
+// old restartable sequence.
+func (t *Task) OldRSeqCriticalRegion() OldRSeqCriticalRegion {
+ return *t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+}
+
+// SetOldRSeqCriticalRegion replaces t's thread group's old restartable
+// sequence.
+//
+// Preconditions: t.RSeqAvailable() == true.
+func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
// These checks are somewhat more lenient than in Linux, which (bizarrely)
- // requires rscr.CriticalSection to be non-empty and rscr.Restart to be
- // outside of rscr.CriticalSection, even if rscr.CriticalSection.Start == 0
+ // requires r.CriticalSection to be non-empty and r.Restart to be
+ // outside of r.CriticalSection, even if r.CriticalSection.Start == 0
// (which disables the critical region).
- if rscr.CriticalSection.Start == 0 {
- rscr.CriticalSection.End = 0
- rscr.Restart = 0
- t.tg.rscr.Store(&rscr)
+ if r.CriticalSection.Start == 0 {
+ r.CriticalSection.End = 0
+ r.Restart = 0
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
- if rscr.CriticalSection.Start >= rscr.CriticalSection.End {
+ if r.CriticalSection.Start >= r.CriticalSection.End {
return syserror.EINVAL
}
- if rscr.CriticalSection.Contains(rscr.Restart) {
+ if r.CriticalSection.Contains(r.Restart) {
return syserror.EINVAL
}
- // TODO(jamieliu): check that rscr.CriticalSection and rscr.Restart are in
- // the application address range, for consistency with Linux
- t.tg.rscr.Store(&rscr)
+ // TODO(jamieliu): check that r.CriticalSection and r.Restart are in
+ // the application address range, for consistency with Linux.
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
-// RSEQCPUAddr returns the address that RSEQ will keep updated with t's CPU
-// number.
+// OldRSeqCPUAddr returns the address that old rseq will keep updated with t's
+// CPU number.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) RSEQCPUAddr() usermem.Addr {
- return t.rseqCPUAddr
+func (t *Task) OldRSeqCPUAddr() usermem.Addr {
+ return t.oldRSeqCPUAddr
}
-// SetRSEQCPUAddr replaces the address that RSEQ will keep updated with t's CPU
-// number.
+// SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with
+// t's CPU number.
//
-// Preconditions: t.RSEQAvailable() == true. The caller must be running on the
+// Preconditions: t.RSeqAvailable() == true. The caller must be running on the
// task goroutine. t's AddressSpace must be active.
-func (t *Task) SetRSEQCPUAddr(addr usermem.Addr) error {
- t.rseqCPUAddr = addr
- if addr != 0 {
- t.rseqCPU = int32(hostcpu.GetCPU())
- if err := t.rseqCopyOutCPU(); err != nil {
- t.rseqCPUAddr = 0
- t.rseqCPU = -1
- return syserror.EINVAL // yes, EINVAL, not err or EFAULT
- }
- } else {
- t.rseqCPU = -1
+func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error {
+ t.oldRSeqCPUAddr = addr
+
+ // Check that addr is writable.
+ //
+ // N.B. rseqUpdateCPU may fail on a bad t.rseqAddr as well. That's
+ // unfortunate, but unlikely in a correct program.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.oldRSeqCPUAddr = 0
+ return syserror.EINVAL // yes, EINVAL, not err or EFAULT
}
return nil
}
// Preconditions: The caller must be running on the task goroutine. t's
// AddressSpace must be active.
-func (t *Task) rseqCopyOutCPU() error {
+func (t *Task) rseqUpdateCPU() error {
+ if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 {
+ t.rseqCPU = -1
+ return nil
+ }
+
+ t.rseqCPU = int32(hostcpu.GetCPU())
+
+ // Update both CPUs, even if one fails.
+ rerr := t.rseqCopyOutCPU()
+ oerr := t.oldRSeqCopyOutCPU()
+
+ if rerr != nil {
+ return rerr
+ }
+ return oerr
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) oldRSeqCopyOutCPU() error {
+ if t.oldRSeqCPUAddr == 0 {
+ return nil
+ }
+
buf := t.CopyScratchBuffer(4)
usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
- _, err := t.CopyOutBytes(t.rseqCPUAddr, buf)
+ _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqCopyOutCPU() error {
+ if t.rseqAddr == 0 {
+ return nil
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqClearCPU() error {
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
return err
}
+// rseqAddrInterrupt checks if IP is in a critical section, and aborts if so.
+//
+// This is a bit complex since both the RSeq and RSeqCriticalSection structs
+// are stored in userspace. So we must:
+//
+// 1. Copy in the address of RSeqCriticalSection from RSeq.
+// 2. Copy in RSeqCriticalSection itself.
+// 3. Validate critical section struct version, address range, abort address.
+// 4. Validate the abort signature (4 bytes preceding abort IP match expected
+// signature).
+// 5. Clear address of RSeqCriticalSection from RSeq.
+// 6. Finally, conditionally abort.
+//
+// See kernel/rseq.c:rseq_ip_fixup for reference.
+//
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqAddrInterrupt() {
+ if t.rseqAddr == 0 {
+ return
+ }
+
+ critAddrAddr, ok := t.rseqAddr.AddLength(linux.OffsetOfRSeqCriticalSection)
+ if !ok {
+ // SetRSeq should validate this.
+ panic(fmt.Sprintf("t.rseqAddr (%#x) not large enough", t.rseqAddr))
+ }
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ t.Debugf("Only 64-bit rseq supported.")
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(critAddrAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf))
+ if critAddr == 0 {
+ return
+ }
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqCriticalSection)
+ if _, err := t.CopyInBytes(critAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Manually marshal RSeqCriticalSection as this is in the hot path when
+ // rseq is enabled. It must be as fast as possible.
+ //
+ // TODO(b/130243041): Replace with go_marshal.
+ cs := linux.RSeqCriticalSection{
+ Version: usermem.ByteOrder.Uint32(buf[0:4]),
+ Flags: usermem.ByteOrder.Uint32(buf[4:8]),
+ Start: usermem.ByteOrder.Uint64(buf[8:16]),
+ PostCommitOffset: usermem.ByteOrder.Uint64(buf[16:24]),
+ Abort: usermem.ByteOrder.Uint64(buf[24:32]),
+ }
+
+ if cs.Version != 0 {
+ t.Debugf("Unknown version in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ start := usermem.Addr(cs.Start)
+ critRange, ok := start.ToRange(cs.PostCommitOffset)
+ if !ok {
+ t.Debugf("Invalid start and offset in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ abort := usermem.Addr(cs.Abort)
+ if critRange.Contains(abort) {
+ t.Debugf("Abort in critical section in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Verify signature.
+ sigAddr := abort - linux.SizeOfRSeqSignature
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqSignature)
+ if _, err := t.CopyInBytes(sigAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section signature from %#x for rseq: %v", sigAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ sig := usermem.ByteOrder.Uint32(buf)
+ if sig != t.rseqSignature {
+ t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Clear the critical section address.
+ //
+ // NOTE(b/143949567): We don't support any rseq flags, so we always
+ // restart if we are in the critical section, and thus *always* clear
+ // critAddrAddr.
+ if _, err := t.MemoryManager().ZeroOut(t, critAddrAddr, int64(t.Arch().Width()), usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ t.Debugf("Failed to clear critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Finally we can actually decide whether or not to restart.
+ if !critRange.Contains(usermem.Addr(t.Arch().IP())) {
+ return
+ }
+
+ t.Arch().SetIP(uintptr(cs.Abort))
+}
+
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) rseqInterrupt() {
- rscr := t.tg.rscr.Load().(*RSEQCriticalRegion)
- if ip := t.Arch().IP(); rscr.CriticalSection.Contains(usermem.Addr(ip)) {
- t.Debugf("Interrupted RSEQ critical section at %#x; restarting at %#x", ip, rscr.Restart)
- t.Arch().SetIP(uintptr(rscr.Restart))
- t.Arch().SetRSEQInterruptedIP(ip)
+func (t *Task) oldRSeqInterrupt() {
+ r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+ if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) {
+ t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart)
+ t.Arch().SetIP(uintptr(r.Restart))
+ t.Arch().SetOldRSeqInterruptedIP(ip)
}
}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) rseqInterrupt() {
+ t.rseqAddrInterrupt()
+ t.oldRSeqInterrupt()
+}
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 5bd610f68..19034a21e 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -71,9 +71,20 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
// shms maps segment ids to segments.
+ //
+ // shms holds all referenced segments, which are removed on the last
+ // DecRef. Thus, it cannot itself hold a reference on the Shm.
+ //
+ // Since removal only occurs after the last (unlocked) DecRef, there
+ // exists a short window during which a Shm still exists in Shm, but is
+ // unreferenced. Users must use TryIncRef to determine if the Shm is
+ // still valid.
shms map[ID]*Shm
// keysToShms maps segment keys to segments.
+ //
+ // Shms in keysToShms are guaranteed to be referenced, as they are
+ // removed by disassociateKey before the last DecRef.
keysToShms map[Key]*Shm
// Sum of the sizes of all existing segments rounded up to page size, in
@@ -95,10 +106,18 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
}
// FindByID looks up a segment given an ID.
+//
+// FindByID returns a reference on Shm.
func (r *Registry) FindByID(id ID) *Shm {
r.mu.Lock()
defer r.mu.Unlock()
- return r.shms[id]
+ s := r.shms[id]
+ // Take a reference on s. If TryIncRef fails, s has reached the last
+ // DecRef, but hasn't quite been removed from r.shms yet.
+ if s != nil && s.TryIncRef() {
+ return s
+ }
+ return nil
}
// dissociateKey removes the association between a segment and its key,
@@ -119,6 +138,8 @@ func (r *Registry) dissociateKey(s *Shm) {
// FindOrCreate looks up or creates a segment in the registry. It's functionally
// analogous to open(2).
+//
+// FindOrCreate returns a reference on Shm.
func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
// "A new segment was to be created and size is less than SHMMIN or
@@ -166,6 +187,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
return nil, syserror.EEXIST
}
+ shm.IncRef()
return shm, nil
}
@@ -193,7 +215,14 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
// Need to create a new segment.
creator := fs.FileOwnerFromContext(ctx)
perms := fs.FilePermsFromMode(mode)
- return r.newShm(ctx, pid, key, creator, perms, size)
+ s, err := r.newShm(ctx, pid, key, creator, perms, size)
+ if err != nil {
+ return nil, err
+ }
+ // The initial reference is held by s itself. Take another to return to
+ // the caller.
+ s.IncRef()
+ return s, nil
}
// newShm creates a new segment in the registry.
@@ -296,22 +325,26 @@ func (r *Registry) remove(s *Shm) {
// Shm represents a single shared memory segment.
//
-// Shm segment are backed directly by an allocation from platform
-// memory. Segments are always mapped as a whole, greatly simplifying how
-// mappings are tracked. However note that mremap and munmap calls may cause the
-// vma for a segment to become fragmented; which requires special care when
-// unmapping a segment. See mm/shm.go.
+// Shm segment are backed directly by an allocation from platform memory.
+// Segments are always mapped as a whole, greatly simplifying how mappings are
+// tracked. However note that mremap and munmap calls may cause the vma for a
+// segment to become fragmented; which requires special care when unmapping a
+// segment. See mm/shm.go.
//
// Segments persist until they are explicitly marked for destruction via
-// shmctl(SHM_RMID).
+// MarkDestroyed().
//
// Shm implements memmap.Mappable and memmap.MappingIdentity.
//
// +stateify savable
type Shm struct {
- // AtomicRefCount tracks the number of references to this segment from
- // maps. A segment always holds a reference to itself, until it's marked for
+ // AtomicRefCount tracks the number of references to this segment.
+ //
+ // A segment holds a reference to itself until it is marked for
// destruction.
+ //
+ // In addition to direct users, the MemoryManager will hold references
+ // via MappingIdentity.
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
@@ -484,9 +517,8 @@ type AttachOpts struct {
// ConfigureAttach creates an mmap configuration for the segment with the
// requested attach options.
//
-// ConfigureAttach returns with a ref on s on success. The caller should drop
-// this once the map is installed. This reference prevents s from being
-// destroyed before the returned configuration is used.
+// Postconditions: The returned MMapOpts are valid only as long as a reference
+// continues to be held on s.
func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -504,7 +536,6 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac
// in the user namespace that governs its IPC namespace." - man shmat(2)
return memmap.MMapOpts{}, syserror.EACCES
}
- s.IncRef()
return memmap.MMapOpts{
Length: s.size,
Offset: 0,
@@ -549,10 +580,15 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
}
creds := auth.CredentialsFromContext(ctx)
- nattach := uint64(s.ReadRefs())
- // Don't report the self-reference we keep prior to being marked for
- // destruction. However, also don't report a count of -1 for segments marked
- // as destroyed, with no mappings.
+ // Use the reference count as a rudimentary count of the number of
+ // attaches. We exclude:
+ //
+ // 1. The reference the caller holds.
+ // 2. The self-reference held by s prior to destruction.
+ //
+ // Note that this may still overcount by including transient references
+ // used in concurrent calls.
+ nattach := uint64(s.ReadRefs()) - 1
if !s.pendingDestruction {
nattach--
}
@@ -620,18 +656,17 @@ func (s *Shm) MarkDestroyed() {
s.registry.dissociateKey(s)
s.mu.Lock()
- // Only drop the segment's self-reference once, when destruction is
- // requested. Otherwise, repeated calls to shmctl(IPC_RMID) would force a
- // segment to be destroyed prematurely, potentially with active maps to the
- // segment's address range. Remaining references are dropped when the
- // segment is detached or unmaped.
+ defer s.mu.Unlock()
if !s.pendingDestruction {
s.pendingDestruction = true
- s.mu.Unlock() // Must release s.mu before calling s.DecRef.
+ // Drop the self-reference so destruction occurs when all
+ // external references are gone.
+ //
+ // N.B. This cannot be the final DecRef, as the caller also
+ // holds a reference.
s.DecRef()
return
}
- s.mu.Unlock()
}
// checkOwnership verifies whether a segment may be accessed by ctx as an
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index ab0c6c4aa..d25a7903b 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -489,18 +489,43 @@ type Task struct {
// netns is protected by mu. netns is owned by the task goroutine.
netns bool
- // If rseqPreempted is true, before the next call to p.Switch(), interrupt
- // RSEQ critical regions as defined by tg.rseq and write the task
- // goroutine's CPU number to rseqCPUAddr. rseqCPU is the last CPU number
- // written to rseqCPUAddr.
+ // If rseqPreempted is true, before the next call to p.Switch(),
+ // interrupt rseq critical regions as defined by rseqAddr and
+ // tg.oldRSeqCritical and write the task goroutine's CPU number to
+ // rseqAddr/oldRSeqCPUAddr.
//
- // If rseqCPUAddr is 0, rseqCPU is -1.
+ // We support two ABIs for restartable sequences:
//
- // rseqCPUAddr, rseqCPU, and rseqPreempted are exclusive to the task
- // goroutine.
+ // 1. The upstream interface added in v4.18,
+ // 2. An "old" interface never merged upstream. In the implementation,
+ // this is referred to as "old rseq".
+ //
+ // rseqPreempted is exclusive to the task goroutine.
rseqPreempted bool `state:"nosave"`
- rseqCPUAddr usermem.Addr
- rseqCPU int32
+
+ // rseqCPU is the last CPU number written to rseqAddr/oldRSeqCPUAddr.
+ //
+ // If rseq is unused, rseqCPU is -1 for convenient use in
+ // platform.Context.Switch.
+ //
+ // rseqCPU is exclusive to the task goroutine.
+ rseqCPU int32
+
+ // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable.
+ //
+ // oldRSeqCPUAddr is exclusive to the task goroutine.
+ oldRSeqCPUAddr usermem.Addr
+
+ // rseqAddr is a pointer to the userspace linux.RSeq structure.
+ //
+ // rseqAddr is exclusive to the task goroutine.
+ rseqAddr usermem.Addr
+
+ // rseqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ //
+ // rseqSignature is exclusive to the task goroutine.
+ rseqSignature uint32
// copyScratchBuffer is a buffer available to CopyIn/CopyOut
// implementations that require an intermediate buffer to copy data
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 3eadfedb4..247bd4aba 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -236,14 +236,19 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else if opts.NewPIDNamespace {
pidns = pidns.NewChild(userns)
}
+
tg := t.tg
+ rseqAddr := usermem.Addr(0)
+ rseqSignature := uint32(0)
if opts.NewThreadGroup {
tg.mounts.IncRef()
sh := t.tg.signalHandlers
if opts.NewSignalHandlers {
sh = sh.Fork()
}
- tg = t.k.newThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy(), t.k.monotonicClock)
+ tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ rseqAddr = t.rseqAddr
+ rseqSignature = t.rseqSignature
}
cfg := &TaskConfig{
@@ -260,6 +265,8 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
UTSNamespace: utsns,
IPCNamespace: ipcns,
AbstractSocketNamespace: t.abstractSockets,
+ RSeqAddr: rseqAddr,
+ RSeqSignature: rseqSignature,
ContainerID: t.ContainerID(),
}
if opts.NewThreadGroup {
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 90a6190f1..fa6528386 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -190,9 +190,11 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.updateRSSLocked()
// Restartable sequence state is discarded.
t.rseqPreempted = false
- t.rseqCPUAddr = 0
t.rseqCPU = -1
- t.tg.rscr.Store(&RSEQCriticalRegion{})
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+ t.oldRSeqCPUAddr = 0
+ t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
// Remove FDs with the CloseOnExec flag set.
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d97f8c189..6357273d3 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -169,12 +169,22 @@ func (*runApp) execute(t *Task) taskRunState {
// Apply restartable sequences.
if t.rseqPreempted {
t.rseqPreempted = false
- if t.rseqCPUAddr != 0 {
+ if t.rseqAddr != 0 || t.oldRSeqCPUAddr != 0 {
+ // Linux writes the CPU on every preemption. We only do
+ // so if it changed. Thus we may delay delivery of
+ // SIGSEGV if rseqAddr/oldRSeqCPUAddr is invalid.
cpu := int32(hostcpu.GetCPU())
if t.rseqCPU != cpu {
t.rseqCPU = cpu
if err := t.rseqCopyOutCPU(); err != nil {
- t.Warningf("Failed to copy CPU to %#x for RSEQ: %v", t.rseqCPUAddr, err)
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ if err := t.oldRSeqCopyOutCPU(); err != nil {
+ t.Debugf("Failed to copy CPU to %#x for old rseq: %v", t.oldRSeqCPUAddr, err)
t.forceSignal(linux.SIGSEGV, false)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
// Re-enter the task run loop for signal delivery.
@@ -320,7 +330,7 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextCPUPreempted:
- // Ensure that RSEQ critical sections are interrupted and per-thread
+ // Ensure that rseq critical sections are interrupted and per-thread
// CPU values are updated before the next platform.Context.Switch().
t.rseqPreempted = true
return (*runApp)(nil)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 3522a4ae5..58af16ee2 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -79,6 +80,13 @@ type TaskConfig struct {
// AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
AbstractSocketNamespace *AbstractSocketNamespace
+ // RSeqAddr is a pointer to the the userspace linux.RSeq structure.
+ RSeqAddr usermem.Addr
+
+ // RSeqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ RSeqSignature uint32
+
// ContainerID is the container the new task belongs to.
ContainerID string
}
@@ -126,6 +134,8 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
futexWaiter: futex.NewWaiter(),
containerID: cfg.ContainerID,
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 72568d296..c0197a563 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -238,8 +238,8 @@ type ThreadGroup struct {
// execed is protected by the TaskSet mutex.
execed bool
- // rscr is the thread group's RSEQ critical region.
- rscr atomic.Value `state:".(*RSEQCriticalRegion)"`
+ // oldRSeqCritical is the thread group's old rseq critical region.
+ oldRSeqCritical atomic.Value `state:".(*OldRSeqCriticalRegion)"`
// mounts is the thread group's mount namespace. This does not really
// correspond to a "mount namespace" in Linux, but is more like a
@@ -256,35 +256,35 @@ type ThreadGroup struct {
tty *TTY
}
-// newThreadGroup returns a new, empty thread group in PID namespace ns. The
+// NewThreadGroup returns a new, empty thread group in PID namespace ns. The
// thread group leader will send its parent terminationSignal when it exits.
// The new thread group isn't visible to the system until a task has been
// created inside of it by a successful call to TaskSet.NewTask.
-func (k *Kernel) newThreadGroup(mounts *fs.MountNamespace, ns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet, monotonicClock *timekeeperClock) *ThreadGroup {
+func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet) *ThreadGroup {
tg := &ThreadGroup{
threadGroupNode: threadGroupNode{
- pidns: ns,
+ pidns: pidns,
},
signalHandlers: sh,
terminationSignal: terminationSignal,
ioUsage: &usage.IO{},
limits: limits,
- mounts: mounts,
+ mounts: mntns,
}
tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
tg.timers = make(map[linux.TimerID]*IntervalTimer)
- tg.rscr.Store(&RSEQCriticalRegion{})
+ tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
return tg
}
-// saveRscr is invoked by stateify.
-func (tg *ThreadGroup) saveRscr() *RSEQCriticalRegion {
- return tg.rscr.Load().(*RSEQCriticalRegion)
+// saveOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) saveOldRSeqCritical() *OldRSeqCriticalRegion {
+ return tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
}
-// loadRscr is invoked by stateify.
-func (tg *ThreadGroup) loadRscr(rscr *RSEQCriticalRegion) {
- tg.rscr.Store(rscr)
+// loadOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) loadOldRSeqCritical(r *OldRSeqCriticalRegion) {
+ tg.oldRSeqCritical.Store(r)
}
// SignalHandlers returns the signal handlers used by tg.
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 8c2246bb4..79610acb7 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -66,8 +66,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
}
@@ -81,7 +79,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallMapsEntry)
}
}
@@ -97,8 +94,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaMapsEntryLocked(ctx, vseg),
@@ -116,7 +111,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallMapsEntry),
@@ -187,15 +181,12 @@ func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffe
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
}
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallSmapsEntry)
}
}
@@ -211,8 +202,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaSmapsEntryLocked(ctx, vseg),
@@ -223,7 +212,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallSmapsEntry),
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 5c52d4007..f3afd98da 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -34,6 +34,8 @@ go_library(
"machine_arm64_unsafe.go",
"machine_unsafe.go",
"physical_map.go",
+ "physical_map_amd64.go",
+ "physical_map_arm64.go",
"virtual_map.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm",
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index 586e91bb2..91de5dab1 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -24,15 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
-const (
- // reservedMemory is a chunk of physical memory reserved starting at
- // physical address zero. There are some special pages in this region,
- // so we just call the whole thing off.
- //
- // Other architectures may define this to be zero.
- reservedMemory = 0x100000000
-)
-
type region struct {
virtual uintptr
length uintptr
@@ -59,8 +50,7 @@ func fillAddressSpace() (excludedRegions []region) {
// We can cut vSize in half, because the kernel will be using the top
// half and we ignore it while constructing mappings. It's as if we've
// already excluded half the possible addresses.
- vSize := uintptr(1) << ring0.VirtualAddressBits()
- vSize = vSize >> 1
+ vSize := ring0.UserspaceSize
// We exclude reservedMemory below from our physical memory size, so it
// needs to be dropped here as well. Otherwise, we could end up with
diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/platform/kvm/physical_map_amd64.go
index 0e016bca5..c5adfb577 100644
--- a/pkg/sentry/fsimpl/proc/filesystems.go
+++ b/pkg/sentry/platform/kvm/physical_map_amd64.go
@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package proc
+package kvm
-// filesystemsData implements vfs.DynamicBytesSource for /proc/filesystems.
-//
-// +stateify savable
-type filesystemsData struct{}
-
-// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for
-// filesystemsData. We would need to retrive filesystem names from
-// vfs.VirtualFilesystem. Also needs vfs replacement for
-// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev.
+const (
+ // reservedMemory is a chunk of physical memory reserved starting at
+ // physical address zero. There are some special pages in this region,
+ // so we just call the whole thing off.
+ reservedMemory = 0x100000000
+)
diff --git a/pkg/sentry/fsimpl/proc/proc.go b/pkg/sentry/platform/kvm/physical_map_arm64.go
index 31dec36de..4d8561453 100644
--- a/pkg/sentry/fsimpl/proc/proc.go
+++ b/pkg/sentry/platform/kvm/physical_map_arm64.go
@@ -12,5 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package proc implements a partial in-memory file system for procfs.
-package proc
+package kvm
+
+const (
+ reservedMemory = 0
+)
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
index 64c718d21..16f9c523e 100644
--- a/pkg/sentry/platform/ptrace/stub_amd64.s
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -64,6 +64,8 @@ begin:
CMPQ AX, $0
JL error
+ MOVQ $0, BX
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -73,23 +75,26 @@ begin:
MOVQ $SIGSTOP, SI
SYSCALL
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+ // The sentry sets BX to 1 when creating stub process.
+ CMPQ BX, $1
+ JE clone
+
+ // Notify the Sentry that syscall exited.
+done:
+ INT $3
+ // Be paranoid.
+ JMP done
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R15 has been updated with the expected PPID.
- JMP begin
+ CMPQ AX, $0
+ JE begin
+ // The clone syscall returns a non-zero value.
+ JMP done
error:
// Exit with -errno.
MOVQ AX, DI
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
index 2c5e4d5cb..6162df02a 100644
--- a/pkg/sentry/platform/ptrace/stub_arm64.s
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -59,6 +59,8 @@ begin:
CMP $0x0, R0
BLT error
+ MOVD $0, R9
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -66,22 +68,26 @@ begin:
MOVD $SYS_KILL, R8
MOVD $SIGSTOP, R1
SVC
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+
+ // The sentry sets R9 to 1 when creating stub process.
+ CMP $1, R9
+ BEQ clone
+
+done:
+ // Notify the Sentry that syscall exited.
+ BRK $3
+ B done // Be paranoid.
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R7 has been updated with the expected PPID.
- B begin
+ CMP $0, R0
+ BEQ begin
+
+ // The clone system call returned a non-zero value.
+ B done
error:
// Exit with -errno.
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index ddb1f41e3..20244fd95 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -21,6 +21,7 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -429,13 +430,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
for {
- // Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ // Execute the syscall instruction. The task has to stop on the
+ // trap instruction which is right after the syscall
+ // instruction.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_CONT, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
+ if sig == syscall.SIGTRAP {
// Reached syscall-enter-stop.
break
} else {
@@ -447,18 +450,6 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
}
- // Complete the actual system call.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens
- // between syscall-enter-stop and syscall-exit-stop; it happens *after*
- // syscall-exit-stop.)" - ptrace(2), "Syscall-stops"
- if sig := t.wait(stopped); sig != (syscallEvent | syscall.SIGTRAP) {
- t.dumpAndPanic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig))
- }
-
// Grab registers.
if err := t.getRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace get regs failed: %v", err))
@@ -541,14 +532,14 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if isSingleStepping(regs) {
if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SYSEMU_SINGLESTEP,
+ unix.PTRACE_SYSEMU_SINGLESTEP,
uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SYSEMU,
+ unix.PTRACE_SYSEMU,
uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 606dc2b1d..e99798c56 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -141,9 +141,11 @@ func (t *thread) adjustInitRegsRip() {
t.initRegs.Rip -= initRegsRipAdjustment
}
-// Pass the expected PPID to the child via R15 when creating stub process
+// Pass the expected PPID to the child via R15 when creating stub process.
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.R15 = uint64(ppid)
+ // Rbx has to be set to 1 when creating stub process.
+ initregs.Rbx = 1
}
// patchSignalInfo patches the signal info to account for hitting the seccomp
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 62a686ee7..7b975137f 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -127,6 +127,8 @@ func (t *thread) adjustInitRegsRip() {
// Pass the expected PPID to the child via X7 when creating stub process
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.Regs[7] = uint64(ppid)
+ // R9 has to be set to 1 when creating stub process.
+ initregs.Regs[9] = 1
}
// patchSignalInfo patches the signal info to account for hitting the seccomp
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index cf13ea5e4..74968dfdf 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -54,7 +54,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index f1af18265..87f4552b5 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -71,6 +71,7 @@ go_library(
"lib_amd64.go",
"lib_amd64.s",
"lib_arm64.go",
+ "lib_arm64.s",
"ring0.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0",
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index fbfbd9bab..dc0eeec01 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -73,6 +73,9 @@ type CPUArchState struct {
// application context pointer.
appAddr uintptr
+
+ // lazyVFP is the value of cpacr_el1.
+ lazyVFP uintptr
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 29c475882..64e9c0845 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -31,6 +31,11 @@
#define RSV_REG R18_PLATFORM
#define RSV_REG_APP R9
+#define FPEN_NOTRAP 0x3
+#define FPEN_SHIFT 20
+
+#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+
#define REGISTERS_SAVE(reg, offset) \
MOVD R0, offset+PTRACE_R0(reg); \
MOVD R1, offset+PTRACE_R1(reg); \
@@ -279,6 +284,16 @@
#define IRQ_DISABLE \
MSR $2, DAIFClr;
+#define VFP_ENABLE \
+ MOVD $FPEN_ENABLE, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+#define VFP_DISABLE \
+ MOVD $0x0, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
#define KERNEL_ENTRY_FROM_EL0 \
SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
@@ -318,6 +333,11 @@ TEXT ·Halt(SB),NOSPLIT,$0
BNE mmio_exit
MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG)
mmio_exit:
+ // Disable fpsimd.
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, CPU_LAZY_VFP(RSV_REG)
+ VFP_DISABLE
+
// MMIO_EXIT.
MOVD $0, R9
MOVD R0, 0xffff000000001000(R9)
@@ -337,6 +357,73 @@ TEXT ·Current(SB),NOSPLIT,$0-8
#define STACK_FRAME_SIZE 16
TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
+ // Step1, save sentry context into memory.
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+
+ WORD $0xd5384003 // MRS SPSR_EL1, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
+ MOVD R30, CPU_REGISTERS+PTRACE_PC(RSV_REG)
+ MOVD RSP, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_SP(RSV_REG)
+
+ MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3
+
+ // Step2, save SP_EL1, PSTATE into kernel temporary stack.
+ // switch to temporary stack.
+ LOAD_KERNEL_STACK(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12
+ STP (R11, R12), 16*0(RSP)
+
+ MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12
+
+ // Step3, test user pagetable.
+ // If user pagetable is empty, trapped in el1_ia.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ // If pagetable is not empty, recovery kernel temporary stack.
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
+ // Step4, load app context pointer.
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP
+
+ // Step5, prepare the environment for container application.
+ // set sp_el0.
+ MOVD PTRACE_SP(RSV_REG_APP), R1
+ WORD $0xd5184101 //MSR R1, SP_EL0
+ // set pc.
+ MOVD PTRACE_PC(RSV_REG_APP), R1
+ MSR R1, ELR_EL1
+ // set pstate.
+ MOVD PTRACE_PSTATE(RSV_REG_APP), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ // RSV_REG & RSV_REG_APP will be loaded at the end.
+ REGISTERS_LOAD(RSV_REG_APP, 0)
+
+ // switch to user pagetable.
+ MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
+ MOVD PTRACE_R9(RSV_REG_APP), RSV_REG_APP
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
ERET()
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
@@ -382,6 +469,8 @@ TEXT ·El1_sync(SB),NOSPLIT,$0
BEQ el1_svc
CMP $ESR_ELx_EC_BREAKPT_CUR, R24
BGE el1_dbg
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el1_fpsimd_acc
B el1_invalid
el1_da:
@@ -402,6 +491,10 @@ el1_svc:
el1_dbg:
B ·Shutdown(SB)
+el1_fpsimd_acc:
+ VFP_ENABLE
+ B ·kernelExitToEl1(SB) // Resume.
+
el1_invalid:
B ·Shutdown(SB)
@@ -528,6 +621,10 @@ TEXT ·Vectors(SB),NOSPLIT,$0
B ·El0_error_invalid(SB)
nop31Instructions()
+ // The exception-vector-table is required to be 11-bits aligned.
+ // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
+ // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
WORD $0xd503201f //nop
nop31Instructions()
WORD $0xd503201f
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index 900ee6380..8bcfe1032 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -16,10 +16,24 @@
package ring0
-// LoadFloatingPoint loads floating point state by the most efficient mechanism
-// available (set by Init).
-var LoadFloatingPoint func(*byte)
+// CPACREL1 returns the value of the CPACR_EL1 register.
+func CPACREL1() (value uintptr)
-// SaveFloatingPoint saves floating point state by the most efficient mechanism
-// available (set by Init).
-var SaveFloatingPoint func(*byte)
+// FPCR returns the value of FPCR register.
+func FPCR() (value uintptr)
+
+// SetFPCR writes the FPCR value.
+func SetFPCR(value uintptr)
+
+// FPSR returns the value of FPSR register.
+func FPSR() (value uintptr)
+
+// SetFPSR writes the FPSR value.
+func SetFPSR(value uintptr)
+
+// SaveVRegs saves V0-V31 registers.
+// V0-V31: 32 128-bit registers for floating point and simd.
+func SaveVRegs(*byte)
+
+// LoadVRegs loads V0-V31 registers.
+func LoadVRegs(*byte)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
new file mode 100644
index 000000000..0e6a6235b
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -0,0 +1,121 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+TEXT ·CPACREL1(SB),NOSPLIT,$0-8
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPCR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4201 // MRS NZCV, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPSR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4421 // MRS FPSR, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·SetFPCR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4201 // MSR R1, NZCV
+ RET
+
+TEXT ·SetFPSR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4421 // MSR R1, FPSR
+ RET
+
+TEXT ·SaveVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD F0, 16*1(R0)
+ FMOVD F1, 16*2(R0)
+ FMOVD F2, 16*3(R0)
+ FMOVD F3, 16*4(R0)
+ FMOVD F4, 16*5(R0)
+ FMOVD F5, 16*6(R0)
+ FMOVD F6, 16*7(R0)
+ FMOVD F7, 16*8(R0)
+ FMOVD F8, 16*9(R0)
+ FMOVD F9, 16*10(R0)
+ FMOVD F10, 16*11(R0)
+ FMOVD F11, 16*12(R0)
+ FMOVD F12, 16*13(R0)
+ FMOVD F13, 16*14(R0)
+ FMOVD F14, 16*15(R0)
+ FMOVD F15, 16*16(R0)
+ FMOVD F16, 16*17(R0)
+ FMOVD F17, 16*18(R0)
+ FMOVD F18, 16*19(R0)
+ FMOVD F19, 16*20(R0)
+ FMOVD F20, 16*21(R0)
+ FMOVD F21, 16*22(R0)
+ FMOVD F22, 16*23(R0)
+ FMOVD F23, 16*24(R0)
+ FMOVD F24, 16*25(R0)
+ FMOVD F25, 16*26(R0)
+ FMOVD F26, 16*27(R0)
+ FMOVD F27, 16*28(R0)
+ FMOVD F28, 16*29(R0)
+ FMOVD F29, 16*30(R0)
+ FMOVD F30, 16*31(R0)
+ FMOVD F31, 16*32(R0)
+ ISB $15
+
+ RET
+
+TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD 16*1(R0), F0
+ FMOVD 16*2(R0), F1
+ FMOVD 16*3(R0), F2
+ FMOVD 16*4(R0), F3
+ FMOVD 16*5(R0), F4
+ FMOVD 16*6(R0), F5
+ FMOVD 16*7(R0), F6
+ FMOVD 16*8(R0), F7
+ FMOVD 16*9(R0), F8
+ FMOVD 16*10(R0), F9
+ FMOVD 16*11(R0), F10
+ FMOVD 16*12(R0), F11
+ FMOVD 16*13(R0), F12
+ FMOVD 16*14(R0), F13
+ FMOVD 16*15(R0), F14
+ FMOVD 16*16(R0), F15
+ FMOVD 16*17(R0), F16
+ FMOVD 16*18(R0), F17
+ FMOVD 16*19(R0), F18
+ FMOVD 16*20(R0), F19
+ FMOVD 16*21(R0), F20
+ FMOVD 16*22(R0), F21
+ FMOVD 16*23(R0), F22
+ FMOVD 16*24(R0), F23
+ FMOVD 16*25(R0), F24
+ FMOVD 16*26(R0), F25
+ FMOVD 16*27(R0), F26
+ FMOVD 16*28(R0), F27
+ FMOVD 16*29(R0), F28
+ FMOVD 16*30(R0), F29
+ FMOVD 16*31(R0), F30
+ FMOVD 16*32(R0), F31
+ ISB $15
+
+ RET
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index d7aa1c7cc..cd2a65f97 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -39,6 +39,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index af1a4e95f..4301b697c 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -471,6 +471,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IP:
switch h.Type {
case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
i += AlignUp(length, width)
@@ -481,6 +484,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IPV6:
switch h.Type {
case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += AlignUp(length, width)
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 79589e3c8..136821963 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -22,7 +22,6 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 4a1b87a9a..d2e3644a6 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -500,29 +499,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
trunc := flags&linux.MSG_TRUNC != 0
r := unix.EndpointReader{
+ Ctx: t,
Endpoint: s.ep,
Peek: flags&linux.MSG_PEEK != 0,
}
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
// If MSG_TRUNC is set with a zero byte destination then we still need
// to read the message and discard it, or in the case where MSG_PEEK is
// set, leave it be. In both cases the full message length must be
- // returned. However, the memory manager for the destination will not read
- // the endpoint if the destination is zero length.
- //
- // In order for the endpoint to be read when the destination size is zero,
- // we must cause a read of the endpoint by using a separate fake zero
- // length block sequence and calling the EndpointReader directly.
+ // returned.
if trunc && dst.Addrs.NumBytes() == 0 {
- // Perform a read to a zero byte block sequence. We can ignore the
- // original destination since it was zero bytes. The length returned by
- // ReadToBlocks is ignored and we return the full message length to comply
- // with MSG_TRUNC.
- _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
- return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
}
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -540,7 +539,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 140851c17..764f11a6b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -222,17 +222,25 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -977,13 +985,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
+ if v == 0 {
return []byte{}, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1213,12 +1231,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.V6OnlyOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ var o uint32
+ if v {
+ o = 1
+ }
+ return int32(o), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1427,7 +1448,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
@@ -1621,7 +1655,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 2ec1a662d..2447f24ef 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -83,6 +83,19 @@ type EndpointReader struct {
ControlTrunc bool
}
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 529a7a7a9..37c7ac3c1 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,17 +175,25 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
// State returns the current state of the socket, as represented by Linux in
// procfs.
@@ -851,11 +859,19 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 885758054..91effe89a 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -544,8 +544,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if senderRequested {
r.From = &tcpip.FullAddress{}
}
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
var total int64
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -580,7 +599,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index 9fa2f0e16..1e823b685 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -378,5 +378,7 @@ func init() {
syscallTable{
os: abi.Linux,
arch: arch.AMD64,
- syscalls: linuxAMD64})
+ syscalls: linuxAMD64,
+ },
+ )
}
diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go
index dea309fda..c77d418e6 100644
--- a/pkg/sentry/strace/select.go
+++ b/pkg/sentry/strace/select.go
@@ -36,6 +36,9 @@ func fdsFromSet(t *kernel.Task, set []byte) []int {
}
func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string {
+ if nfds < 0 {
+ return fmt.Sprintf("%#x (negative nfds)", addr)
+ }
if addr == 0 {
return "null"
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 6766ba587..a76975cee 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -30,6 +30,7 @@ go_library(
"sys_random.go",
"sys_read.go",
"sys_rlimit.go",
+ "sys_rseq.go",
"sys_rusage.go",
"sys_sched.go",
"sys_seccomp.go",
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index 272ae9991..479c5f6ff 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -377,7 +377,7 @@ var AMD64 = &kernel.SyscallTable{
331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
332: syscalls.Supported("statx", Statx),
333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
- 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+ 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index 3b584eed9..d3f61f5e8 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -307,7 +307,7 @@ var ARM64 = &kernel.SyscallTable{
290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
291: syscalls.Supported("statx", Statx),
292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
- 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+ 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b9bd25464..bde17a767 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -226,6 +226,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if mask == 0 {
return 0, nil, syserror.EINVAL
}
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
n, err := t.Futex().Wake(t, addr, private, mask, val)
return uintptr(n), nil, err
@@ -242,6 +247,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FUTEX_WAKE_OP:
op := uint32(val3)
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op)
return uintptr(n), nil, err
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
new file mode 100644
index 000000000..90db10ea6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// RSeq implements syscall rseq(2).
+func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint()
+ flags := args[2].Int()
+ signature := args[3].Uint()
+
+ if !t.RSeqAvailable() {
+ // Event for applications that want rseq on a configuration
+ // that doesn't support them.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ switch flags {
+ case 0:
+ // Register.
+ return 0, nil, t.SetRSeq(addr, length, signature)
+ case linux.RSEQ_FLAG_UNREGISTER:
+ return 0, nil, t.ClearRSeq(addr, length, signature)
+ default:
+ // Unknown flag.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index d57ffb3a1..4a8bc24a2 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -39,10 +39,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
+ defer segment.DecRef()
return uintptr(segment.ID), nil, nil
}
// findSegment retrives a shm segment by the given id.
+//
+// findSegment returns a reference on Shm.
func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
r := t.IPCNamespace().ShmRegistry()
segment := r.FindByID(id)
@@ -63,6 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
@@ -72,7 +76,6 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- defer segment.DecRef()
addr, err = t.MemoryManager().MMap(t, opts)
return uintptr(addr), nil, err
}
@@ -105,6 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
stat, err := segment.IPCStat(t)
if err == nil {
@@ -128,6 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
switch cmd {
case linux.IPC_SET:
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index e3e554b88..4c6aa04a1 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -9,6 +9,7 @@ go_library(
"context.go",
"debug.go",
"dentry.go",
+ "device.go",
"file_description.go",
"file_description_impl_util.go",
"filesystem.go",
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 40f4c1d09..1bc9c4a38 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -85,12 +85,12 @@ type Dentry struct {
// mounts is accessed using atomic memory operations.
mounts uint32
- // mu synchronizes disowning and mounting over this Dentry.
- mu sync.Mutex
-
// children are child Dentries.
children map[string]*Dentry
+ // mu synchronizes disowning and mounting over this Dentry.
+ mu sync.Mutex
+
// impl is the DentryImpl associated with this Dentry. impl is immutable.
// This should be the last field in Dentry.
impl DentryImpl
@@ -199,6 +199,18 @@ func (d *Dentry) HasChildren() bool {
return len(d.children) != 0
}
+// Children returns a map containing all of d's children.
+func (d *Dentry) Children() map[string]*Dentry {
+ if !d.HasChildren() {
+ return nil
+ }
+ m := make(map[string]*Dentry)
+ for name, child := range d.children {
+ m[name] = child
+ }
+ return m
+}
+
// InsertChild makes child a child of d with the given name.
//
// InsertChild is a mutator of d and child.
@@ -222,6 +234,18 @@ func (d *Dentry) InsertChild(child *Dentry, name string) {
child.name = name
}
+// IsAncestorOf returns true if d is an ancestor of d2; that is, d is either
+// d2's parent or an ancestor of d2's parent.
+func (d *Dentry) IsAncestorOf(d2 *Dentry) bool {
+ for d2.parent != nil {
+ if d2.parent == d {
+ return true
+ }
+ d2 = d2.parent
+ }
+ return false
+}
+
// PrepareDeleteDentry must be called before attempting to delete the file
// represented by d. If PrepareDeleteDentry succeeds, the caller must call
// AbortDeleteDentry or CommitDeleteDentry depending on the deletion's outcome.
@@ -271,21 +295,6 @@ func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) {
}
}
-// DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as
-// appropriate for in-memory filesystems that don't need to ensure that some
-// external state change succeeds before committing the deletion.
-//
-// DeleteDentry is a mutator of d and d.Parent().
-//
-// Preconditions: d is a child Dentry.
-func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error {
- if err := vfs.PrepareDeleteDentry(mntns, d); err != nil {
- return err
- }
- vfs.CommitDeleteDentry(d)
- return nil
-}
-
// ForceDeleteDentry causes d to become disowned. It should only be used in
// cases where VFS has no ability to stop the deletion (e.g. d represents the
// local state of a file on a remote filesystem on which the file has already
@@ -314,7 +323,7 @@ func (vfs *VirtualFilesystem) ForceDeleteDentry(d *Dentry) {
// CommitRenameExchangeDentry depending on the rename's outcome.
//
// Preconditions: from is a child Dentry. If to is not nil, it must be a child
-// Dentry from the same Filesystem.
+// Dentry from the same Filesystem. from != to.
func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error {
if checkInvariants {
if from.parent == nil {
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
new file mode 100644
index 000000000..cb672e36f
--- /dev/null
+++ b/pkg/sentry/vfs/device.go
@@ -0,0 +1,100 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// DeviceKind indicates whether a device is a block or character device.
+type DeviceKind uint32
+
+const (
+ // BlockDevice indicates a block device.
+ BlockDevice DeviceKind = iota
+
+ // CharDevice indicates a character device.
+ CharDevice
+)
+
+// String implements fmt.Stringer.String.
+func (kind DeviceKind) String() string {
+ switch kind {
+ case BlockDevice:
+ return "block"
+ case CharDevice:
+ return "character"
+ default:
+ return fmt.Sprintf("invalid device kind %d", kind)
+ }
+}
+
+type devTuple struct {
+ kind DeviceKind
+ major uint32
+ minor uint32
+}
+
+// A Device backs device special files.
+type Device interface {
+ // Open returns a FileDescription representing this device.
+ Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error)
+}
+
+type registeredDevice struct {
+ dev Device
+ opts RegisterDeviceOptions
+}
+
+// RegisterDeviceOptions contains options to
+// VirtualFilesystem.RegisterDevice().
+type RegisterDeviceOptions struct {
+ // GroupName is the name shown for this device registration in
+ // /proc/devices. If GroupName is empty, this registration will not be
+ // shown in /proc/devices.
+ GroupName string
+}
+
+// RegisterDevice registers the given Device in vfs with the given major and
+// minor device numbers.
+func (vfs *VirtualFilesystem) RegisterDevice(kind DeviceKind, major, minor uint32, dev Device, opts *RegisterDeviceOptions) error {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.Lock()
+ defer vfs.devicesMu.Unlock()
+ if existing, ok := vfs.devices[tup]; ok {
+ return fmt.Errorf("%s device number (%d, %d) is already registered to device type %T", kind, major, minor, existing.dev)
+ }
+ vfs.devices[tup] = &registeredDevice{
+ dev: dev,
+ opts: *opts,
+ }
+ return nil
+}
+
+// OpenDeviceSpecialFile returns a FileDescription representing the given
+// device.
+func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mount, d *Dentry, kind DeviceKind, major, minor uint32, opts *OpenOptions) (*FileDescription, error) {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.RLock()
+ defer vfs.devicesMu.RUnlock()
+ rd, ok := vfs.devices[tup]
+ if !ok {
+ return nil, syserror.ENXIO
+ }
+ return rd.dev.Open(ctx, mnt, d, *opts)
+}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 6575afd16..6afe280bc 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -20,8 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -38,49 +40,58 @@ type FileDescription struct {
// operations.
refs int64
+ // statusFlags contains status flags, "initialized by open(2) and possibly
+ // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic
+ // memory operations.
+ statusFlags uint32
+
// vd is the filesystem location at which this FileDescription was opened.
// A reference is held on vd. vd is immutable.
vd VirtualDentry
+ opts FileDescriptionOptions
+
// impl is the FileDescriptionImpl associated with this Filesystem. impl is
// immutable. This should be the last field in FileDescription.
impl FileDescriptionImpl
}
-// Init must be called before first use of fd. It takes ownership of references
-// on mnt and d held by the caller.
-func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) {
+// FileDescriptionOptions contains options to FileDescription.Init().
+type FileDescriptionOptions struct {
+ // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is
+ // usually only the case if O_DIRECT would actually have an effect.
+ AllowDirectIO bool
+
+ // If UseDentryMetadata is true, calls to FileDescription methods that
+ // interact with file and filesystem metadata (Stat, SetStat, StatFS,
+ // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
+ // the corresponding FilesystemImpl methods instead of the corresponding
+ // FileDescriptionImpl methods.
+ //
+ // UseDentryMetadata is intended for file descriptions that are implemented
+ // outside of individual filesystems, such as pipes, sockets, and device
+ // special files. FileDescriptions for which UseDentryMetadata is true may
+ // embed DentryMetadataFileDescriptionImpl to obtain appropriate
+ // implementations of FileDescriptionImpl methods that should not be
+ // called.
+ UseDentryMetadata bool
+}
+
+// Init must be called before first use of fd. It takes references on mnt and
+// d. statusFlags is the initial file description status flags, which is
+// usually the full set of flags passed to open(2).
+func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) {
fd.refs = 1
+ fd.statusFlags = statusFlags | linux.O_LARGEFILE
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
}
+ fd.vd.IncRef()
+ fd.opts = *opts
fd.impl = impl
}
-// Impl returns the FileDescriptionImpl associated with fd.
-func (fd *FileDescription) Impl() FileDescriptionImpl {
- return fd.impl
-}
-
-// Mount returns the mount on which fd was opened. It does not take a reference
-// on the returned Mount.
-func (fd *FileDescription) Mount() *Mount {
- return fd.vd.mount
-}
-
-// Dentry returns the dentry at which fd was opened. It does not take a
-// reference on the returned Dentry.
-func (fd *FileDescription) Dentry() *Dentry {
- return fd.vd.dentry
-}
-
-// VirtualDentry returns the location at which fd was opened. It does not take
-// a reference on the returned VirtualDentry.
-func (fd *FileDescription) VirtualDentry() VirtualDentry {
- return fd.vd
-}
-
// IncRef increments fd's reference count.
func (fd *FileDescription) IncRef() {
atomic.AddInt64(&fd.refs, 1)
@@ -112,6 +123,82 @@ func (fd *FileDescription) DecRef() {
}
}
+// Mount returns the mount on which fd was opened. It does not take a reference
+// on the returned Mount.
+func (fd *FileDescription) Mount() *Mount {
+ return fd.vd.mount
+}
+
+// Dentry returns the dentry at which fd was opened. It does not take a
+// reference on the returned Dentry.
+func (fd *FileDescription) Dentry() *Dentry {
+ return fd.vd.dentry
+}
+
+// VirtualDentry returns the location at which fd was opened. It does not take
+// a reference on the returned VirtualDentry.
+func (fd *FileDescription) VirtualDentry() VirtualDentry {
+ return fd.vd
+}
+
+// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
+func (fd *FileDescription) StatusFlags() uint32 {
+ return atomic.LoadUint32(&fd.statusFlags)
+}
+
+// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL).
+func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error {
+ // Compare Linux's fs/fcntl.c:setfl().
+ oldFlags := fd.StatusFlags()
+ // Linux documents this check as "O_APPEND cannot be cleared if the file is
+ // marked as append-only and the file is open for write", which would make
+ // sense. However, the check as actually implemented seems to be "O_APPEND
+ // cannot be changed if the file is marked as append-only".
+ if (flags^oldFlags)&linux.O_APPEND != 0 {
+ stat, err := fd.Stat(ctx, StatOptions{
+ // There is no mask bit for stx_attributes.
+ Mask: 0,
+ // Linux just reads inode::i_flags directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) {
+ return syserror.EPERM
+ }
+ }
+ if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) {
+ stat, err := fd.Stat(ctx, StatOptions{
+ Mask: linux.STATX_UID,
+ // Linux's inode_owner_or_capable() just reads inode::i_uid
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if stat.Mask&linux.STATX_UID == 0 {
+ return syserror.EPERM
+ }
+ if !CanActAsOwner(creds, auth.KUID(stat.UID)) {
+ return syserror.EPERM
+ }
+ }
+ if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO {
+ return syserror.EINVAL
+ }
+ // TODO(jamieliu): FileDescriptionImpl.SetOAsync()?
+ const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK
+ atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags))
+ return nil
+}
+
+// Impl returns the FileDescriptionImpl associated with fd.
+func (fd *FileDescription) Impl() FileDescriptionImpl {
+ return fd.impl
+}
+
// FileDescriptionImpl contains implementation details for an FileDescription.
// Implementations of FileDescriptionImpl should contain their associated
// FileDescription by value as their first field.
@@ -120,6 +207,8 @@ func (fd *FileDescription) DecRef() {
// be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID and
// auth.KGID respectively).
//
+// All methods may return errors not specified.
+//
// FileDescriptionImpl is analogous to Linux's struct file_operations.
type FileDescriptionImpl interface {
// Release is called when the associated FileDescription reaches zero
@@ -131,14 +220,6 @@ type FileDescriptionImpl interface {
// prevent the file descriptor from being closed.
OnClose(ctx context.Context) error
- // StatusFlags returns file description status flags, as for
- // fcntl(F_GETFL).
- StatusFlags(ctx context.Context) (uint32, error)
-
- // SetStatusFlags sets file description status flags, as for
- // fcntl(F_SETFL).
- SetStatusFlags(ctx context.Context, flags uint32) error
-
// Stat returns metadata for the file represented by the FileDescription.
Stat(ctx context.Context, opts StatOptions) (linux.Statx, error)
@@ -156,6 +237,10 @@ type FileDescriptionImpl interface {
// PRead reads from the file into dst, starting at the given offset, and
// returns the number of bytes read. PRead is permitted to return partial
// reads with a nil error.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP.
PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error)
// Read is similar to PRead, but does not specify an offset.
@@ -165,6 +250,10 @@ type FileDescriptionImpl interface {
// the number of bytes read; note that POSIX 2.9.7 "Thread Interactions
// with Regular File Operations" requires that all operations that may
// mutate the FileDescription offset are serialized.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, Read returns EOPNOTSUPP.
Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error)
// PWrite writes src to the file, starting at the given offset, and returns
@@ -174,6 +263,11 @@ type FileDescriptionImpl interface {
// As in Linux (but not POSIX), if O_APPEND is in effect for the
// FileDescription, PWrite should ignore the offset and append data to the
// end of the file.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, PWrite returns
+ // EOPNOTSUPP.
PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error)
// Write is similar to PWrite, but does not specify an offset, which is
@@ -183,6 +277,10 @@ type FileDescriptionImpl interface {
// PWrite that uses a FileDescription offset, to make it possible for
// remote filesystems to implement O_APPEND correctly (i.e. atomically with
// respect to writers outside the scope of VFS).
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, Write returns EOPNOTSUPP.
Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error)
// IterDirents invokes cb on each entry in the directory represented by the
@@ -212,7 +310,21 @@ type FileDescriptionImpl interface {
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
- // TODO: extended attributes; file locking
+ // Listxattr returns all extended attribute names for the file.
+ Listxattr(ctx context.Context) ([]string, error)
+
+ // Getxattr returns the value associated with the given extended attribute
+ // for the file.
+ Getxattr(ctx context.Context, name string) (string, error)
+
+ // Setxattr changes the value associated with the given extended attribute
+ // for the file.
+ Setxattr(ctx context.Context, opts SetxattrOptions) error
+
+ // Removexattr removes the given extended attribute from the file.
+ Removexattr(ctx context.Context, name string) error
+
+ // TODO: file locking
}
// Dirent holds the information contained in struct linux_dirent64.
@@ -249,31 +361,49 @@ func (fd *FileDescription) OnClose(ctx context.Context) error {
return fd.impl.OnClose(ctx)
}
-// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
-func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- flags, err := fd.impl.StatusFlags(ctx)
- flags |= linux.O_LARGEFILE
- return flags, err
-}
-
-// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL).
-func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- return fd.impl.SetStatusFlags(ctx, flags)
-}
-
// Stat returns metadata for the file represented by fd.
func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return stat, err
+ }
return fd.impl.Stat(ctx, opts)
}
// SetStat updates metadata for the file represented by fd.
func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
return fd.impl.SetStat(ctx, opts)
}
// StatFS returns metadata for the filesystem containing the file represented
// by fd.
func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
+ vfsObj.putResolvingPath(rp)
+ return statfs, err
+ }
return fd.impl.StatFS(ctx)
}
@@ -329,6 +459,78 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
return fd.impl.Ioctl(ctx, uio, args)
}
+// Listxattr returns all extended attribute names for the file represented by
+// fd.
+func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp)
+ vfsObj.putResolvingPath(rp)
+ return names, err
+ }
+ names, err := fd.impl.Listxattr(ctx)
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by default
+ // don't exist.
+ return nil, nil
+ }
+ return names, err
+}
+
+// Getxattr returns the value associated with the given extended attribute for
+// the file represented by fd.
+func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ vfsObj.putResolvingPath(rp)
+ return val, err
+ }
+ return fd.impl.Getxattr(ctx, name)
+}
+
+// Setxattr changes the value associated with the given extended attribute for
+// the file represented by fd.
+func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
+ return fd.impl.Setxattr(ctx, opts)
+}
+
+// Removexattr removes the given extended attribute from the file represented
+// by fd.
+func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
+ return fd.impl.Removexattr(ctx, name)
+}
+
// SyncFS instructs the filesystem containing fd to execute the semantics of
// syncfs(2).
func (fd *FileDescription) SyncFS(ctx context.Context) error {
@@ -347,7 +549,7 @@ func (fd *FileDescription) MappedName(ctx context.Context) string {
// DeviceID implements memmap.MappingIdentity.DeviceID.
func (fd *FileDescription) DeviceID() uint64 {
- stat, err := fd.impl.Stat(context.Background(), StatOptions{
+ stat, err := fd.Stat(context.Background(), StatOptions{
// There is no STATX_DEV; we assume that Stat will return it if it's
// available regardless of mask.
Mask: 0,
@@ -363,7 +565,7 @@ func (fd *FileDescription) DeviceID() uint64 {
// InodeID implements memmap.MappingIdentity.InodeID.
func (fd *FileDescription) InodeID() uint64 {
- stat, err := fd.impl.Stat(context.Background(), StatOptions{
+ stat, err := fd.Stat(context.Background(), StatOptions{
Mask: linux.STATX_INO,
// fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly.
Sync: linux.AT_STATX_DONT_SYNC,
@@ -376,5 +578,5 @@ func (fd *FileDescription) InodeID() uint64 {
// Msync implements memmap.MappingIdentity.Msync.
func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
- return fd.impl.Sync(ctx)
+ return fd.Sync(ctx)
}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index aae023254..66eb57bc2 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -127,6 +127,31 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
return 0, syserror.ENOTTY
}
+// Listxattr implements FileDescriptionImpl.Listxattr analogously to
+// inode_operations::listxattr == NULL in Linux.
+func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) {
+ // This isn't exactly accurate; see FileDescription.Listxattr.
+ return nil, syserror.ENOTSUP
+}
+
+// Getxattr implements FileDescriptionImpl.Getxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) {
+ return "", syserror.ENOTSUP
+}
+
+// Setxattr implements FileDescriptionImpl.Setxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ return syserror.ENOTSUP
+}
+
+// Removexattr implements FileDescriptionImpl.Removexattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error {
+ return syserror.ENOTSUP
+}
+
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
@@ -152,6 +177,21 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme
return 0, syserror.EISDIR
}
+// DentryMetadataFileDescriptionImpl may be embedded by implementations of
+// FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is
+// true to obtain implementations of Stat and SetStat that panic.
+type DentryMetadataFileDescriptionImpl struct{}
+
+// Stat implements FileDescriptionImpl.Stat.
+func (DentryMetadataFileDescriptionImpl) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.Stat")
+}
+
+// SetStat implements FileDescriptionImpl.SetStat.
+func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetStatOptions) error {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.SetStat")
+}
+
// DynamicBytesFileDescriptionImpl may be embedded by implementations of
// FileDescriptionImpl that represent read-only regular files whose contents
// are backed by a bytes.Buffer that is regenerated when necessary, consistent
@@ -174,6 +214,17 @@ type DynamicBytesSource interface {
Generate(ctx context.Context, buf *bytes.Buffer) error
}
+// StaticData implements DynamicBytesSource over a static string.
+type StaticData struct {
+ Data string
+}
+
+// Generate implements DynamicBytesSource.
+func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(s.Data)
+ return nil
+}
+
// SetDataSource must be called exactly once on fd before first use.
func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
fd.data = data
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index ac7799296..9ed58512f 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -48,7 +48,7 @@ type genCountFD struct {
func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription {
var fd genCountFD
- fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.vfsfd.Init(&fd, 0 /* statusFlags */, mnt, vfsd, &FileDescriptionOptions{})
fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd)
return &fd.vfsfd
}
@@ -89,7 +89,7 @@ func TestGenCountFD(t *testing.T) {
creds := auth.CredentialsFromContext(ctx)
vfsObj := New() // vfs.New()
- vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{})
+ vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}, &RegisterFilesystemTypeOptions{})
mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create testfs root mount: %v", err)
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 8011eba3f..ea78f555b 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -108,6 +108,24 @@ func (fs *Filesystem) DecRef() {
// (responsible for actually implementing the operation) isn't known until path
// resolution is complete.
//
+// Unless otherwise specified, FilesystemImpl methods are responsible for
+// performing permission checks. In many cases, vfs package functions in
+// permissions.go may be used to help perform these checks.
+//
+// When multiple specified error conditions apply to a given method call, the
+// implementation may return any applicable errno unless otherwise specified,
+// but returning the earliest error specified is preferable to maximize
+// compatibility with Linux.
+//
+// All methods may return errors not specified, notably including:
+//
+// - ENOENT if a required path component does not exist.
+//
+// - ENOTDIR if an intermediate path component is not a directory.
+//
+// - Errors from vfs-package functions (ResolvingPath.Resolve*(),
+// Mount.CheckBeginWrite(), permission-checking functions, etc.)
+//
// For all methods that take or return linux.Statx, Statx.Uid and Statx.Gid
// should be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID
// and auth.KGID respectively).
@@ -130,46 +148,223 @@ type FilesystemImpl interface {
// GetDentryAt does not correspond directly to a Linux syscall; it is used
// in the implementation of:
//
- // - Syscalls that need to resolve two paths: rename(), renameat(),
- // renameat2(), link(), linkat().
+ // - Syscalls that need to resolve two paths: link(), linkat().
//
// - Syscalls that need to refer to a filesystem position outside the
// context of a file description: chdir(), fchdir(), chroot(), mount(),
// umount().
GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error)
+ // GetParentDentryAt returns a Dentry representing the directory at the
+ // second-to-last path component in rp. (Note that, despite the name, this
+ // is not necessarily the parent directory of the file at rp, since the
+ // last path component in rp may be "." or "..".) A reference is taken on
+ // the returned Dentry.
+ //
+ // GetParentDentryAt does not correspond directly to a Linux syscall; it is
+ // used in the implementation of the rename() family of syscalls, which
+ // must resolve the parent directories of two paths.
+ //
+ // Preconditions: !rp.Done().
+ //
+ // Postconditions: If GetParentDentryAt returns a nil error, then
+ // rp.Final(). If GetParentDentryAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error)
+
// LinkAt creates a hard link at rp representing the same file as vd. It
// does not take ownership of references on vd.
//
- // The implementation is responsible for checking that vd.Mount() ==
- // rp.Mount(), and that vd does not represent a directory.
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", LinkAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, LinkAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), LinkAt returns ENOENT.
+ //
+ // - If the directory in which the link would be created has been removed
+ // by RmdirAt or RenameAt, LinkAt returns ENOENT.
+ //
+ // - If rp.Mount != vd.Mount(), LinkAt returns EXDEV.
+ //
+ // - If vd represents a directory, LinkAt returns EPERM.
+ //
+ // - If vd represents a file for which all existing links have been
+ // removed, or a file created by open(O_TMPFILE|O_EXCL), LinkAt returns
+ // ENOENT. Equivalently, if vd represents a file with a link count of 0 not
+ // created by open(O_TMPFILE) without O_EXCL, LinkAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If LinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error
// MkdirAt creates a directory at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", MkdirAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, MkdirAt returns EEXIST.
+ //
+ // - If the directory in which the new directory would be created has been
+ // removed by RmdirAt or RenameAt, MkdirAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If MkdirAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error
// MknodAt creates a regular file, device special file, or named pipe at
// rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", MknodAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, MknodAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), MknodAt returns ENOENT.
+ //
+ // - If the directory in which the file would be created has been removed
+ // by RmdirAt or RenameAt, MknodAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If MknodAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error
// OpenAt returns an FileDescription providing access to the file at rp. A
// reference is taken on the returned FileDescription.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies O_TMPFILE and this feature is unsupported by
+ // the implementation, OpenAt returns EOPNOTSUPP. (All other unsupported
+ // features are silently ignored, consistently with Linux's open*(2).)
OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error)
// ReadlinkAt returns the target of the symbolic link at rp.
+ //
+ // Errors:
+ //
+ // - If the file at rp is not a symbolic link, ReadlinkAt returns EINVAL.
ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error)
- // RenameAt renames the Dentry represented by vd to rp. It does not take
- // ownership of references on vd.
+ // RenameAt renames the file named oldName in directory oldParentVD to rp.
+ // It does not take ownership of references on oldParentVD.
+ //
+ // Errors [1]:
+ //
+ // - If opts.Flags specifies unsupported options, RenameAt returns EINVAL.
+ //
+ // - If the last path component in rp is "." or "..", and opts.Flags
+ // contains RENAME_NOREPLACE, RenameAt returns EEXIST.
+ //
+ // - If the last path component in rp is "." or "..", and opts.Flags does
+ // not contain RENAME_NOREPLACE, RenameAt returns EBUSY.
+ //
+ // - If rp.Mount != oldParentVD.Mount(), RenameAt returns EXDEV.
+ //
+ // - If the renamed file is not a directory, and opts.MustBeDir is true,
+ // RenameAt returns ENOTDIR.
+ //
+ // - If renaming would replace an existing file and opts.Flags contains
+ // RENAME_NOREPLACE, RenameAt returns EEXIST.
+ //
+ // - If there is no existing file at rp and opts.Flags contains
+ // RENAME_EXCHANGE, RenameAt returns ENOENT.
+ //
+ // - If there is an existing non-directory file at rp, and rp.MustBeDir()
+ // is true, RenameAt returns ENOTDIR.
+ //
+ // - If the renamed file is not a directory, opts.Flags does not contain
+ // RENAME_EXCHANGE, and rp.MustBeDir() is true, RenameAt returns ENOTDIR.
+ // (This check is not subsumed by the check for directory replacement below
+ // since it applies even if there is no file to replace.)
+ //
+ // - If the renamed file is a directory, and the new parent directory of
+ // the renamed file is either the renamed directory or a descendant
+ // subdirectory of the renamed directory, RenameAt returns EINVAL.
//
- // The implementation is responsible for checking that vd.Mount() ==
- // rp.Mount().
- RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error
+ // - If renaming would exchange the renamed file with an ancestor directory
+ // of the renamed file, RenameAt returns EINVAL.
+ //
+ // - If renaming would replace an ancestor directory of the renamed file,
+ // RenameAt returns ENOTEMPTY. (This check would be subsumed by the
+ // non-empty directory check below; however, this check takes place before
+ // the self-rename check.)
+ //
+ // - If the renamed file would replace or exchange with itself (i.e. the
+ // source and destination paths resolve to the same file), RenameAt returns
+ // nil, skipping the checks described below.
+ //
+ // - If the source or destination directory is not writable by the provider
+ // of rp.Credentials(), RenameAt returns EACCES.
+ //
+ // - If the renamed file is a directory, and renaming would replace a
+ // non-directory file, RenameAt returns ENOTDIR.
+ //
+ // - If the renamed file is not a directory, and renaming would replace a
+ // directory, RenameAt returns EISDIR.
+ //
+ // - If the new parent directory of the renamed file has been removed by
+ // RmdirAt or a preceding call to RenameAt, RenameAt returns ENOENT.
+ //
+ // - If the renamed file is a directory, it is not writable by the
+ // provider of rp.Credentials(), and the source and destination parent
+ // directories are different, RenameAt returns EACCES. (This is nominally
+ // required to change the ".." entry in the renamed directory.)
+ //
+ // - If renaming would replace a non-empty directory, RenameAt returns
+ // ENOTEMPTY.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink(). oldName is not "." or "..".
+ //
+ // Postconditions: If RenameAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ //
+ // [1] "The worst of all namespace operations - renaming directory.
+ // "Perverted" doesn't even start to describe it. Somebody in UCB had a
+ // heck of a trip..." - fs/namei.c:vfs_rename()
+ RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error
// RmdirAt removes the directory at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is ".", RmdirAt returns EINVAL.
+ //
+ // - If the last path component in rp is "..", RmdirAt returns ENOTEMPTY.
+ //
+ // - If no file exists at rp, RmdirAt returns ENOENT.
+ //
+ // - If the file at rp exists but is not a directory, RmdirAt returns
+ // ENOTDIR.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If RmdirAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
RmdirAt(ctx context.Context, rp *ResolvingPath) error
// SetStatAt updates metadata for the file at the given path.
+ //
+ // Errors:
+ //
+ // - If opts specifies unsupported options, SetStatAt returns EINVAL.
SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error
// StatAt returns metadata for the file at rp.
@@ -181,11 +376,82 @@ type FilesystemImpl interface {
StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error)
// SymlinkAt creates a symbolic link at rp referring to the given target.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", SymlinkAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, SymlinkAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), SymlinkAt returns ENOENT.
+ //
+ // - If the directory in which the symbolic link would be created has been
+ // removed by RmdirAt or RenameAt, SymlinkAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If SymlinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error
- // UnlinkAt removes the non-directory file at rp.
+ // UnlinkAt removes the file at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", UnlinkAt returns
+ // EISDIR.
+ //
+ // - If no file exists at rp, UnlinkAt returns ENOENT.
+ //
+ // - If rp.MustBeDir(), and the file at rp exists and is not a directory,
+ // UnlinkAt returns ENOTDIR.
+ //
+ // - If the file at rp exists but is a directory, UnlinkAt returns EISDIR.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If UnlinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
UnlinkAt(ctx context.Context, rp *ResolvingPath) error
+ // ListxattrAt returns all extended attribute names for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // ListxattrAt returns nil. (See FileDescription.Listxattr for an
+ // explanation.)
+ ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error)
+
+ // GetxattrAt returns the value associated with the given extended
+ // attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, GetxattrAt
+ // returns ENOTSUP.
+ GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error)
+
+ // SetxattrAt changes the value associated with the given extended
+ // attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, SetxattrAt
+ // returns ENOTSUP.
+ SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
+
+ // RemovexattrAt removes the given extended attribute from the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // RemovexattrAt returns ENOTSUP.
+ RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+
// PrependPath prepends a path from vd to vd.Mount().Root() to b.
//
// If vfsroot.Ok(), it is the contextual VFS root; if it is encountered
@@ -208,7 +474,7 @@ type FilesystemImpl interface {
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
- // TODO: extended attributes; inotify_add_watch(); bind()
+ // TODO: inotify_add_watch(); bind()
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index c335e206d..023301780 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -15,6 +15,7 @@
package vfs
import (
+ "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -43,28 +44,70 @@ type GetFilesystemOptions struct {
InternalData interface{}
}
+type registeredFilesystemType struct {
+ fsType FilesystemType
+ opts RegisterFilesystemTypeOptions
+}
+
+// RegisterFilesystemTypeOptions contains options to
+// VirtualFilesystem.RegisterFilesystem().
+type RegisterFilesystemTypeOptions struct {
+ // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt()
+ // for which MountOptions.InternalMount == false to use this filesystem
+ // type.
+ AllowUserMount bool
+
+ // If AllowUserList is true, make this filesystem type visible in
+ // /proc/filesystems.
+ AllowUserList bool
+
+ // If RequiresDevice is true, indicate that mounting this filesystem
+ // requires a block device as the mount source in /proc/filesystems.
+ RequiresDevice bool
+}
+
// RegisterFilesystemType registers the given FilesystemType in vfs with the
// given name.
-func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType) error {
+func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) error {
vfs.fsTypesMu.Lock()
defer vfs.fsTypesMu.Unlock()
if existing, ok := vfs.fsTypes[name]; ok {
- return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing)
+ return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing.fsType)
+ }
+ vfs.fsTypes[name] = &registeredFilesystemType{
+ fsType: fsType,
+ opts: *opts,
}
- vfs.fsTypes[name] = fsType
return nil
}
// MustRegisterFilesystemType is equivalent to RegisterFilesystemType but
// panics on failure.
-func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType) {
- if err := vfs.RegisterFilesystemType(name, fsType); err != nil {
+func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) {
+ if err := vfs.RegisterFilesystemType(name, fsType, opts); err != nil {
panic(fmt.Sprintf("failed to register filesystem type %T: %v", fsType, err))
}
}
-func (vfs *VirtualFilesystem) getFilesystemType(name string) FilesystemType {
+func (vfs *VirtualFilesystem) getFilesystemType(name string) *registeredFilesystemType {
vfs.fsTypesMu.RLock()
defer vfs.fsTypesMu.RUnlock()
return vfs.fsTypes[name]
}
+
+// GenerateProcFilesystems emits the contents of /proc/filesystems for vfs to
+// buf.
+func (vfs *VirtualFilesystem) GenerateProcFilesystems(buf *bytes.Buffer) {
+ vfs.fsTypesMu.RLock()
+ defer vfs.fsTypesMu.RUnlock()
+ for name, rft := range vfs.fsTypes {
+ if !rft.opts.AllowUserList {
+ continue
+ }
+ var nodev string
+ if !rft.opts.RequiresDevice {
+ nodev = "nodev"
+ }
+ fmt.Fprintf(buf, "%s\t%s\n", nodev, name)
+ }
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index ec23ab0dd..00177b371 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -112,11 +112,11 @@ type MountNamespace struct {
// configured by the given arguments. A reference is taken on the returned
// MountNamespace.
func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
- fsType := vfs.getFilesystemType(fsTypeName)
- if fsType == nil {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
return nil, syserror.ENODEV
}
- fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
if err != nil {
return nil, err
}
@@ -136,11 +136,14 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
// MountAt creates and mounts a Filesystem configured by the given arguments.
func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
- fsType := vfs.getFilesystemType(fsTypeName)
- if fsType == nil {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
return syserror.ENODEV
}
- fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
+ if !opts.InternalMount && !rft.opts.AllowUserMount {
+ return syserror.ENODEV
+ }
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
return err
}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 3ecbc8fc1..b7774bf28 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -50,6 +50,10 @@ type MknodOptions struct {
type MountOptions struct {
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
GetFilesystemOptions GetFilesystemOptions
+
+ // If InternalMount is true, allow the use of filesystem types for which
+ // RegisterFilesystemTypeOptions.AllowUserMount == false.
+ InternalMount bool
}
// OpenOptions contains options to VirtualFilesystem.OpenAt() and
@@ -83,6 +87,9 @@ type ReadOptions struct {
type RenameOptions struct {
// Flags contains flags as specified for renameat2(2).
Flags uint32
+
+ // If MustBeDir is true, the renamed file must be a directory.
+ MustBeDir bool
}
// SetStatOptions contains options to VirtualFilesystem.SetStatAt(),
@@ -101,6 +108,20 @@ type SetStatOptions struct {
Stat linux.Statx
}
+// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
+// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
+// FileDescriptionImpl.Setxattr().
+type SetxattrOptions struct {
+ // Name is the name of the extended attribute being mutated.
+ Name string
+
+ // Value is the extended attribute's new value.
+ Value string
+
+ // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2).
+ Flags uint32
+}
+
// StatOptions contains options to VirtualFilesystem.StatAt(),
// FilesystemImpl.StatAt(), FileDescription.Stat(), and
// FileDescriptionImpl.Stat().
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 621f5a6f8..f0641d314 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -85,11 +85,11 @@ func init() {
// so error "constants" are really mutable vars, necessitating somewhat
// expensive interface object comparisons.
-type resolveMountRootError struct{}
+type resolveMountRootOrJumpError struct{}
// Error implements error.Error.
-func (resolveMountRootError) Error() string {
- return "resolving mount root"
+func (resolveMountRootOrJumpError) Error() string {
+ return "resolving mount root or jump"
}
type resolveMountPointError struct{}
@@ -112,30 +112,26 @@ var resolvingPathPool = sync.Pool{
},
}
-func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) (*ResolvingPath, error) {
- path, err := fspath.Parse(pop.Pathname)
- if err != nil {
- return nil, err
- }
+func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath {
rp := resolvingPathPool.Get().(*ResolvingPath)
rp.vfs = vfs
rp.root = pop.Root
rp.mount = pop.Start.mount
rp.start = pop.Start.dentry
- rp.pit = path.Begin
+ rp.pit = pop.Path.Begin
rp.flags = 0
if pop.FollowFinalSymlink {
rp.flags |= rpflagsFollowFinalSymlink
}
- rp.mustBeDir = path.Dir
- rp.mustBeDirOrig = path.Dir
+ rp.mustBeDir = pop.Path.Dir
+ rp.mustBeDirOrig = pop.Path.Dir
rp.symlinks = 0
rp.curPart = 0
rp.numOrigParts = 1
rp.creds = creds
- rp.parts[0] = path.Begin
- rp.origParts[0] = path.Begin
- return rp, nil
+ rp.parts[0] = pop.Path.Begin
+ rp.origParts[0] = pop.Path.Begin
+ return rp
}
func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) {
@@ -274,7 +270,7 @@ func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) {
// ... of non-root mount.
rp.nextMount = vd.mount
rp.nextStart = vd.dentry
- return nil, resolveMountRootError{}
+ return nil, resolveMountRootOrJumpError{}
}
// ... of root mount.
parent = d
@@ -345,29 +341,34 @@ func (rp *ResolvingPath) ShouldFollowSymlink() bool {
// symlink target and returns nil. Otherwise it returns a non-nil error.
//
// Preconditions: !rp.Done().
+//
+// Postconditions: If HandleSymlink returns a nil error, then !rp.Done().
func (rp *ResolvingPath) HandleSymlink(target string) error {
if rp.symlinks >= linux.MaxSymlinkTraversals {
return syserror.ELOOP
}
- targetPath, err := fspath.Parse(target)
- if err != nil {
- return err
+ if len(target) == 0 {
+ return syserror.ENOENT
}
rp.symlinks++
+ targetPath := fspath.Parse(target)
if targetPath.Absolute {
rp.absSymlinkTarget = targetPath
return resolveAbsSymlinkError{}
}
- if !targetPath.Begin.Ok() {
- panic(fmt.Sprintf("symbolic link has non-empty target %q that is both relative and has no path components?", target))
- }
// Consume the path component that represented the symlink.
rp.Advance()
// Prepend the symlink target to the relative path.
+ if checkInvariants {
+ if !targetPath.HasComponents() {
+ panic(fmt.Sprintf("non-empty pathname %q parsed to relative path with no components", target))
+ }
+ }
rp.relpathPrepend(targetPath)
return nil
}
+// Preconditions: path.HasComponents().
func (rp *ResolvingPath) relpathPrepend(path fspath.Path) {
if rp.pit.Ok() {
rp.parts[rp.curPart] = rp.pit
@@ -385,11 +386,32 @@ func (rp *ResolvingPath) relpathPrepend(path fspath.Path) {
}
}
+// HandleJump is called when the current path component is a "magic" link to
+// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem
+// method should continue path traversal, HandleMagicSymlink updates the path
+// component stream to reflect the magic link target and returns nil. Otherwise
+// it returns a non-nil error.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) HandleJump(target VirtualDentry) error {
+ if rp.symlinks >= linux.MaxSymlinkTraversals {
+ return syserror.ELOOP
+ }
+ rp.symlinks++
+ // Consume the path component that represented the magic link.
+ rp.Advance()
+ // Unconditionally return a resolveMountRootOrJumpError, even if the Mount
+ // isn't changing, to force restarting at the new Dentry.
+ target.IncRef()
+ rp.nextMount = target.mount
+ rp.nextStart = target.dentry
+ return resolveMountRootOrJumpError{}
+}
+
func (rp *ResolvingPath) handleError(err error) bool {
switch err.(type) {
- case resolveMountRootError:
- // Switch to the new Mount. We hold references on the Mount and Dentry
- // (from VFS.getMountpointAt()).
+ case resolveMountRootOrJumpError:
+ // Switch to the new Mount. We hold references on the Mount and Dentry.
rp.decRefStartAndMount()
rp.mount = rp.nextMount
rp.start = rp.nextStart
@@ -407,9 +429,8 @@ func (rp *ResolvingPath) handleError(err error) bool {
return true
case resolveMountPointError:
- // Switch to the new Mount. We hold a reference on the Mount (from
- // VFS.getMountAt()), but borrow the reference on the mount root from
- // the Mount.
+ // Switch to the new Mount. We hold a reference on the Mount, but
+ // borrow the reference on the mount root from the Mount.
rp.decRefStartAndMount()
rp.mount = rp.nextMount
rp.start = rp.nextMount.root
@@ -447,6 +468,17 @@ func (rp *ResolvingPath) handleError(err error) bool {
}
}
+// canHandleError returns true if err is an error returned by rp.Resolve*()
+// that rp.handleError() may attempt to handle.
+func (rp *ResolvingPath) canHandleError(err error) bool {
+ switch err.(type) {
+ case resolveMountRootOrJumpError, resolveMountPointError, resolveAbsSymlinkError:
+ return true
+ default:
+ return false
+ }
+}
+
// MustBeDir returns true if the file traversed by rp must be a directory.
func (rp *ResolvingPath) MustBeDir() bool {
return rp.mustBeDir
diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go
index 7a1d9e383..ee5c8b9e2 100644
--- a/pkg/sentry/vfs/testutil.go
+++ b/pkg/sentry/vfs/testutil.go
@@ -57,6 +57,11 @@ func (fs *FDTestFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath,
return nil, syserror.EPERM
}
+// GetParentDentryAt implements FilesystemImpl.GetParentDentryAt.
+func (fs *FDTestFilesystem) GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error) {
+ return nil, syserror.EPERM
+}
+
// LinkAt implements FilesystemImpl.LinkAt.
func (fs *FDTestFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error {
return syserror.EPERM
@@ -83,7 +88,7 @@ func (fs *FDTestFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (
}
// RenameAt implements FilesystemImpl.RenameAt.
-func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error {
+func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error {
return syserror.EPERM
}
@@ -117,6 +122,26 @@ func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) err
return syserror.EPERM
}
+// ListxattrAt implements FilesystemImpl.ListxattrAt.
+func (fs *FDTestFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) {
+ return nil, syserror.EPERM
+}
+
+// GetxattrAt implements FilesystemImpl.GetxattrAt.
+func (fs *FDTestFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) {
+ return "", syserror.EPERM
+}
+
+// SetxattrAt implements FilesystemImpl.SetxattrAt.
+func (fs *FDTestFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error {
+ return syserror.EPERM
+}
+
+// RemovexattrAt implements FilesystemImpl.RemovexattrAt.
+func (fs *FDTestFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error {
+ return syserror.EPERM
+}
+
// PrependPath implements FilesystemImpl.PrependPath.
func (fs *FDTestFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error {
b.PrependComponent(fmt.Sprintf("vfs.fdTestDentry:%p", vd.dentry.impl.(*fdTestDentry)))
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 7262b0d0a..ea2db7031 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -28,9 +28,11 @@
package vfs
import (
+ "fmt"
"sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
@@ -73,23 +75,29 @@ type VirtualFilesystem struct {
// mountpoints is analogous to Linux's mountpoint_hashtable.
mountpoints map[*Dentry]map[*Mount]struct{}
+ // devices contains all registered Devices. devices is protected by
+ // devicesMu.
+ devicesMu sync.RWMutex
+ devices map[devTuple]*registeredDevice
+
+ // fsTypes contains all registered FilesystemTypes. fsTypes is protected by
+ // fsTypesMu.
+ fsTypesMu sync.RWMutex
+ fsTypes map[string]*registeredFilesystemType
+
// filesystems contains all Filesystems. filesystems is protected by
// filesystemsMu.
filesystemsMu sync.Mutex
filesystems map[*Filesystem]struct{}
-
- // fsTypes contains all FilesystemTypes that are usable in the
- // VirtualFilesystem. fsTypes is protected by fsTypesMu.
- fsTypesMu sync.RWMutex
- fsTypes map[string]FilesystemType
}
// New returns a new VirtualFilesystem with no mounts or FilesystemTypes.
func New() *VirtualFilesystem {
vfs := &VirtualFilesystem{
mountpoints: make(map[*Dentry]map[*Mount]struct{}),
+ devices: make(map[devTuple]*registeredDevice),
+ fsTypes: make(map[string]*registeredFilesystemType),
filesystems: make(map[*Filesystem]struct{}),
- fsTypes: make(map[string]FilesystemType),
}
vfs.mounts.Init()
return vfs
@@ -111,11 +119,11 @@ type PathOperation struct {
// are borrowed from the provider of the PathOperation (i.e. the caller of
// the VFS method to which the PathOperation was passed).
//
- // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root.
+ // Invariants: Start.Ok(). If Path.Absolute, then Start == Root.
Start VirtualDentry
// Path is the pathname traversed by this operation.
- Pathname string
+ Path fspath.Path
// If FollowFinalSymlink is true, and the Dentry traversed by the final
// path component represents a symbolic link, the symbolic link should be
@@ -126,10 +134,7 @@ type PathOperation struct {
// GetDentryAt returns a VirtualDentry representing the given path, at which a
// file must exist. A reference is taken on the returned VirtualDentry.
func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return VirtualDentry{}, err
- }
+ rp := vfs.getResolvingPath(creds, pop)
for {
d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
if err == nil {
@@ -148,6 +153,33 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede
}
}
+// Preconditions: pop.Path.Begin.Ok().
+func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (VirtualDentry, string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ parent, err := rp.mount.fs.impl.GetParentDentryAt(ctx, rp)
+ if err == nil {
+ parentVD := VirtualDentry{
+ mount: rp.mount,
+ dentry: parent,
+ }
+ rp.mount.IncRef()
+ name := rp.Component()
+ vfs.putResolvingPath(rp)
+ return parentVD, name, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return VirtualDentry{}, "", err
+ }
+ }
+}
+
// LinkAt creates a hard link at newpop representing the existing file at
// oldpop.
func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error {
@@ -155,21 +187,36 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
if err != nil {
return err
}
- rp, err := vfs.getResolvingPath(creds, newpop)
- if err != nil {
+
+ if !newpop.Path.Begin.Ok() {
oldVD.DecRef()
- return err
+ if newpop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if newpop.FollowFinalSymlink {
+ oldVD.DecRef()
+ ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
}
+
+ rp := vfs.getResolvingPath(creds, newpop)
for {
err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
if err == nil {
- oldVD.DecRef()
vfs.putResolvingPath(rp)
+ oldVD.DecRef()
return nil
}
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
if !rp.handleError(err) {
- oldVD.DecRef()
vfs.putResolvingPath(rp)
+ oldVD.DecRef()
return err
}
}
@@ -177,19 +224,32 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
// MkdirAt creates a directory at the given path.
func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
// "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
// also honored." - mkdir(2)
opts.Mode &= 0777 | linux.S_ISVTX
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
- }
+
+ rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(rp)
return nil
}
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
if !rp.handleError(err) {
vfs.putResolvingPath(rp)
return err
@@ -200,16 +260,29 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
// MknodAt creates a file of the given mode at the given path. It returns an
// error from the syserror package.
func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
}
+
+ rp := vfs.getResolvingPath(creds, pop)
for {
- if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
+ err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
+ if err != nil {
vfs.putResolvingPath(rp)
return nil
}
- // Handle mount traversals.
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
if !rp.handleError(err) {
vfs.putResolvingPath(rp)
return err
@@ -259,10 +332,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
if opts.Flags&linux.O_NOFOLLOW != 0 {
pop.FollowFinalSymlink = false
}
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil, err
- }
+ rp := vfs.getResolvingPath(creds, pop)
if opts.Flags&linux.O_DIRECTORY != 0 {
rp.mustBeDir = true
rp.mustBeDirOrig = true
@@ -282,10 +352,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
// ReadlinkAt returns the target of the symbolic link at the given path.
func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return "", err
- }
+ rp := vfs.getResolvingPath(creds, pop)
for {
target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
if err == nil {
@@ -301,25 +368,59 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden
// RenameAt renames the file at oldpop to newpop.
func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error {
- oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
- if err != nil {
- return err
+ if !oldpop.Path.Begin.Ok() {
+ if oldpop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if oldpop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink")
+ return syserror.EINVAL
}
- rp, err := vfs.getResolvingPath(creds, newpop)
+
+ oldParentVD, oldName, err := vfs.getParentDirAndName(ctx, creds, oldpop)
if err != nil {
- oldVD.DecRef()
return err
}
+ if oldName == "." || oldName == ".." {
+ oldParentVD.DecRef()
+ return syserror.EBUSY
+ }
+
+ if !newpop.Path.Begin.Ok() {
+ oldParentVD.DecRef()
+ if newpop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if newpop.FollowFinalSymlink {
+ oldParentVD.DecRef()
+ ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, newpop)
+ renameOpts := *opts
+ if oldpop.Path.Dir {
+ renameOpts.MustBeDir = true
+ }
for {
- err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts)
+ err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts)
if err == nil {
- oldVD.DecRef()
vfs.putResolvingPath(rp)
+ oldParentVD.DecRef()
return nil
}
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
if !rp.handleError(err) {
- oldVD.DecRef()
vfs.putResolvingPath(rp)
+ oldParentVD.DecRef()
return err
}
}
@@ -327,16 +428,29 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
// RmdirAt removes the directory at the given path.
func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink")
+ return syserror.EINVAL
}
+
+ rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.RmdirAt(ctx, rp)
if err == nil {
vfs.putResolvingPath(rp)
return nil
}
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
if !rp.handleError(err) {
vfs.putResolvingPath(rp)
return err
@@ -346,10 +460,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
// SetStatAt changes metadata for the file at the given path.
func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
- }
+ rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
if err == nil {
@@ -365,10 +476,7 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent
// StatAt returns metadata for the file at the given path.
func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return linux.Statx{}, err
- }
+ rp := vfs.getResolvingPath(creds, pop)
for {
stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
if err == nil {
@@ -385,10 +493,7 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential
// StatFSAt returns metadata for the filesystem containing the file at the
// given path.
func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return linux.Statfs{}, err
- }
+ rp := vfs.getResolvingPath(creds, pop)
for {
statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
if err == nil {
@@ -404,16 +509,29 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti
// SymlinkAt creates a symbolic link at the given path with the given target.
func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
}
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
if err == nil {
vfs.putResolvingPath(rp)
return nil
}
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
if !rp.handleError(err) {
vfs.putResolvingPath(rp)
return err
@@ -423,16 +541,104 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
// UnlinkAt deletes the non-directory file at the given path.
func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
}
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
if err == nil {
vfs.putResolvingPath(rp)
return nil
}
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// ListxattrAt returns all extended attribute names for the file at the given
+// path.
+func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return names, nil
+ }
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by
+ // default don't exist.
+ vfs.putResolvingPath(rp)
+ return nil, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// GetxattrAt returns the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return val, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return "", err
+ }
+ }
+}
+
+// SetxattrAt changes the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// RemovexattrAt removes the given extended attribute from the file at rp.
+func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
if !rp.handleError(err) {
vfs.putResolvingPath(rp)
return err
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 1987e89cc..2269f6237 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -45,6 +45,7 @@ var (
ELIBBAD = error(syscall.ELIBBAD)
ELOOP = error(syscall.ELOOP)
EMFILE = error(syscall.EMFILE)
+ EMLINK = error(syscall.EMLINK)
EMSGSIZE = error(syscall.EMSGSIZE)
ENAMETOOLONG = error(syscall.ENAMETOOLONG)
ENOATTR = ENODATA
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 65d4d0cd8..e07ebd153 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -10,6 +10,7 @@ go_library(
"packet_buffer_state.go",
"tcpip.go",
"time_unsafe.go",
+ "timer.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
@@ -26,3 +27,10 @@ go_test(
srcs = ["tcpip_test.go"],
embed = [":tcpip"],
)
+
+go_test(
+ name = "timer_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ deps = [":tcpip"],
+)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index f1d837196..f2061c778 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -44,6 +44,7 @@ go_test(
],
deps = [
":header",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index fc671e439..135a60b12 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -15,6 +15,7 @@
package header
import (
+ "crypto/sha256"
"encoding/binary"
"strings"
@@ -102,6 +103,11 @@ const (
// bytes including and after the IIDOffsetInIPv6Address-th byte are
// for the IID.
IIDOffsetInIPv6Address = 8
+
+ // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
+ // for the secret key used to generate an opaque interface identifier as
+ // outlined by RFC 7217.
+ OpaqueIIDSecretKeyMinBytes = 16
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -326,3 +332,42 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+
+// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
+// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
+//
+// The opaque IID is generated from the cryptographic hash of the concatenation
+// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
+// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
+// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
+// requirements for security.
+//
+// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
+// array for the buffer will not be allocated.
+func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
+ // As per RFC 7217 section 5, the opaque identifier can be generated as a
+ // cryptographic hash of the concatenation of each of the function parameters.
+ // Note, we omit the optional Network_ID field.
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address]))
+ h.Write([]byte(nicName))
+ h.Write([]byte{dadCounter})
+ h.Write(secretKey)
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ return append(buf, sum[:IIDSize]...)
+}
+
+// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
+// an opaque IID.
+func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 42c5c6fc1..1994003ed 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -15,9 +15,12 @@
package header_test
import (
+ "bytes"
+ "crypto/sha256"
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -43,3 +46,163 @@ func TestLinkLocalAddr(t *testing.T) {
t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
}
}
+
+func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ prefix: header.IPv6LinkLocalPrefix.Subnet(),
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ h := sha256.New()
+ h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
+ h.Write([]byte(test.nicName))
+ h.Write([]byte{test.dadCounter})
+ if k := test.secretKey; k != nil {
+ h.Write(k)
+ }
+ var hashSum [sha256.Size]byte
+ h.Sum(hashSum[:0])
+ want := hashSum[:header.IIDSize]
+
+ // Passing a nil buffer should result in a new buffer returned with the
+ // IID.
+ if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+
+ // Passing a buffer with sufficient capacity for the IID should populate
+ // the buffer provided.
+ var iidBuf [header.IIDSize]byte
+ if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ if got := iidBuf[:]; !bytes.Equal(got, want) {
+ t.Errorf("got iidBuf = %x, want = %x", got, want)
+ }
+ })
+ }
+}
+
+func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ prefix := header.IPv6LinkLocalPrefix.Subnet()
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ addrBytes := [header.IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ test.nicName,
+ test.dadCounter,
+ test.secretKey,
+ ))
+
+ if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
+ t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index da8482509..42cacb8a6 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4144a7837..f1bc33adf 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e645cf62c..4ee3d5b45 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
@@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
@@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return len(pkts), nil
}
@@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
ip := header.IPv4(pkt.Data.First())
@@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
e.HandlePacket(r, pkt.Clone())
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index dd31f0fb7..58c3c79b9 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
@@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
@@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return len(pkts), nil
}
@@ -161,8 +161,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
- // TODO(b/119580726): Support IPv6 header-included packets.
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+ // TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 69077669a..b8f9517d0 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -59,6 +59,7 @@ go_test(
],
deps = [
":stack",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 27bd02e76..35825ebf7 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -131,40 +131,34 @@ type NDPDispatcher interface {
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
// OnDefaultRouterDiscovered will be called when a new default router is
- // discovered. Implementations must return true along with a new valid
- // route table if the newly discovered router should be remembered. If
- // an implementation returns false, the second return value will be
- // ignored.
+ // discovered. Implementations must return true if the newly discovered
+ // router should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route)
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
// OnDefaultRouterInvalidated will be called when a discovered default
- // router is invalidated. Implementers must return a new valid route
- // table.
+ // router that was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
// OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
- // discovered. Implementations must return true along with a new valid
- // route table if the newly discovered on-link prefix should be
- // remembered. If an implementation returns false, the second return
- // value will be ignored.
+ // discovered. Implementations must return true if the newly discovered
+ // on-link prefix should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route)
+ OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
// OnOnLinkPrefixInvalidated will be called when a discovered on-link
- // prefix is invalidated. Implementers must return a new valid route
- // table.
+ // prefix that was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route
+ OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
// OnAutoGenAddress will be called when a new prefix with its
// autonomous address-configuration flag set has been received and SLAAC
@@ -175,6 +169,15 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ // OnAutoGenAddressDeprecated will be called when an auto-generated
+ // address (as part of SLAAC) has been deprecated, but is still
+ // considered valid. Note, if an address is invalidated at the same
+ // time it is deprecated, the deprecation event MAY be omitted.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
// OnAutoGenAddressInvalidated will be called when an auto-generated
// address (as part of SLAAC) has been invalidated.
//
@@ -296,71 +299,27 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- invalidationTimer *time.Timer
-
- // Used to inform the timer not to invalidate the default router (R) in
- // a race condition (T1 is a goroutine that handles an RA from R and T2
- // is the goroutine that handles R's invalidation timer firing):
- // T1: Receive a new RA from R
- // T1: Obtain the NIC's lock before processing the RA
- // T2: R's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends R's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates R immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate R, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
+ invalidationTimer tcpip.CancellableTimer
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- invalidationTimer *time.Timer
-
- // Used to signal the timer not to invalidate the on-link prefix (P) in
- // a race condition (T1 is a goroutine that handles a PI for P and T2
- // is the goroutine that handles P's invalidation timer firing):
- // T1: Receive a new PI for P
- // T1: Obtain the NIC's lock before processing the PI
- // T2: P's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends P's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates P immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate P, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
+ invalidationTimer tcpip.CancellableTimer
}
// autoGenAddressState holds data associated with an address generated via
// SLAAC.
type autoGenAddressState struct {
- invalidationTimer *time.Timer
-
- // Used to signal the timer not to invalidate the SLAAC address (A) in
- // a race condition (T1 is a goroutine that handles a PI for A and T2
- // is the goroutine that handles A's invalidation timer firing):
- // T1: Receive a new PI for A
- // T1: Obtain the NIC's lock before processing the PI
- // T2: A's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends A's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates A immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate A, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
-
- // Nonzero only when the address is not valid forever (invalidationTimer
- // is not nil).
+ // A reference to the referencedNetworkEndpoint that this autoGenAddressState
+ // is holding state for.
+ ref *referencedNetworkEndpoint
+
+ deprecationTimer tcpip.CancellableTimer
+ invalidationTimer tcpip.CancellableTimer
+
+ // Nonzero only when the address is not valid forever.
validUntil time.Time
}
@@ -562,7 +521,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
// handleRA handles a Router Advertisement message that arrived on the NIC
// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Is the NIC configured to handle RAs at all?
//
@@ -591,27 +550,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
// the invalidation timer.
- timer := rtr.invalidationTimer
-
- // We should ALWAYS have an invalidation timer for a
- // discovered router.
- if timer == nil {
- panic("ndphandlera: RA invalidation timer should not be nil")
- }
-
- if !timer.Stop() {
- // If we reach this point, then we know the
- // timer fired after we already took the NIC
- // lock. Inform the timer not to invalidate the
- // router when it obtains the lock as we just
- // got a new RA that refreshes its lifetime to a
- // non-zero value. See
- // defaultRouterState.doNotInvalidate for more
- // details.
- *rtr.doNotInvalidate = true
- }
-
- timer.Reset(rl)
+ rtr.invalidationTimer.StopLocked()
+ rtr.invalidationTimer.Reset(rl)
+ ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
// We know about the router but it is no longer to be
@@ -668,7 +609,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// invalidateDefaultRouter invalidates a discovered default router.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
rtr, ok := ndp.defaultRouters[ip]
@@ -678,16 +619,13 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.Stop()
- rtr.invalidationTimer = nil
- *rtr.doNotInvalidate = true
- rtr.doNotInvalidate = nil
+ rtr.invalidationTimer.StopLocked()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndp.nic.stack.ndpDisp != nil {
- ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
}
}
@@ -696,43 +634,29 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
//
// The router identified by ip MUST NOT already be known by the NIC.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- if ndp.nic.stack.ndpDisp == nil {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
return
}
// Inform the integrator when we discovered a default router.
- remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip)
- if !remember {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
- // Used to signal the timer not to invalidate the default router (R) in
- // a race condition. See defaultRouterState.doNotInvalidate for more
- // details.
- var doNotInvalidate bool
-
- ndp.defaultRouters[ip] = defaultRouterState{
- invalidationTimer: time.AfterFunc(rl, func() {
- ndp.nic.stack.mu.Lock()
- defer ndp.nic.stack.mu.Unlock()
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if doNotInvalidate {
- doNotInvalidate = false
- return
- }
-
+ state := defaultRouterState{
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
- doNotInvalidate: &doNotInvalidate,
}
- ndp.nic.stack.routeTable = routeTable
+ state.invalidationTimer.Reset(rl)
+
+ ndp.defaultRouters[ip] = state
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -740,42 +664,36 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The prefix identified by prefix MUST NOT already be known.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- if ndp.nic.stack.ndpDisp == nil {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
return
}
// Inform the integrator when we discovered an on-link prefix.
- remember, routeTable := ndp.nic.stack.ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix)
- if !remember {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
- // Used to signal the timer not to invalidate the on-link prefix (P) in
- // a race condition. See onLinkPrefixState.doNotInvalidate for more
- // details.
- var doNotInvalidate bool
- var timer *time.Timer
-
- // Only create a timer if the lifetime is not infinite.
- if l < header.NDPInfiniteLifetime {
- timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate)
+ state := onLinkPrefixState{
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
}
- ndp.onLinkPrefixes[prefix] = onLinkPrefixState{
- invalidationTimer: timer,
- doNotInvalidate: &doNotInvalidate,
+ if l < header.NDPInfiniteLifetime {
+ state.invalidationTimer.Reset(l)
}
- ndp.nic.stack.routeTable = routeTable
+ ndp.onLinkPrefixes[prefix] = state
}
// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
s, ok := ndp.onLinkPrefixes[prefix]
@@ -785,51 +703,23 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- if s.invalidationTimer != nil {
- s.invalidationTimer.Stop()
- s.invalidationTimer = nil
- *s.doNotInvalidate = true
- }
-
- s.doNotInvalidate = nil
+ s.invalidationTimer.StopLocked()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndp.nic.stack.ndpDisp != nil {
- ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
}
}
-// prefixInvalidationCallback returns a new on-link prefix invalidation timer
-// for prefix that fires after vl.
-//
-// doNotInvalidate is used to signal the timer when it fires at the same time
-// that a prefix's valid lifetime gets refreshed. See
-// onLinkPrefixState.doNotInvalidate for more details.
-func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer {
- return time.AfterFunc(vl, func() {
- ndp.nic.stack.mu.Lock()
- defer ndp.nic.stack.mu.Unlock()
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if *doNotInvalidate {
- *doNotInvalidate = false
- return
- }
-
- ndp.invalidateOnLinkPrefix(prefix)
- })
-}
-
// handleOnLinkPrefixInformation handles a Prefix Information option with
// its on-link flag set, as per RFC 4861 section 6.3.4.
//
// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the on-link flag is set.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
prefix := pi.Subnet()
prefixState, ok := ndp.onLinkPrefixes[prefix]
@@ -862,42 +752,17 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
+ //
// Update the invalidation timer.
- timer := prefixState.invalidationTimer
-
- if timer == nil && vl >= header.NDPInfiniteLifetime {
- // Had infinite valid lifetime before and
- // continues to have an invalid lifetime. Do
- // nothing further.
- return
- }
-
- if timer != nil && !timer.Stop() {
- // If we reach this point, then we know the timer alread fired
- // after we took the NIC lock. Inform the timer to not
- // invalidate the prefix once it obtains the lock as we just
- // got a new PI that refreshes its lifetime to a non-zero value.
- // See onLinkPrefixState.doNotInvalidate for more details.
- *prefixState.doNotInvalidate = true
- }
- if vl >= header.NDPInfiniteLifetime {
- // Prefix is now valid forever so we don't need
- // an invalidation timer.
- prefixState.invalidationTimer = nil
- ndp.onLinkPrefixes[prefix] = prefixState
- return
- }
+ prefixState.invalidationTimer.StopLocked()
- if timer != nil {
- // We already have a timer so just reset it to
- // expire after the new valid lifetime.
- timer.Reset(vl)
- return
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // the new valid lifetime.
+ prefixState.invalidationTimer.Reset(vl)
}
- // We do not have a timer so just create a new one.
- prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate)
ndp.onLinkPrefixes[prefix] = prefixState
}
@@ -907,7 +772,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the autonomous flag is set.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
vl := pi.ValidLifetime()
pl := pi.PreferredLifetime()
@@ -922,103 +787,30 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
// Check if we already have an auto-generated address for prefix.
- for _, ref := range ndp.nic.endpoints {
- if ref.protocol != header.IPv6ProtocolNumber {
- continue
- }
-
- if ref.configType != slaac {
- continue
- }
-
- addr := ref.ep.ID().LocalAddress
- refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()}
+ for addr, addrState := range ndp.autoGenAddresses {
+ refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()}
if refAddrWithPrefix.Subnet() != prefix {
continue
}
- //
- // At this point, we know we are refreshing a SLAAC generated
- // IPv6 address with the prefix, prefix. Do the work as outlined
- // by RFC 4862 section 5.5.3.e.
- //
-
- addrState, ok := ndp.autoGenAddresses[addr]
- if !ok {
- panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr))
- }
-
- // TODO(b/143713887): Handle deprecating auto-generated address
- // after the preferred lifetime.
-
- // As per RFC 4862 section 5.5.3.e, the valid lifetime of the
- // address generated by SLAAC is as follows:
- //
- // 1) If the received Valid Lifetime is greater than 2 hours or
- // greater than RemainingLifetime, set the valid lifetime of
- // the address to the advertised Valid Lifetime.
- //
- // 2) If RemainingLifetime is less than or equal to 2 hours,
- // ignore the advertised Valid Lifetime.
- //
- // 3) Otherwise, reset the valid lifetime of the address to 2
- // hours.
-
- // Handle the infinite valid lifetime separately as we do not
- // keep a timer in this case.
- if vl >= header.NDPInfiniteLifetime {
- if addrState.invalidationTimer != nil {
- // Valid lifetime was finite before, but now it
- // is valid forever.
- if !addrState.invalidationTimer.Stop() {
- *addrState.doNotInvalidate = true
- }
- addrState.invalidationTimer = nil
- addrState.validUntil = time.Time{}
- ndp.autoGenAddresses[addr] = addrState
- }
-
- return
- }
-
- var effectiveVl time.Duration
- var rl time.Duration
-
- // If the address was originally set to be valid forever,
- // assume the remaining time to be the maximum possible value.
- if addrState.invalidationTimer == nil {
- rl = header.NDPInfiniteLifetime
- } else {
- rl = time.Until(addrState.validUntil)
- }
-
- if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
- effectiveVl = vl
- } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
- ndp.autoGenAddresses[addr] = addrState
- return
- } else {
- effectiveVl = MinPrefixInformationValidLifetimeForUpdate
- }
-
- if addrState.invalidationTimer == nil {
- addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate)
- } else {
- if !addrState.invalidationTimer.Stop() {
- *addrState.doNotInvalidate = true
- }
- addrState.invalidationTimer.Reset(effectiveVl)
- }
-
- addrState.validUntil = time.Now().Add(effectiveVl)
- ndp.autoGenAddresses[addr] = addrState
+ // At this point, we know we are refreshing a SLAAC generated IPv6 address
+ // with the prefix prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.e.
+ ndp.refreshAutoGenAddressLifetimes(addr, pl, vl)
return
}
// We do not already have an address within the prefix, prefix. Do the
// work as outlined by RFC 4862 section 5.5.3.d if n is configured
// to auto-generated global addresses by SLAAC.
+ ndp.newAutoGenAddress(prefix, pl, vl)
+}
+// newAutoGenAddress generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) {
// Are we configured to auto-generate new global addresses?
if !ndp.configs.AutoGenGlobalAddresses {
return
@@ -1038,22 +830,24 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
return
}
- // Only attempt to generate an interface-specific IID if we have a valid
- // link address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- linkAddr := ndp.nic.linkEP.LinkAddress()
- if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return
- }
+ addrBytes := []byte(prefix.ID())
+ if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey)
+ } else {
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return
+ }
- // Generate an address within prefix from the modified EUI-64 of ndp's
- // NIC's Ethernet MAC address.
- addrBytes := make([]byte, header.IPv6AddressSize)
- copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address])
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ // Generate an address within prefix from the modified EUI-64 of ndp's NIC's
+ // Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ }
addr := tcpip.Address(addrBytes)
addrWithPrefix := tcpip.AddressWithPrefix{
Address: addr,
@@ -1066,37 +860,141 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
}
// Inform the integrator that we have a new SLAAC address.
- if ndp.nic.stack.ndpDisp == nil {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
return
}
- if !ndp.nic.stack.ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
// Informed by the integrator not to add the address.
return
}
- if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{
+ protocolAddr := tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addrWithPrefix,
- }, FirstPrimaryEndpoint, permanent, slaac); err != nil {
- panic(err)
+ }
+ // If the preferred lifetime is zero, then the address should be considered
+ // deprecated.
+ deprecated := pl == 0
+ ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
+ if err != nil {
+ log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
}
- // Setup the timers to deprecate and invalidate this newly generated
+ state := autoGenAddressState{
+ ref: ref,
+ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)
+ }
+ addrState.ref.deprecated = true
+ ndp.notifyAutoGenAddressDeprecated(addr)
+ }),
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateAutoGenAddress(addr)
+ }),
+ }
+
+ // Setup the initial timers to deprecate and invalidate this newly generated
// address.
- // TODO(b/143713887): Handle deprecating auto-generated addresses
- // after the preferred lifetime.
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ state.deprecationTimer.Reset(pl)
+ }
- var doNotInvalidate bool
- var vTimer *time.Timer
if vl < header.NDPInfiniteLifetime {
- vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate)
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = time.Now().Add(vl)
+ }
+
+ ndp.autoGenAddresses[addr] = state
+}
+
+// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated
+// address addr.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) {
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr)
+ }
+ defer func() { ndp.autoGenAddresses[addr] = addrState }()
+
+ // If the preferred lifetime is zero, then the address should be considered
+ // deprecated.
+ deprecated := pl == 0
+ wasDeprecated := addrState.ref.deprecated
+ addrState.ref.deprecated = deprecated
+
+ // Only send the deprecation event if the deprecated status for addr just
+ // changed from non-deprecated to deprecated.
+ if !wasDeprecated && deprecated {
+ ndp.notifyAutoGenAddressDeprecated(addr)
+ }
+
+ // If addr was preferred for some finite lifetime before, stop the deprecation
+ // timer so it can be reset.
+ addrState.deprecationTimer.StopLocked()
+
+ // Reset the deprecation timer if addr has a finite preferred lifetime.
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ addrState.deprecationTimer.Reset(pl)
+ }
+
+ // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address
+ //
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the address to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the address to 2 hours.
+
+ // Handle the infinite valid lifetime separately as we do not keep a timer in
+ // this case.
+ if vl >= header.NDPInfiniteLifetime {
+ addrState.invalidationTimer.StopLocked()
+ addrState.validUntil = time.Time{}
+ return
+ }
+
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the address was originally set to be valid forever, assume the remaining
+ // time to be the maximum possible value.
+ if addrState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(addrState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
+ return
+ } else {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
}
- ndp.autoGenAddresses[addr] = autoGenAddressState{
- invalidationTimer: vTimer,
- doNotInvalidate: &doNotInvalidate,
- validUntil: time.Now().Add(vl),
+ addrState.invalidationTimer.StopLocked()
+ addrState.invalidationTimer.Reset(effectiveVl)
+ addrState.validUntil = time.Now().Add(effectiveVl)
+}
+
+// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr
+// has been deprecated.
+func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ })
}
}
@@ -1120,23 +1018,16 @@ func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
state, ok := ndp.autoGenAddresses[addr]
-
if !ok {
return false
}
- if state.invalidationTimer != nil {
- state.invalidationTimer.Stop()
- state.invalidationTimer = nil
- *state.doNotInvalidate = true
- }
-
- state.doNotInvalidate = nil
-
+ state.deprecationTimer.StopLocked()
+ state.invalidationTimer.StopLocked()
delete(ndp.autoGenAddresses, addr)
- if ndp.nic.stack.ndpDisp != nil {
- ndp.nic.stack.ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: validPrefixLenForAutoGen,
})
@@ -1145,22 +1036,37 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// autoGenAddrInvalidationTimer returns a new invalidation timer for an
-// auto-generated address that fires after vl.
+// cleanupHostOnlyState cleans up any state that is only useful for hosts.
//
-// doNotInvalidate is used to inform the timer when it fires at the same time
-// that an auto-generated address's valid lifetime gets refreshed. See
-// autoGenAddrState.doNotInvalidate for more details.
-func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer {
- return time.AfterFunc(vl, func() {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if *doNotInvalidate {
- *doNotInvalidate = false
- return
- }
-
+// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
+// host to a router. This function will invalidate all discovered on-link
+// prefixes, discovered routers, and auto-generated addresses as routers do not
+// normally process Router Advertisements to discover default routers and
+// on-link prefixes, and auto-generate addresses via SLAAC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupHostOnlyState() {
+ for addr, _ := range ndp.autoGenAddresses {
ndp.invalidateAutoGenAddress(addr)
- })
+ }
+
+ if got := len(ndp.autoGenAddresses); got != 0 {
+ log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got)
+ }
+
+ for prefix, _ := range ndp.onLinkPrefixes {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }
+
+ if got := len(ndp.onLinkPrefixes); got != 0 {
+ log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got)
+ }
+
+ for router, _ := range ndp.defaultRouters {
+ ndp.invalidateDefaultRouter(router)
+ }
+
+ if got := len(ndp.defaultRouters); got != 0 {
+ log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got)
+ }
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index d8e7ce67e..d334af289 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -29,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
@@ -45,8 +48,25 @@ var (
llAddr1 = header.LinkLocalAddr(linkAddr1)
llAddr2 = header.LinkLocalAddr(linkAddr2)
llAddr3 = header.LinkLocalAddr(linkAddr3)
+ dstAddr = tcpip.FullAddress{
+ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ Port: 25,
+ }
)
+func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addrBytes := []byte(subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+}
+
// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
// tcpip.Subnet, and an address where the lower half of the address is composed
// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
@@ -59,17 +79,7 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi
subnet := prefix.Subnet()
- var addr tcpip.AddressWithPrefix
- if header.IsValidUnicastEthernetAddress(linkAddr) {
- addrBytes := []byte(subnet.ID())
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
- addr = tcpip.AddressWithPrefix{
- Address: tcpip.Address(addrBytes),
- PrefixLen: 64,
- }
- }
-
- return prefix, subnet, addr
+ return prefix, subnet, addrForSubnet(subnet, linkAddr)
}
// TestDADDisabled tests that an address successfully resolves immediately
@@ -79,7 +89,7 @@ func TestDADDisabled(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -132,6 +142,7 @@ type ndpAutoGenAddrEventType int
const (
newAddr ndpAutoGenAddrEventType = iota
+ deprecatedAddr
invalidatedAddr
)
@@ -163,7 +174,6 @@ type ndpDispatcher struct {
rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
- routeTable []tcpip.Route
}
// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
@@ -179,106 +189,56 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add
}
// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
-func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) {
- if n.routerC != nil {
- n.routerC <- ndpRouterEvent{
+func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
nicID,
addr,
true,
}
}
- if !n.rememberRouter {
- return false, nil
- }
-
- rt := append([]tcpip.Route(nil), n.routeTable...)
- rt = append(rt, tcpip.Route{
- Destination: header.IPv6EmptySubnet,
- Gateway: addr,
- NIC: nicID,
- })
- n.routeTable = rt
- return true, rt
+ return n.rememberRouter
}
// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
-func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route {
- if n.routerC != nil {
- n.routerC <- ndpRouterEvent{
+func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
nicID,
addr,
false,
}
}
-
- var rt []tcpip.Route
- exclude := tcpip.Route{
- Destination: header.IPv6EmptySubnet,
- Gateway: addr,
- NIC: nicID,
- }
-
- for _, r := range n.routeTable {
- if r != exclude {
- rt = append(rt, r)
- }
- }
- n.routeTable = rt
- return rt
}
// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) {
- if n.prefixC != nil {
- n.prefixC <- ndpPrefixEvent{
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
nicID,
prefix,
true,
}
}
- if !n.rememberPrefix {
- return false, nil
- }
-
- rt := append([]tcpip.Route(nil), n.routeTable...)
- rt = append(rt, tcpip.Route{
- Destination: prefix,
- NIC: nicID,
- })
- n.routeTable = rt
- return true, rt
+ return n.rememberPrefix
}
// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
-func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route {
- if n.prefixC != nil {
- n.prefixC <- ndpPrefixEvent{
+func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
nicID,
prefix,
false,
}
}
-
- var rt []tcpip.Route
- exclude := tcpip.Route{
- Destination: prefix,
- NIC: nicID,
- }
-
- for _, r := range n.routeTable {
- if r != exclude {
- rt = append(rt, r)
- }
- }
- n.routeTable = rt
- return rt
}
func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
- if n.autoGenAddrC != nil {
- n.autoGenAddrC <- ndpAutoGenAddrEvent{
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
nicID,
addr,
newAddr,
@@ -287,9 +247,19 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
return true
}
+func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ deprecatedAddr,
+ }
+ }
+}
+
func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
- if n.autoGenAddrC != nil {
- n.autoGenAddrC <- ndpAutoGenAddrEvent{
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
nicID,
addr,
invalidatedAddr,
@@ -299,8 +269,8 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi
// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
- if n.rdnssC != nil {
- n.rdnssC <- ndpRDNSSEvent{
+ if c := n.rdnssC; c != nil {
+ c <- ndpRDNSSEvent{
nicID,
ndpRDNSS{
addrs,
@@ -330,6 +300,8 @@ func TestDADResolve(t *testing.T) {
}
for _, test := range tests {
+ test := test
+
t.Run(test.name, func(t *testing.T) {
t.Parallel()
@@ -520,7 +492,7 @@ func TestDADFail(t *testing.T) {
}
opts.NDPConfigs.RetransmitTimer = time.Second * 2
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -601,7 +573,7 @@ func TestDADStop(t *testing.T) {
NDPConfigs: ndpConfigs,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -702,7 +674,7 @@ func TestSetNDPConfigurations(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPDisp: &ndpDisp,
@@ -903,8 +875,6 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
// TestNoRouterDiscovery tests that router discovery will not be performed if
// configured not to.
func TestNoRouterDiscovery(t *testing.T) {
- t.Parallel()
-
// Being configured to discover routers means handle and
// discover are set to true and forwarding is set to false.
// This tests all possible combinations of the configurations,
@@ -920,9 +890,9 @@ func TestNoRouterDiscovery(t *testing.T) {
t.Parallel()
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 10),
+ routerC: make(chan ndpRouterEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -942,21 +912,27 @@ func TestNoRouterDiscovery(t *testing.T) {
select {
case <-ndpDisp.routerC:
t.Fatal("unexpectedly discovered a router when configured not to")
- case <-time.After(defaultTimeout):
+ default:
}
})
}
}
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
+ return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
t.Parallel()
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 10),
+ routerC: make(chan ndpRouterEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -970,50 +946,25 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
t.Fatalf("CreateNIC(1) = %s", err)
}
- routeTable := []tcpip.Route{
- {
- header.IPv6EmptySubnet,
- llAddr3,
- 1,
- },
- }
- s.SetRouteTable(routeTable)
-
- // Rx an RA with short lifetime.
- lifetime := time.Duration(1)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime)))
+ // Receive an RA for a router we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
select {
- case r := <-ndpDisp.routerC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != llAddr2 {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2)
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
}
- if !r.discovered {
- t.Fatal("got r.discovered = false, want = true")
- }
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for router discovery event")
+ default:
+ t.Fatal("expected router discovery event")
}
- // Original route table should not have been modified.
- if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable)
- }
-
- // Wait for the normal invalidation time plus an extra second to
- // make sure we do not actually receive any invalidation events as
- // we should not have remembered the router in the first place.
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the router in the first place.
select {
case <-ndpDisp.routerC:
t.Fatal("should not have received any router events")
- case <-time.After(lifetime*time.Second + defaultTimeout):
- }
-
- // Original route table should not have been modified.
- if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable)
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
}
@@ -1021,10 +972,10 @@ func TestRouterDiscovery(t *testing.T) {
t.Parallel()
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 10),
+ routerC: make(chan ndpRouterEvent, 1),
rememberRouter: true,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1034,22 +985,29 @@ func TestRouterDiscovery(t *testing.T) {
NDPDisp: &ndpDisp,
})
- waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) {
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
t.Helper()
select {
- case r := <-ndpDisp.routerC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != addr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
}
- if r.discovered != discovered {
- t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered)
+ default:
+ t.Fatal("expected router discovery event")
+ }
+ }
+
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
}
case <-time.After(timeout):
- t.Fatal("timeout waiting for router discovery event")
+ t.Fatal("timed out waiting for router discovery event")
}
}
@@ -1063,40 +1021,25 @@ func TestRouterDiscovery(t *testing.T) {
select {
case <-ndpDisp.routerC:
t.Fatal("unexpectedly discovered a router with 0 lifetime")
- case <-time.After(defaultTimeout):
+ default:
}
// Rx an RA from lladdr2 with a huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- waitForEvent(llAddr2, true, defaultTimeout)
-
- // Should have a default route through the discovered router.
- if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectRouterEvent(llAddr2, true)
// Rx an RA from another router (lladdr3) with non-zero lifetime.
- l3Lifetime := time.Duration(6)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime)))
- waitForEvent(llAddr3, true, defaultTimeout)
-
- // Should have default routes through the discovered routers.
- if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
+ expectRouterEvent(llAddr3, true)
// Rx an RA from lladdr2 with lesser lifetime.
- l2Lifetime := time.Duration(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime)))
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
select {
case <-ndpDisp.routerC:
t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- case <-time.After(defaultTimeout):
- }
-
- // Should still have a default route through the discovered routers.
- if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ default:
}
// Wait for lladdr2's router invalidation timer to fire. The lifetime
@@ -1106,31 +1049,15 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout)
-
- // Should no longer have the default route through lladdr2.
- if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- waitForEvent(llAddr2, true, defaultTimeout)
-
- // Should have a default route through the discovered routers.
- if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectRouterEvent(llAddr2, true)
// Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- waitForEvent(llAddr2, false, defaultTimeout)
-
- // Should have deleted the default route through the router that just
- // got invalidated.
- if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectRouterEvent(llAddr2, false)
// Wait for lladdr3's router invalidation timer to fire. The lifetime
// of the router should have been updated to the most recent (smaller)
@@ -1139,13 +1066,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout)
-
- // Should not have any routes now that all discovered routers have been
- // invalidated.
- if got := len(s.GetRouteTable()); got != 0 {
- t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got)
- }
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1154,10 +1075,10 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
t.Parallel()
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 10),
+ routerC: make(chan ndpRouterEvent, 1),
rememberRouter: true,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1171,8 +1092,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
t.Fatalf("CreateNIC(1) = %s", err)
}
- expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{}
-
// Receive an RA from 2 more than the max number of discovered routers.
for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
@@ -1182,43 +1101,28 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
if i <= stack.MaxDiscoveredDefaultRouters {
- expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1}
select {
- case r := <-ndpDisp.routerC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != llAddr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr)
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
}
- if !r.discovered {
- t.Fatal("got r.discovered = false, want = true")
- }
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for router discovery event")
+ default:
+ t.Fatal("expected router discovery event")
}
} else {
select {
case <-ndpDisp.routerC:
t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
- case <-time.After(defaultTimeout):
+ default:
}
}
}
-
- // Should only have default routes for the first
- // stack.MaxDiscoveredDefaultRouters discovered routers.
- if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt)
- }
}
// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
// configured not to.
func TestNoPrefixDiscovery(t *testing.T) {
- t.Parallel()
-
prefix := tcpip.AddressWithPrefix{
Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
PrefixLen: 64,
@@ -1239,9 +1143,9 @@ func TestNoPrefixDiscovery(t *testing.T) {
t.Parallel()
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 10),
+ prefixC: make(chan ndpPrefixEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1262,12 +1166,18 @@ func TestNoPrefixDiscovery(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly discovered a prefix when configured not to")
- case <-time.After(defaultTimeout):
+ default:
}
})
}
}
+// Check e to make sure that the event is for prefix on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
+ return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered on-link prefix when the dispatcher asks it not to.
func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
@@ -1276,9 +1186,9 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
prefix, subnet, _ := prefixSubnetAddr(0, "")
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 10),
+ prefixC: make(chan ndpPrefixEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1293,50 +1203,25 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
t.Fatalf("CreateNIC(1) = %s", err)
}
- routeTable := []tcpip.Route{
- {
- header.IPv6EmptySubnet,
- llAddr3,
- 1,
- },
- }
- s.SetRouteTable(routeTable)
-
- // Rx an RA with prefix with a short lifetime.
- const lifetime = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0))
+ // Receive an RA with prefix that we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
select {
- case r := <-ndpDisp.prefixC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- if r.prefix != subnet {
- t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet)
- }
- if !r.discovered {
- t.Fatal("got r.discovered = false, want = true")
- }
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for prefix discovery event")
- }
-
- // Original route table should not have been modified.
- if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable)
+ default:
+ t.Fatal("expected prefix discovery event")
}
- // Wait for the normal invalidation time plus some buffer to
- // make sure we do not actually receive any invalidation events as
- // we should not have remembered the prefix in the first place.
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the prefix in the first place.
select {
case <-ndpDisp.prefixC:
t.Fatal("should not have received any prefix events")
- case <-time.After(lifetime*time.Second + defaultTimeout):
- }
-
- // Original route table should not have been modified.
- if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable)
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
}
@@ -1348,10 +1233,10 @@ func TestPrefixDiscovery(t *testing.T) {
prefix3, subnet3, _ := prefixSubnetAddr(2, "")
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 10),
+ prefixC: make(chan ndpPrefixEvent, 1),
rememberPrefix: true,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1361,74 +1246,48 @@ func TestPrefixDiscovery(t *testing.T) {
NDPDisp: &ndpDisp,
})
- waitForEvent := func(subnet tcpip.Subnet, discovered bool, timeout time.Duration) {
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
t.Helper()
select {
- case r := <-ndpDisp.prefixC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.prefix != subnet {
- t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet)
- }
- if r.discovered != discovered {
- t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered)
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
- t.Fatal("timeout waiting for prefix discovery event")
+ default:
+ t.Fatal("expected prefix discovery event")
}
}
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with zero valid lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
- case <-time.After(defaultTimeout):
+ default:
}
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
- waitForEvent(subnet1, true, defaultTimeout)
-
- // Should have added a device route for subnet1 through the nic.
- if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectPrefixEvent(subnet1, true)
// Receive an RA with prefix2 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
- waitForEvent(subnet2, true, defaultTimeout)
-
- // Should have added a device route for subnet2 through the nic.
- if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectPrefixEvent(subnet2, true)
// Receive an RA with prefix3 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
- waitForEvent(subnet3, true, defaultTimeout)
-
- // Should have added a device route for subnet3 through the nic.
- if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectPrefixEvent(subnet3, true)
// Receive an RA with prefix1 in a PI with lifetime = 0.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- waitForEvent(subnet1, false, defaultTimeout)
-
- // Should have removed the device route for subnet1 through the nic.
- if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
- }
+ expectPrefixEvent(subnet1, false)
// Receive an RA with prefix2 in a PI with lesser lifetime.
lifetime := uint32(2)
@@ -1436,31 +1295,23 @@ func TestPrefixDiscovery(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly received prefix event when updating lifetime")
- case <-time.After(defaultTimeout):
- }
-
- // Should not have updated route table.
- if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ default:
}
// Wait for prefix2's most recent invalidation timer plus some buffer to
// expire.
- waitForEvent(subnet2, false, time.Duration(lifetime)*time.Second+defaultTimeout)
-
- // Should have removed the device route for subnet2 through the nic.
- if got, want := s.GetRouteTable(), []tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
}
// Receive RA to invalidate prefix3.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
- waitForEvent(subnet3, false, defaultTimeout)
-
- // Should not have any routes.
- if got := len(s.GetRouteTable()); got != 0 {
- t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got)
- }
+ expectPrefixEvent(subnet3, false)
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
@@ -1482,10 +1333,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
subnet := prefix.Subnet()
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 10),
+ prefixC: make(chan ndpPrefixEvent, 1),
rememberPrefix: true,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1495,33 +1346,27 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
NDPDisp: &ndpDisp,
})
- waitForEvent := func(discovered bool, timeout time.Duration) {
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
t.Helper()
select {
- case r := <-ndpDisp.prefixC:
- if r.nicID != 1 {
- t.Errorf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.prefix != subnet {
- t.Errorf("got r.prefix = %s, want = %s", r.prefix, subnet)
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- if r.discovered != discovered {
- t.Errorf("got r.discovered = %t, want = %t", r.discovered, discovered)
- }
- case <-time.After(timeout):
- t.Fatal("timeout waiting for prefix discovery event")
+ default:
+ t.Fatal("expected prefix discovery event")
}
}
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
// Receive an RA with prefix in an NDP Prefix Information option (PI)
// with infinite valid lifetime which should not get invalidated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
- waitForEvent(true, defaultTimeout)
+ expectPrefixEvent(subnet, true)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1531,11 +1376,18 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with finite lifetime.
// The prefix should get invalidated after 1s.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
- waitForEvent(false, testInfiniteLifetime)
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(testInfiniteLifetime):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
// Receive an RA with finite lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
- waitForEvent(true, defaultTimeout)
+ expectPrefixEvent(subnet, true)
// Receive an RA with prefix with an infinite lifetime.
// The prefix should not be invalidated.
@@ -1558,7 +1410,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with 0 lifetime.
// The prefix should get invalidated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
- waitForEvent(false, defaultTimeout)
+ expectPrefixEvent(subnet, false)
}
// TestPrefixDiscoveryMaxRouters tests that only
@@ -1570,7 +1422,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1586,7 +1438,6 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
- expectedRt := [stack.MaxDiscoveredOnLinkPrefixes]tcpip.Route{}
prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
// Receive an RA with 2 more than the max number of discovered on-link
@@ -1606,43 +1457,27 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
copy(buf[14:], prefix.Address)
optSer[i] = header.NDPPrefixInformation(buf[:])
-
- if i < stack.MaxDiscoveredOnLinkPrefixes {
- expectedRt[i] = tcpip.Route{prefixes[i], tcpip.Address([]byte(nil)), 1}
- }
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
if i < stack.MaxDiscoveredOnLinkPrefixes {
select {
- case r := <-ndpDisp.prefixC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.prefix != prefixes[i] {
- t.Fatalf("got r.prefix = %s, want = %s", r.prefix, prefixes[i])
- }
- if !r.discovered {
- t.Fatal("got r.discovered = false, want = true")
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for prefix discovery event")
+ default:
+ t.Fatal("expected prefix discovery event")
}
} else {
select {
case <-ndpDisp.prefixC:
t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes")
- case <-time.After(defaultTimeout):
+ default:
}
}
}
-
- // Should only have device routes for the first
- // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes.
- if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) {
- t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt)
- }
}
// Checks to see if list contains an IPv6 address, item.
@@ -1663,8 +1498,6 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
func TestNoAutoGenAddr(t *testing.T) {
- t.Parallel()
-
prefix, _, _ := prefixSubnetAddr(0, "")
// Being configured to auto-generate addresses means handle and
@@ -1682,9 +1515,9 @@ func TestNoAutoGenAddr(t *testing.T) {
t.Parallel()
ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1705,12 +1538,18 @@ func TestNoAutoGenAddr(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address when configured not to")
- case <-time.After(defaultTimeout):
+ default:
}
})
}
}
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// event type is set to eventType.
+func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
+ return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
+}
+
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
func TestAutoGenAddr(t *testing.T) {
@@ -1726,9 +1565,9 @@ func TestAutoGenAddr(t *testing.T) {
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1738,42 +1577,36 @@ func TestAutoGenAddr(t *testing.T) {
NDPDisp: &ndpDisp,
})
- waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
select {
- case r := <-ndpDisp.autoGenAddrC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- if r.addr != addr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
- }
- if r.eventType != eventType {
- t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType)
- }
- case <-time.After(timeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ default:
+ t.Fatal("expected addr auto gen event")
}
}
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with zero valid lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
- case <-time.After(defaultTimeout):
+ default:
}
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- waitForEvent(addr1, newAddr, defaultTimeout)
+ expectAutoGenAddrEvent(addr1, newAddr)
if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -1784,12 +1617,12 @@ func TestAutoGenAddr(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
- case <-time.After(defaultTimeout):
+ default:
}
// Receive an RA with prefix2 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- waitForEvent(addr2, newAddr, defaultTimeout)
+ expectAutoGenAddrEvent(addr2, newAddr)
if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -1802,11 +1635,18 @@ func TestAutoGenAddr(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
- case <-time.After(defaultTimeout):
+ default:
}
// Wait for addr of prefix1 to be invalidated.
- waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should not have %s in the list of addresses", addr1)
}
@@ -1815,12 +1655,529 @@ func TestAutoGenAddr(t *testing.T) {
}
}
+// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
+// channel.Endpoint and stack.Stack.
+//
+// stack.Stack will have a default route through the router (llAddr3) installed
+// and a static link-address (linkAddr3) added to the link address cache for the
+// router.
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+ t.Helper()
+ ndpDisp := &ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ return ndpDisp, e, s
+}
+
+// addrForNewConnection returns the local address used when creating a new
+// connection.
+func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// addrForNewConnectionWithAddr returns the local address used when creating a
+// new connection with a specific local address.
+func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Bind(addr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", addr, err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
+// receiving a PI with 0 preferred lifetime.
+func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ }
+
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ }
+
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+}
+
+// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// when its preferred lifetime expires.
+func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+ const nicID = 1
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout)
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
+ }
+
+ case <-time.After(newMinVLDuration + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+}
+
+// Tests transitioning a SLAAC address's valid lifetime between finite and
+// infinite values.
+func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
+ const infiniteVLSeconds = 2
+ const minVLSeconds = 1
+ savedIL := header.NDPInfiniteLifetime
+ savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ header.NDPInfiniteLifetime = savedIL
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ infiniteVL uint32
+ }{
+ {
+ name: "EqualToInfiniteVL",
+ infiniteVL: infiniteVLSeconds,
+ },
+ // Our implementation supports changing header.NDPInfiniteLifetime for tests
+ // such that a packet can be received where the lifetime field has a value
+ // greater than header.NDPInfiniteLifetime. Because of this, we test to make
+ // sure that receiving a value greater than header.NDPInfiniteLifetime is
+ // handled the same as when receiving a value equal to
+ // header.NDPInfiniteLifetime.
+ {
+ name: "MoreThanInfiniteVL",
+ infiniteVL: infiniteVLSeconds + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(minVLSeconds*time.Second + defaultTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
// auto-generated address only gets updated when required to, as specified in
// RFC 4862 section 5.5.3.e.
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
- const newMinVL = 5
+ const newMinVL = 4
saved := stack.MinPrefixInformationValidLifetimeForUpdate
defer func() {
stack.MinPrefixInformationValidLifetimeForUpdate = saved
@@ -1896,78 +2253,81 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const delta = 500 * time.Millisecond
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
- }
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
- // Receive an RA with prefix with initial VL, test.ovl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
- select {
- case r := <-ndpDisp.autoGenAddrC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != addr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
}
- if r.eventType != newAddr {
- t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
}
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
- }
- // Receive an new RA with prefix with new VL, test.nvl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- //
- // Validate that the VL for the address got set to
- // test.evl.
- //
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
- // Make sure we do not get any invalidation events
- // until atleast 500ms (delta) before test.evl.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - delta):
- }
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
- // Wait for another second (2x delta), but now we expect
- // the invalidation event.
- select {
- case r := <-ndpDisp.autoGenAddrC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != addr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ // Make sure we do not get any invalidation
+ // events until atleast 500ms (delta) before
+ // test.evl.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(time.Duration(test.evl)*time.Second - delta):
}
- if r.eventType != invalidatedAddr {
- t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+
+ // Wait for another second (2x delta), but now
+ // we expect the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(2 * delta):
+ t.Fatal("timeout waiting for addr auto gen event")
}
- case <-time.After(2 * delta):
- t.Fatal("timeout waiting for addr auto gen event")
- }
- })
- }
+ })
+ }
+ })
}
// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
@@ -1979,9 +2339,9 @@ func TestAutoGenAddrRemoval(t *testing.T) {
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -1995,51 +2355,37 @@ func TestAutoGenAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(1) = %s", err)
}
- // Receive an RA with prefix with its valid lifetime = lifetime.
- const lifetime = 5
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0))
- select {
- case r := <-ndpDisp.autoGenAddrC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != addr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
- }
- if r.eventType != newAddr {
- t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
}
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
}
- // Remove the address.
+ // Receive a PI to auto-generate an address.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Removing the address should result in an invalidation event
+ // immediately.
if err := s.RemoveAddress(1, addr.Address); err != nil {
t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
}
-
- // Should get the invalidation event immediately.
- select {
- case r := <-ndpDisp.autoGenAddrC:
- if r.nicID != 1 {
- t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
- }
- if r.addr != addr {
- t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
- }
- if r.eventType != invalidatedAddr {
- t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
- }
- case <-time.After(defaultTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
- }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
// Wait for the original valid lifetime to make sure the original timer
// got stopped/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly received an auto gen addr event")
- case <-time.After(lifetime*time.Second + defaultTimeout):
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
}
@@ -2051,9 +2397,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPConfigs: stack.NDPConfigurations{
@@ -2077,12 +2423,12 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
// Receive a PI where the generated address will be the same as the one
// that we already have assigned statically.
- const lifetime = 5
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0))
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
- case <-time.After(defaultTimeout):
+ default:
}
if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -2093,13 +2439,117 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetime*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
}
+// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
+// opaque interface identifiers when configured to do so.
+func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
+ t.Parallel()
+
+ const nicID = 1
+ const nicName = "nic1"
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
+ // addr1 and addr2 are the addresses that are expected to be generated when
+ // stack.Stack is configured to generate opaque interface identifiers as
+ // defined by RFC 7217.
+ addrBytes := []byte(subnet1.ID())
+ addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ addrBytes = []byte(subnet2.ID())
+ addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in a PI.
+ const validLifetimeSecondPrefix1 = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in a PI with a large valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+}
+
// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
// to the integrator when an RA is received with the NDP Recursive DNS Server
// option with at least one valid address.
@@ -2247,3 +2697,257 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
})
}
}
+
+// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers
+// and prefixes, and auto-generated addresses get invalidated when a NIC
+// becomes a router.
+func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
+ t.Parallel()
+
+ const (
+ lifetimeSeconds = 5
+ maxEvents = 4
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1)
+ e2Addr1 := addrForSubnet(subnet1, linkAddr2)
+ e2Addr2 := addrForSubnet(subnet2, linkAddr2)
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
+
+ return false, ndpRouterEvent{}
+ }
+
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
+
+ return false, ndpPrefixEvent{}
+ }
+
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
+
+ return false, ndpAutoGenAddrEvent{}
+ }
+
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and
+ // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from
+ // llAddr2) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
+
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
+
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
+
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
+
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
+
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !contains(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !contains(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !contains(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !contains(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ s.SetForwarding(true)
+
+ // Collect invalidation events after becoming a router
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < maxEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
+
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr1, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr2, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr1, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if contains(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if contains(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if contains(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if contains(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when we transitioned into a router).
+ time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index e8401c673..3810c6602 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -27,11 +27,11 @@ import (
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
- loopback bool
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ context NICContext
mu sync.RWMutex
spoofing bool
@@ -85,7 +85,7 @@ const (
)
// newNIC returns a new NIC using the default NDP configurations from stack.
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
@@ -99,7 +99,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
id: id,
name: name,
linkEP: ep,
- loopback: loopback,
+ context: ctx,
primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -174,23 +174,28 @@ func (n *NIC) enable() *tcpip.Error {
return err
}
- if !n.stack.autoGenIPv6LinkLocal {
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() {
return nil
}
- l2addr := n.linkEP.LinkAddress()
+ var addr tcpip.Address
+ if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
+ } else {
+ l2addr := n.linkEP.LinkAddress()
- // Only attempt to generate the link-local address if we have a
- // valid MAC address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
- }
+ // Only attempt to generate the link-local address if we have a valid MAC
+ // address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
- addr := header.LinkLocalAddr(l2addr)
+ addr = header.LinkLocalAddr(l2addr)
+ }
_, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
@@ -203,6 +208,18 @@ func (n *NIC) enable() *tcpip.Error {
return err
}
+// becomeIPv6Router transitions n into an IPv6 router.
+//
+// When transitioning into an IPv6 router, host-only state (NDP discovered
+// routers, discovered on-link prefixes, and auto-generated addresses) will
+// be cleaned up/invalidated.
+func (n *NIC) becomeIPv6Router() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.ndp.cleanupHostOnlyState()
+}
+
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
// to start delivering packets.
func (n *NIC) attachLinkEndpoint() {
@@ -223,6 +240,10 @@ func (n *NIC) isPromiscuousMode() bool {
return rv
}
+func (n *NIC) isLoopback() bool {
+ return n.linkEP.Capabilities()&CapabilityLoopback != 0
+}
+
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
@@ -232,17 +253,47 @@ func (n *NIC) setSpoofing(enable bool) {
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
+//
+// primaryEndpoint will return the first non-deprecated endpoint if such an
+// endpoint exists. If no non-deprecated endpoint exists, the first deprecated
+// endpoint will be returned.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
n.mu.RLock()
defer n.mu.RUnlock()
+ var deprecatedEndpoint *referencedNetworkEndpoint
for _, r := range n.primary[protocol] {
- if r.isValidForOutgoing() && r.tryIncRef() {
- return r
+ if !r.isValidForOutgoing() {
+ continue
+ }
+
+ if !r.deprecated {
+ if r.tryIncRef() {
+ // r is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ deprecatedEndpoint.decRefLocked()
+ deprecatedEndpoint = nil
+ }
+
+ return r
+ }
+ } else if deprecatedEndpoint == nil && r.tryIncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of r in
+ // case n doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, r's reference count
+ // will be decremented when such an endpoint is found.
+ deprecatedEndpoint = r
}
}
- return nil
+ // n doesn't have any valid non-deprecated endpoints, so return
+ // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
+ // endpoints either).
+ return deprecatedEndpoint
}
// hasPermanentAddrLocked returns true if n has a permanent (including currently
@@ -350,7 +401,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, temporary, static)
+ }, peb, temporary, static, false)
n.mu.Unlock()
return ref
@@ -399,10 +450,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent, static)
+ return n.addAddressLocked(protocolAddress, peb, permanent, static, false)
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
// TODO(b/141022673): Validate IP address before adding them.
// Sanity check.
@@ -438,6 +489,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
protocol: protocolAddress.Protocol,
kind: kind,
configType: configType,
+ deprecated: deprecated,
}
// Set up cache if link address resolution exists for this protocol.
@@ -536,6 +588,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
return addrs
}
+// primaryAddress returns the primary address associated with this NIC.
+//
+// primaryAddress will return the first non-deprecated address if such an
+// address exists. If no non-deprecated address exists, the first deprecated
+// address will be returned.
+func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list, ok := n.primary[proto]
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints to avoid confusion
+ // and prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentTentative, permanentExpired, temporary:
+ continue
+ }
+
+ if !ref.deprecated {
+ return tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ }
+ }
+
+ if deprecatedEndpoint == nil {
+ deprecatedEndpoint = ref
+ }
+ }
+
+ if deprecatedEndpoint != nil {
+ return tcpip.AddressWithPrefix{
+ Address: deprecatedEndpoint.ep.ID().LocalAddress,
+ PrefixLen: deprecatedEndpoint.ep.PrefixLen(),
+ }
+ }
+
+ return tcpip.AddressWithPrefix{}
+}
+
// AddAddressRange adds a range of addresses to n, so that it starts accepting
// packets targeted at the given addresses and network protocol. The range is
// given by a subnet address, and all addresses contained in the subnet are
@@ -1092,6 +1189,11 @@ type referencedNetworkEndpoint struct {
// configType is the method that was used to configure this endpoint.
// This must never change after the endpoint is added to a NIC.
configType networkEndpointConfigType
+
+ // deprecated indicates whether or not the endpoint should be considered
+ // deprecated. That is, when deprecated is true, other endpoints that are not
+ // deprecated should be preferred.
+ deprecated bool
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 61fd46d66..2b8751d49 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -234,15 +234,15 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 34307ae07..517f4b941 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt)
+ err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop)
+ n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n))
}
@@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0e88643a4..41bf9fd9b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -352,6 +352,38 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
return atomic.AddUint64((*uint64)(u), 1)
}
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it will be passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -412,8 +444,8 @@ type Stack struct {
ndpConfigs NDPConfigurations
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
- // See the AutoGenIPv6LinkLocal field of Options for more details.
+ // to auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
autoGenIPv6LinkLocal bool
// ndpDisp is the NDP event dispatcher that is used to send the netstack
@@ -422,6 +454,10 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
}
// UniqueID is an abstract generator of unique identifiers.
@@ -460,13 +496,15 @@ type Options struct {
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
// Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address
- // Detection is enabled, an address will only be assigned if it
- // successfully resolves. If it fails, no further attempt will be made
- // to auto-generate an IPv6 link-local address.
+ // will be assigned right away, or at all. If Duplicate Address Detection
+ // is enabled, an address will only be assigned if it successfully resolves.
+ // If it fails, no further attempt will be made to auto-generate an IPv6
+ // link-local address.
//
// The generated link-local address will follow RFC 4291 Appendix A
// guidelines.
@@ -479,6 +517,10 @@ type Options struct {
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
+
+ // OpaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -549,6 +591,7 @@ func New(opts Options) *Stack {
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
}
// Add specified network protocols.
@@ -662,11 +705,31 @@ func (s *Stack) Stats() tcpip.Stats {
}
// SetForwarding enables or disables the packet forwarding between NICs.
+//
+// When forwarding becomes enabled, any host-only state on all NICs will be
+// cleaned up.
func (s *Stack) SetForwarding(enable bool) {
// TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If forwarding status didn't change, do nothing further.
+ if s.forwarding == enable {
+ return
+ }
+
s.forwarding = enable
- s.mu.Unlock()
+
+ // If this stack does not support IPv6, do nothing further.
+ if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok {
+ return
+ }
+
+ if enable {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Router()
+ }
+ }
}
// Forwarding returns if the packet forwarding between NICs is enabled.
@@ -733,9 +796,30 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
}
-// createNIC creates a NIC with the provided id and link-layer endpoint, and
-// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
+// NICContext is an opaque pointer used to store client-supplied NIC metadata.
+type NICContext interface{}
+
+// NICOptions specifies the configuration of a NIC as it is being created.
+// The zero value creates an enabled, unnamed NIC.
+type NICOptions struct {
+ // Name specifies the name of the NIC.
+ Name string
+
+ // Disabled specifies whether to avoid calling Attach on the passed
+ // LinkEndpoint.
+ Disabled bool
+
+ // Context specifies user-defined data that will be returned in stack.NICInfo
+ // for the NIC. Clients of this library can use it to add metadata that
+ // should be tracked alongside a NIC, to avoid having to keep a
+ // map[tcpip.NICID]metadata mirroring stack.Stack's nic map.
+ Context NICContext
+}
+
+// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
+// NICOptions. See the documentation on type NICOptions for details on how
+// NICs can be configured.
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -744,44 +828,20 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled,
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, name, ep, loopback)
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
- if enabled {
+ if !opts.Disabled {
return n.enable()
}
return nil
}
-// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
+// `LinkEndpoint.Attach` to start delivering packets to it.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, true, false)
-}
-
-// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
-// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, false)
-}
-
-// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
-// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, true)
-}
-
-// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
-// but leave it disable. Stack.EnableNIC must be called before the link-layer
-// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, false, false)
-}
-
-// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
-// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, false, false)
+ return s.CreateNICWithOptions(id, ep, NICOptions{})
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -835,6 +895,18 @@ type NICInfo struct {
MTU uint32
Stats NICStats
+
+ // Context is user-supplied data optionally supplied in CreateNICWithOptions.
+ // See type NICOptions for more details.
+ Context NICContext
+}
+
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
}
// NICInfo returns a map of NICIDs to their associated information.
@@ -848,7 +920,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Up: true, // Netstack interfaces are always up.
Running: nic.linkEP.IsAttached(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ Loopback: nic.isLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
@@ -857,6 +929,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
+ Context: nic.context,
}
}
return nics
@@ -973,9 +1046,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
return nics
}
-// GetMainNICAddress returns the first primary address and prefix for the given
-// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
-// value if the NIC doesn't have a primary address for the given protocol.
+// GetMainNICAddress returns the first non-deprecated primary address and prefix
+// for the given NIC and protocol. If no non-deprecated primary address exists,
+// a deprecated primary address and prefix will be returned. Returns an error if
+// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
+// address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -985,12 +1060,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- for _, a := range nic.PrimaryAddresses() {
- if a.Protocol == protocol {
- return a.AddressWithPrefix, nil
- }
- }
- return tcpip.AddressWithPrefix{}, nil
+ return nic.primaryAddress(protocol), nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -1012,7 +1082,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok {
if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
}
} else {
@@ -1028,7 +1098,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
remoteAddr = ref.ep.ID().LocalAddress
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback)
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
if needRoute {
r.NextHop = route.Gateway
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8fc034ca1..44e5229cc 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,10 +27,12 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -122,7 +124,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -133,7 +135,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -141,7 +143,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
@@ -149,11 +151,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -1894,55 +1896,67 @@ func TestNICForwarding(t *testing.T) {
}
// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
-// (or lack there-of if disabled (default)). Note, DAD will be disabled in
-// these tests.
+// using the modified EUI-64 of the NIC's MAC address (or lack there-of if
+// disabled (default)). Note, DAD will be disabled in these tests.
func TestNICAutoGenAddr(t *testing.T) {
tests := []struct {
name string
autoGen bool
linkAddr tcpip.LinkAddress
+ iidOpts stack.OpaqueInterfaceIdentifierOptions
shouldGen bool
}{
{
"Disabled",
false,
linkAddr1,
+ stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(nicID tcpip.NICID, _ string) string {
+ return fmt.Sprintf("nic%d", nicID)
+ },
+ },
false,
},
{
"Enabled",
true,
linkAddr1,
+ stack.OpaqueInterfaceIdentifierOptions{},
true,
},
{
"Nil MAC",
true,
tcpip.LinkAddress([]byte(nil)),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Empty MAC",
true,
tcpip.LinkAddress(""),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Invalid MAC",
true,
tcpip.LinkAddress("\x01\x02\x03"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Multicast MAC",
true,
tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Unspecified MAC",
true,
tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
}
@@ -1951,13 +1965,12 @@ func TestNICAutoGenAddr(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ OpaqueIIDOpts: test.iidOpts,
}
if test.autoGen {
- // Only set opts.AutoGenIPv6LinkLocal when
- // test.autoGen is true because
- // opts.AutoGenIPv6LinkLocal should be false by
- // default.
+ // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by default.
opts.AutoGenIPv6LinkLocal = true
}
@@ -1973,9 +1986,171 @@ func TestNICAutoGenAddr(t *testing.T) {
}
if test.shouldGen {
+ // Should have auto-generated an address and resolved immediately (DAD
+ // is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
+ }
+}
+
+// TestNICContextPreservation tests that you can read out via stack.NICInfo the
+// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
+func TestNICContextPreservation(t *testing.T) {
+ var ctx *int
+ tests := []struct {
+ name string
+ opts stack.NICOptions
+ want stack.NICContext
+ }{
+ {
+ "context_set",
+ stack.NICOptions{Context: ctx},
+ ctx,
+ },
+ {
+ "context_not_set",
+ stack.NICOptions{},
+ nil,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ id := tcpip.NICID(1)
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
+ t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
+ }
+ nicinfos := s.NICInfo()
+ nicinfo, ok := nicinfos[id]
+ if !ok {
+ t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
+ }
+ if got, want := nicinfo.Context == test.want, true; got != want {
+ t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local
+// addresses with opaque interface identifiers. Link Local addresses should
+// always be generated with opaque IIDs if configured to use them, even if the
+// NIC has an invalid MAC address.
+func TestNICAutoGenAddrWithOpaque(t *testing.T) {
+ const nicID = 1
+
+ var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
+ n, err := rand.Read(secretKey[:])
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
+ }
+
+ tests := []struct {
+ name string
+ nicName string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ secretKey []byte
+ }{
+ {
+ name: "Disabled",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ secretKey: secretKey[:],
+ },
+ {
+ name: "Enabled",
+ nicName: "nic1",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ secretKey: secretKey[:],
+ },
+ // These are all cases where we would not have generated a
+ // link-local address if opaque IIDs were disabled.
+ {
+ name: "Nil MAC and empty nicName",
+ nicName: "",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress([]byte(nil)),
+ secretKey: secretKey[:1],
+ },
+ {
+ name: "Empty MAC and empty nicName",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress(""),
+ secretKey: secretKey[:2],
+ },
+ {
+ name: "Invalid MAC",
+ nicName: "test",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x01\x02\x03"),
+ secretKey: secretKey[:3],
+ },
+ {
+ name: "Multicast MAC",
+ nicName: "test2",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ secretKey: secretKey[:4],
+ },
+ {
+ name: "Unspecified MAC and nil SecretKey",
+ nicName: "test3",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: test.secretKey,
+ },
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: test.nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+
+ if test.autoGen {
// Should have auto-generated an address and
// resolved immediately (DAD is disabled).
- if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
}
} else {
@@ -1988,6 +2163,56 @@ func TestNICAutoGenAddr(t *testing.T) {
}
}
+// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
+// not auto-generated for loopback NICs.
+func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
+ const nicID = 1
+ const nicName = "nicName"
+
+ tests := []struct {
+ name string
+ opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ }{
+ {
+ name: "IID From MAC",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ },
+ {
+ name: "Opaque IID",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ }
+
+ e := loopback.New()
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ }
+ })
+ }
+}
+
// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
// link-local addresses will only be assigned after the DAD process resolves.
func TestNICAutoGenAddrDoesDAD(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 3b28b06d0..5e9237de9 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -41,7 +41,7 @@ const (
type testContext struct {
t *testing.T
- linkEPs map[string]*channel.Endpoint
+ linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
@@ -61,35 +61,29 @@ func (c *testContext) createV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
}
-// newDualTestContextMultiNic creates the testing context and also linkEpNames
-// named NICs.
-func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEPs := make(map[string]*channel.Endpoint)
- for i, linkEpName := range linkEpNames {
- channelEP := channel.New(256, mtu, "")
- nicID := tcpip.NICID(i + 1)
- if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil {
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- linkEPs[linkEpName] = channelEP
+ linkEps[linkEpID] = channelEp
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
@@ -108,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string)
return &testContext{
t: t,
s: s,
- linkEPs: linkEPs,
+ linkEps: linkEps,
}
}
@@ -125,7 +119,7 @@ func newPayload() []byte {
return b
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -156,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -186,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
func TestDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
- bindToDevice string
+ bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
@@ -194,71 +188,71 @@ func TestDistribution(t *testing.T) {
endpoints []endpointSockopts
// wantedDistribution is the wanted ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[string][]float64
+ wantedDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
- "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{0, "dev0"},
- endpointSockopts{0, "dev1"},
- endpointSockopts{0, "dev2"},
+ {0, 1},
+ {0, 2},
+ {0, 3},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
- "dev0": []float64{1, 0, 0},
+ 1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
- "dev1": []float64{0, 1, 0},
+ 2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
- "dev2": []float64{0, 0, 1},
+ 3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, ""},
+ {1, 1},
+ {1, 1},
+ {1, 2},
+ {1, 2},
+ {1, 2},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
- "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ 1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
- "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
- "dev999": []float64{0, 0, 0, 0, 0, 1},
+ 1000: {0, 0, 0, 0, 0, 1},
},
},
} {
t.Run(test.name, func(t *testing.T) {
for device, wantedDistribution := range test.wantedDistributions {
- t.Run(device, func(t *testing.T) {
- var devices []string
+ t.Run(string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
for d := range test.wantedDistributions {
devices = append(devices, d)
}
- c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
defer c.cleanup()
c.createV6Endpoint(false)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 748ce4ea5..f50604a8a 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -102,13 +102,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f62fd729f..72b5ce179 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -423,17 +423,25 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptBool sets a socket option, for simple cases where a value
+ // has the bool type.
+ SetSockOptBool(opt SockOptBool, v bool) *Error
+
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
- SetSockOptInt(opt SockOpt, v int) *Error
+ SetSockOptInt(opt SockOptInt, v int) *Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
+ // GetSockOptBool gets a socket option for simple cases where a return
+ // value has the bool type.
+ GetSockOptBool(SockOptBool) (bool, *Error)
+
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
- GetSockOptInt(SockOpt) (int, *Error)
+ GetSockOptInt(SockOptInt) (int, *Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -488,13 +496,22 @@ type WriteOptions struct {
Atomic bool
}
-// SockOpt represents socket options which values have the int type.
-type SockOpt int
+// SockOptBool represents socket options which values have the bool type.
+type SockOptBool int
+
+const (
+ // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
+ // socket is to be restricted to sending and receiving IPv6 packets only.
+ V6OnlyOption SockOptBool = iota
+)
+
+// SockOptInt represents socket options which values have the int type.
+type SockOptInt int
const (
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the input buffer should be returned.
- ReceiveQueueSizeOption SockOpt = iota
+ ReceiveQueueSizeOption SockOptInt = iota
// SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the send buffer size option.
@@ -521,10 +538,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
-// socket is to be restricted to sending and receiving IPv6 packets only.
-type V6OnlyOption int
-
// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
// held until segments are full by the TCP transport protocol.
type CorkOption int
@@ -539,7 +552,7 @@ type ReusePortOption int
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
-type BindToDeviceOption string
+type BindToDeviceOption NICID
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
new file mode 100644
index 000000000..f5f01f32f
--- /dev/null
+++ b/pkg/tcpip/timer.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "sync"
+ "time"
+)
+
+// cancellableTimerInstance is a specific instance of CancellableTimer.
+//
+// Different instances are created each time CancellableTimer is Reset so each
+// timer has its own earlyReturn signal. This is to address a bug when a
+// CancellableTimer is stopped and reset in quick succession resulting in a
+// timer instance's earlyReturn signal being affected or seen by another timer
+// instance.
+//
+// Consider the following sceneario where timer instances share a common
+// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
+// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
+// (B), third (C), and fourth (D) instance of the timer firing, respectively):
+// T1: Obtain L
+// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T2: instance A fires, blocked trying to obtain L.
+// T1: Attempt to stop instance A (set earlyReturn = true)
+// T1: Reset timer (create instance B)
+// T3: instance B fires, blocked trying to obtain L.
+// T1: Attempt to stop instance B (set earlyReturn = true)
+// T1: Reset timer (create instance C)
+// T4: instance C fires, blocked trying to obtain L.
+// T1: Attempt to stop instance C (set earlyReturn = true)
+// T1: Reset timer (create instance D)
+// T5: instance D fires, blocked trying to obtain L.
+// T1: Release L
+//
+// Now that T1 has released L, any of the 4 timer instances can take L and check
+// earlyReturn. If the timers simply check earlyReturn and then do nothing
+// further, then instance D will never early return even though it was not
+// requested to stop. If the timers reset earlyReturn before early returning,
+// then all but one of the timers will do work when only one was expected to.
+// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// will fire (again, when only one was expected to).
+//
+// To address the above concerns the simplest solution was to give each timer
+// its own earlyReturn signal.
+type cancellableTimerInstance struct {
+ timer *time.Timer
+
+ // Used to inform the timer to early return when it gets stopped while the
+ // lock the timer tries to obtain when fired is held (T1 is a goroutine that
+ // tries to cancel the timer and T2 is the goroutine that handles the timer
+ // firing):
+ // T1: Obtain the lock, then call StopLocked()
+ // T2: timer fires, and gets blocked on obtaining the lock
+ // T1: Releases lock
+ // T2: Obtains lock does unintended work
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using earlyReturn to return early so that once T2 obtains
+ // the lock, it will see that it is set to true and do nothing further.
+ earlyReturn *bool
+}
+
+// stop stops the timer instance t from firing if it hasn't fired already. If it
+// has fired and is blocked at obtaining the lock, earlyReturn will be set to
+// true so that it will early return when it obtains the lock.
+func (t *cancellableTimerInstance) stop() {
+ if t.timer != nil {
+ t.timer.Stop()
+ *t.earlyReturn = true
+ }
+}
+
+// CancellableTimer is a timer that does some work and can be safely cancelled
+// when it fires at the same time some "related work" is being done.
+//
+// The term "related work" is defined as some work that needs to be done while
+// holding some lock that the timer must also hold while doing some work.
+type CancellableTimer struct {
+ // The active instance of a cancellable timer.
+ instance cancellableTimerInstance
+
+ // locker is the lock taken by the timer immediately after it fires and must
+ // be held when attempting to stop the timer.
+ //
+ // Must never change after being assigned.
+ locker sync.Locker
+
+ // fn is the function that will be called when a timer fires and has not been
+ // signaled to early return.
+ //
+ // fn MUST NOT attempt to lock locker.
+ //
+ // Must never change after being assigned.
+ fn func()
+}
+
+// StopLocked prevents the Timer from firing if it has not fired already.
+//
+// If the timer is blocked on obtaining the t.locker lock when StopLocked is
+// called, it will early return instead of calling t.fn.
+//
+// Note, t will be modified.
+//
+// t.locker MUST be locked.
+func (t *CancellableTimer) StopLocked() {
+ t.instance.stop()
+
+ // Nothing to do with the stopped instance anymore.
+ t.instance = cancellableTimerInstance{}
+}
+
+// Reset changes the timer to expire after duration d.
+//
+// Note, t will be modified.
+//
+// Reset should only be called on stopped or expired timers. To be safe, callers
+// should always call StopLocked before calling Reset.
+func (t *CancellableTimer) Reset(d time.Duration) {
+ // Create a new instance.
+ earlyReturn := false
+ t.instance = cancellableTimerInstance{
+ timer: time.AfterFunc(d, func() {
+ t.locker.Lock()
+ defer t.locker.Unlock()
+
+ if earlyReturn {
+ // If we reach this point, it means that the timer fired while another
+ // goroutine called StopLocked while it had the lock. Simply return
+ // here and do nothing further.
+ earlyReturn = false
+ return
+ }
+
+ t.fn()
+ }),
+ earlyReturn: &earlyReturn,
+ }
+}
+
+// MakeCancellableTimer returns an unscheduled CancellableTimer with the given
+// locker and fn.
+//
+// fn MUST NOT attempt to lock locker.
+//
+// Callers must call Reset to schedule the timer to fire.
+func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer {
+ return CancellableTimer{locker: locker, fn: fn}
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
new file mode 100644
index 000000000..1f735d735
--- /dev/null
+++ b/pkg/tcpip/timer_test.go
@@ -0,0 +1,236 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package timer_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ shortDuration = 1 * time.Nanosecond
+ middleDuration = 100 * time.Millisecond
+ longDuration = 1 * time.Second
+)
+
+func TestCancellableTimerFire(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.MakeCancellableTimer(&lock, func() {
+ ch <- struct{}{}
+ })
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromLongDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ timer.StopLocked()
+ lock.Unlock()
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerImmediatelyStop(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ for i := 0; i < 1000; i++ {
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ for i := 0; i < 10; i++ {
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration * 2)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for double the duration so timers that weren't correctly stopped can
+ // fire.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration * 2):
+ }
+}
+
+func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration)
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 9c40931b5..c7ce74cdd 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -350,13 +350,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return nil
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0010b5e5f..07ffa8aba 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -247,17 +247,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+ return tcpip.ErrUnknownProtocolOption
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -265,6 +265,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrNotSupported
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
ep.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 5aafe2615..85f7eb76b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -509,13 +509,38 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -544,21 +569,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
e.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index dfaa4a559..4f361b226 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) {
// Make sure we get the same error when calling the original ep and the
// new one. This validates that v4-mapped endpoints are still able to
// query the V6Only flag, whereas pure v4 endpoints are not.
- var v tcpip.V6OnlyOption
- expected := c.EP.GetSockOpt(&v)
- if err := nep.GetSockOpt(&v); err != expected {
+ _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
}
@@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) {
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
- var v tcpip.V6OnlyOption
- if err := nep.GetSockOpt(&v); err != nil {
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
t.Fatalf("GetSockOpt failed failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index dd8b47cbe..920b24975 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1145,8 +1145,31 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
+// SetSockOptBool sets a socket option.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v
+ }
+
+ return nil
+}
+
// SetSockOptInt sets a socket option.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
@@ -1256,19 +1279,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.QuickAckOption:
if v == 0 {
@@ -1289,23 +1307,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v != 0
- return nil
-
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -1446,8 +1447,27 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1525,12 +1545,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = ""
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.QuickAckOption:
@@ -1540,22 +1556,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -1974,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
+ if e.state == StateInitial {
+ // The listen is called on an unbound socket, the socket is
+ // automatically bound to a random free port with the local
+ // address set to INADDR_ANY.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return err
+ }
+ }
+
// Endpoint must be bound before it can transition to listen mode.
if e.state != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
@@ -2033,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
+ return e.bindLocked(addr)
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index e8fe4dab5..1aa0733d0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) {
func TestConnectBindToDevice(t *testing.T) {
for _, test := range []struct {
name string
- device string
+ device tcpip.NICID
want tcp.EndpointState
}{
- {"RightDevice", "nic1", tcp.StateEstablished},
- {"WrongDevice", "nic2", tcp.StateSynSent},
- {"AnyDevice", "", tcp.StateEstablished},
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
} {
t.Run(test.name, func(t *testing.T) {
c := context.New(t, defaultMTU)
@@ -3794,46 +3794,41 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
t.Errorf("CreateNIC failed: %v", err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -4027,12 +4022,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err)
}
case "dual":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err)
}
default:
t.Fatalf("unknown network: '%s'", network)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index b0a376eba..822907998 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -158,15 +158,17 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts2 := stack.NICOptions{Name: "nic2"}
+ if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
@@ -473,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.EP.SetSockOpt(v); err != nil {
+ if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1ac4705af..864dc8733 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -456,14 +456,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
- return nil
-}
-
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -478,8 +473,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.v6only = v != 0
+ e.v6only = v
+ }
+
+ return nil
+}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -624,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.BroadcastOption:
e.mu.Lock()
@@ -660,8 +662,27 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -695,22 +716,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -757,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 7051a7a9c..0a82bc4fa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
} else if flow.isBroadcast() {
@@ -508,46 +508,42 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "my_device"}
+ if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}