diff options
Diffstat (limited to 'pkg')
140 files changed, 8454 insertions, 2857 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 9553f164d..716ff22d2 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -41,6 +41,7 @@ go_library( "poll.go", "prctl.go", "ptrace.go", + "rseq.go", "rusage.go", "sched.go", "seccomp.go", diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go new file mode 100644 index 000000000..76253ba30 --- /dev/null +++ b/pkg/abi/linux/rseq.go @@ -0,0 +1,130 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Flags passed to rseq(2). +// +// Defined in include/uapi/linux/rseq.h. +const ( + // RSEQ_FLAG_UNREGISTER unregisters the current thread. + RSEQ_FLAG_UNREGISTER = 1 << 0 +) + +// Critical section flags used in RSeqCriticalSection.Flags and RSeq.Flags. +// +// Defined in include/uapi/linux/rseq.h. +const ( + // RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT inhibits restart on preemption. + RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT = 1 << 0 + + // RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL inhibits restart on signal + // delivery. + RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL = 1 << 1 + + // RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE inhibits restart on CPU + // migration. + RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE = 1 << 2 +) + +// RSeqCriticalSection describes a restartable sequences critical section. It +// is equivalent to struct rseq_cs, defined in include/uapi/linux/rseq.h. +// +// In userspace, this structure is always aligned to 32 bytes. +// +// +marshal +type RSeqCriticalSection struct { + // Version is the version of this structure. Version 0 is defined here. + Version uint32 + + // Flags are the critical section flags, defined above. + Flags uint32 + + // Start is the start address of the critical section. + Start uint64 + + // PostCommitOffset is the offset from Start of the first instruction + // outside of the critical section. + PostCommitOffset uint64 + + // Abort is the abort address. It must be outside the critical section, + // and the 4 bytes prior must match the abort signature. + Abort uint64 +} + +const ( + // SizeOfRSeqCriticalSection is the size of RSeqCriticalSection. + SizeOfRSeqCriticalSection = 32 + + // SizeOfRSeqSignature is the size of the signature immediately + // preceding RSeqCriticalSection.Abort. + SizeOfRSeqSignature = 4 +) + +// Special values for RSeq.CPUID, defined in include/uapi/linux/rseq.h. +const ( + // RSEQ_CPU_ID_UNINITIALIZED indicates that this thread has not + // performed rseq initialization. + RSEQ_CPU_ID_UNINITIALIZED = ^uint32(0) // -1 + + // RSEQ_CPU_ID_REGISTRATION_FAILED indicates that rseq initialization + // failed. + RSEQ_CPU_ID_REGISTRATION_FAILED = ^uint32(1) // -2 +) + +// RSeq is the thread-local restartable sequences config/status. It +// is equivalent to struct rseq, defined in include/uapi/linux/rseq.h. +// +// In userspace, this structure is always aligned to 32 bytes. +type RSeq struct { + // CPUIDStart contains the current CPU ID if rseq is initialized. + // + // This field should only be read by the thread which registered this + // structure, and must be read atomically. + CPUIDStart uint32 + + // CPUID contains the current CPU ID or one of the CPU ID special + // values defined above. + // + // This field should only be read by the thread which registered this + // structure, and must be read atomically. + CPUID uint32 + + // RSeqCriticalSection is a pointer to the current RSeqCriticalSection + // block, or NULL. It is reset to NULL by the kernel on restart or + // non-restarting preempt/signal. + // + // This field should only be written by the thread which registered + // this structure, and must be written atomically. + RSeqCriticalSection uint64 + + // Flags are the critical section flags that apply to all critical + // sections on this thread, defined above. + Flags uint32 +} + +const ( + // SizeOfRSeq is the size of RSeq. + // + // Note that RSeq is naively 24 bytes. However, it has 32-byte + // alignment, which in C increases sizeof to 32. That is the size that + // the Linux kernel uses. + SizeOfRSeq = 32 + + // AlignOfRSeq is the standard alignment of RSeq. + AlignOfRSeq = 32 + + // OffsetOfRSeqCriticalSection is the offset of RSeqCriticalSection in RSeq. + OffsetOfRSeqCriticalSection = 8 +) diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index 0c5f50397..ca540363c 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -14,7 +14,6 @@ go_library( "fspath.go", ], importpath = "gvisor.dev/gvisor/pkg/fspath", - deps = ["//pkg/syserror"], ) go_test( @@ -25,5 +24,4 @@ go_test( "fspath_test.go", ], embed = [":fspath"], - deps = ["//pkg/syserror"], ) diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go index f68752560..9fb3fee24 100644 --- a/pkg/fspath/fspath.go +++ b/pkg/fspath/fspath.go @@ -18,19 +18,17 @@ package fspath import ( "strings" - - "gvisor.dev/gvisor/pkg/syserror" ) const pathSep = '/' -// Parse parses a pathname as described by path_resolution(7). -func Parse(pathname string) (Path, error) { +// Parse parses a pathname as described by path_resolution(7), except that +// empty pathnames will be parsed successfully to a Path for which +// Path.Absolute == Path.Dir == Path.HasComponents() == false. (This is +// necessary to support AT_EMPTY_PATH.) +func Parse(pathname string) Path { if len(pathname) == 0 { - // "... POSIX decrees that an empty pathname must not be resolved - // successfully. Linux returns ENOENT in this case." - - // path_resolution(7) - return Path{}, syserror.ENOENT + return Path{} } // Skip leading path separators. i := 0 @@ -41,7 +39,7 @@ func Parse(pathname string) (Path, error) { return Path{ Absolute: true, Dir: true, - }, nil + } } } // Skip trailing path separators. This is required by Iterator.Next. This @@ -64,7 +62,7 @@ func Parse(pathname string) (Path, error) { }, Absolute: i != 0, Dir: j != len(pathname)-1, - }, nil + } } // Path contains the information contained in a pathname string. @@ -111,6 +109,12 @@ func (p Path) String() string { return b.String() } +// HasComponents returns true if p contains a non-zero number of path +// components. +func (p Path) HasComponents() bool { + return p.Begin.Ok() +} + // An Iterator represents either a path component in a Path or a terminal // iterator indicating that the end of the path has been reached. // diff --git a/pkg/fspath/fspath_test.go b/pkg/fspath/fspath_test.go index 215b35622..d5e9a549a 100644 --- a/pkg/fspath/fspath_test.go +++ b/pkg/fspath/fspath_test.go @@ -18,15 +18,10 @@ import ( "reflect" "strings" "testing" - - "gvisor.dev/gvisor/pkg/syserror" ) func TestParseIteratorPartialPathnames(t *testing.T) { - path, err := Parse("/foo//bar///baz////") - if err != nil { - t.Fatalf("Parse failed: %v", err) - } + path := Parse("/foo//bar///baz////") // Parse strips leading slashes, and records their presence as // Path.Absolute. if !path.Absolute { @@ -71,6 +66,12 @@ func TestParse(t *testing.T) { } tests := []testCase{ { + pathname: "", + relpath: []string{}, + abs: false, + dir: false, + }, + { pathname: "/", relpath: []string{}, abs: true, @@ -113,10 +114,7 @@ func TestParse(t *testing.T) { for _, test := range tests { t.Run(test.pathname, func(t *testing.T) { - p, err := Parse(test.pathname) - if err != nil { - t.Fatalf("failed to parse pathname %q: %v", test.pathname, err) - } + p := Parse(test.pathname) t.Logf("pathname %q => path %q", test.pathname, p) if p.Absolute != test.abs { t.Errorf("path absoluteness: got %v, wanted %v", p.Absolute, test.abs) @@ -134,10 +132,3 @@ func TestParse(t *testing.T) { }) } } - -func TestParseEmptyPathname(t *testing.T) { - p, err := Parse("") - if err != syserror.ENOENT { - t.Errorf("parsing empty pathname: got (%v, %v), wanted (<unspecified>, ENOENT)", p, err) - } -} diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 498ca4669..81ec98a77 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -125,9 +125,9 @@ type Context interface { // SetTLS sets the current TLS pointer. Returns false if value is invalid. SetTLS(value uintptr) bool - // SetRSEQInterruptedIP sets the register that contains the old IP when a - // restartable sequence is interrupted. - SetRSEQInterruptedIP(value uintptr) + // SetOldRSeqInterruptedIP sets the register that contains the old IP + // when an "old rseq" restartable sequence is interrupted. + SetOldRSeqInterruptedIP(value uintptr) // StateData returns a pointer to underlying architecture state. StateData() *State diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 67daa6c24..2aa08b1a9 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -174,8 +174,8 @@ func (c *context64) SetTLS(value uintptr) bool { return true } -// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP. -func (c *context64) SetRSEQInterruptedIP(value uintptr) { +// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP. +func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { c.Regs.R10 = uint64(value) } diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index c0a6e884b..a2f966cb6 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -555,6 +555,10 @@ type lockedWriter struct { // // This applies only to Write, not WriteAt. Offset int64 + + // Err contains the first error encountered while copying. This is + // useful to determine whether Writer or Reader failed during io.Copy. + Err error } // Write implements io.Writer.Write. @@ -590,5 +594,8 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { break } } + if w.Err == nil { + w.Err = err + } return written, err } diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index b2e8d9c77..9ca695a95 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -53,7 +53,7 @@ go_template_instance( "Key": "uint64", "Range": "memmap.MappableRange", "Value": "uint64", - "Functions": "fileRangeSetFunctions", + "Functions": "FileRangeSetFunctions", }, ) diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 0a5466b0a..f52d712e3 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -34,25 +34,25 @@ import ( // // type FileRangeSet <generated by go_generics> -// fileRangeSetFunctions implements segment.Functions for FileRangeSet. -type fileRangeSetFunctions struct{} +// FileRangeSetFunctions implements segment.Functions for FileRangeSet. +type FileRangeSetFunctions struct{} // MinKey implements segment.Functions.MinKey. -func (fileRangeSetFunctions) MinKey() uint64 { +func (FileRangeSetFunctions) MinKey() uint64 { return 0 } // MaxKey implements segment.Functions.MaxKey. -func (fileRangeSetFunctions) MaxKey() uint64 { +func (FileRangeSetFunctions) MaxKey() uint64 { return math.MaxUint64 } // ClearValue implements segment.Functions.ClearValue. -func (fileRangeSetFunctions) ClearValue(_ *uint64) { +func (FileRangeSetFunctions) ClearValue(_ *uint64) { } // Merge implements segment.Functions.Merge. -func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) { +func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) { if frstart1+mr1.Length() != frstart2 { return 0, false } @@ -60,7 +60,7 @@ func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ } // Split implements segment.Functions.Split. -func (fileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) { +func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) { return frstart, frstart + (split - mr.Start) } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 0e46c5fb7..9bf4b4527 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -604,6 +604,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(&buf, "Mems_allowed:\t1\n") + fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n") return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0 } diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index 311798811..389c330a0 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -167,6 +167,11 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, if !srcPipe && !opts.SrcOffset { atomic.StoreInt64(&src.offset, src.offset+n) } + + // Don't report any errors if we have some progress without data loss. + if w.Err == nil { + err = nil + } } // Drop locks. diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 934828c12..6b07f6bf2 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -77,7 +77,7 @@ func (mi *masterInodeOperations) Release(ctx context.Context) { } // Truncate implements fs.InodeOperations.Truncate. -func (masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { +func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { return nil } diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 880b7bcd3..bc90330bc 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -74,6 +74,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/fspath", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index bfc46dfa6..4fc8296ef 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -7,6 +7,7 @@ go_test( size = "small", srcs = ["benchmark_test.go"], deps = [ + "//pkg/fspath", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/fsimpl/ext", diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 177ce2cb9..a56b03711 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -24,6 +24,7 @@ import ( "strings" "testing" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext" @@ -49,7 +50,9 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() - vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}) + vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() @@ -121,7 +124,7 @@ func BenchmarkVFS2Ext4fsStat(b *testing.B) { stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{ Root: *root, Start: *root, - Pathname: filePath, + Path: fspath.Parse(filePath), FollowFinalSymlink: true, }, &vfs.StatOptions{}) if err != nil { @@ -150,9 +153,9 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) { creds := auth.CredentialsFromContext(ctx) mountPointName := "/1/" pop := vfs.PathOperation{ - Root: *root, - Start: *root, - Pathname: mountPointName, + Root: *root, + Start: *root, + Path: fspath.Parse(mountPointName), } // Save the mount point for later use. @@ -181,7 +184,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) { stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{ Root: *root, Start: *root, - Pathname: filePath, + Path: fspath.Parse(filePath), FollowFinalSymlink: true, }, &vfs.StatOptions{}) if err != nil { diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go index 567523d32..4110649ab 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go @@ -29,8 +29,12 @@ package disklayout // byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()). const ( - // ExtentStructsSize is the size of all the three extent on-disk structs. - ExtentStructsSize = 12 + // ExtentHeaderSize is the size of the header of an extent tree node. + ExtentHeaderSize = 12 + + // ExtentEntrySize is the size of an entry in an extent tree node. + // This size is the same for both leaf and internal nodes. + ExtentEntrySize = 12 // ExtentMagic is the magic number which must be present in the header. ExtentMagic = 0xf30a @@ -57,7 +61,7 @@ type ExtentNode struct { Entries []ExtentEntryPair } -// ExtentEntry reprsents an extent tree node entry. The entry can either be +// ExtentEntry represents an extent tree node entry. The entry can either be // an ExtentIdx or Extent itself. This exists to simplify navigation logic. type ExtentEntry interface { // FileBlock returns the first file block number covered by this entry. diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go index b0fad9b71..8762b90db 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go @@ -21,7 +21,7 @@ import ( // TestExtentSize tests that the extent structs are of the correct // size. func TestExtentSize(t *testing.T) { - assertSize(t, ExtentHeader{}, ExtentStructsSize) - assertSize(t, ExtentIdx{}, ExtentStructsSize) - assertSize(t, Extent{}, ExtentStructsSize) + assertSize(t, ExtentHeader{}, ExtentHeaderSize) + assertSize(t, ExtentIdx{}, ExtentEntrySize) + assertSize(t, Extent{}, ExtentEntrySize) } diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index e9f756732..6c14a1e2d 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -25,6 +25,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" @@ -65,7 +66,9 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() - vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}) + vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() @@ -140,7 +143,7 @@ func TestSeek(t *testing.T) { fd, err := vfsfs.OpenAt( ctx, auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path}, + &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)}, &vfs.OpenOptions{}, ) if err != nil { @@ -359,7 +362,7 @@ func TestStatAt(t *testing.T) { got, err := vfsfs.StatAt(ctx, auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path}, + &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)}, &vfs.StatOptions{}, ) if err != nil { @@ -429,7 +432,7 @@ func TestRead(t *testing.T) { fd, err := vfsfs.OpenAt( ctx, auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.absPath}, + &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.absPath)}, &vfs.OpenOptions{}, ) if err != nil { @@ -565,7 +568,7 @@ func TestIterDirents(t *testing.T) { fd, err := vfsfs.OpenAt( ctx, auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path}, + &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)}, &vfs.OpenOptions{}, ) if err != nil { diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 3d3ebaca6..11dcc0346 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -57,7 +57,7 @@ func newExtentFile(regFile regularFile) (*extentFile, error) { func (f *extentFile) buildExtTree() error { rootNodeData := f.regFile.inode.diskInode.Data() - binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header) + binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header) // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. if f.root.Header.NumEntries > 4 { @@ -67,7 +67,7 @@ func (f *extentFile) buildExtTree() error { } f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries) - for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize { + for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize { var curEntry disklayout.ExtentEntry if f.root.Header.Height == 0 { // Leaf node. @@ -76,7 +76,7 @@ func (f *extentFile) buildExtTree() error { // Internal node. curEntry = &disklayout.ExtentIdx{} } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry) + binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry) f.root.Entries[i].Entry = curEntry } @@ -105,7 +105,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla } entries := make([]disklayout.ExtentEntryPair, header.NumEntries) - for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize { + for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize { var curEntry disklayout.ExtentEntry if header.Height == 0 { // Leaf node. diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index 5eca2b83f..841274daf 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -26,13 +26,6 @@ import ( type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - - // flags is the same as vfs.OpenOptions.Flags which are passed to - // vfs.FilesystemImpl.OpenAt. - // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2), - // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set. - // Only close(2), fstat(2), fstatfs(2) should work. - flags uint32 } func (fd *fileDescription) filesystem() *filesystem { @@ -43,18 +36,6 @@ func (fd *fileDescription) inode() *inode { return fd.vfsfd.Dentry().Impl().(*dentry).inode } -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil -} - // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index e7aa3b41b..616fc002a 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -275,6 +275,16 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op return vfsd, nil } +// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. +func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { + vfsd, inode, err := fs.walk(rp, true) + if err != nil { + return nil, err + } + inode.incRef() + return vfsd, nil +} + // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { vfsd, inode, err := fs.walk(rp, false) @@ -378,7 +388,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } // RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { +func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { if rp.Done() { return syserror.ENOENT } @@ -443,6 +453,42 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return nil, err + } + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return "", err + } + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + // PrependPath implements vfs.FilesystemImpl.PrependPath. func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { fs.mu.RLock() diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 24249525c..8608805bf 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -157,10 +157,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v switch in.impl.(type) { case *regularFile: var fd regularFileFD - fd.flags = flags - mnt.IncRef() - vfsd.IncRef() - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *directory: // Can't open directories writably. This check is not necessary for a read @@ -169,10 +166,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.EISDIR } var fd directoryFD - fd.flags = flags - mnt.IncRef() - vfsd.IncRef() - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *symlink: if flags&linux.O_PATH == 0 { @@ -180,10 +174,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.ELOOP } var fd symlinkFD - fd.flags = flags - mnt.IncRef() - vfsd.IncRef() - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 52596c090..39c03ee9d 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -25,6 +25,7 @@ go_library( "inode_impl_util.go", "kernfs.go", "slot_list.go", + "symlink.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", visibility = ["//pkg/sentry:internal"], @@ -49,6 +50,7 @@ go_test( deps = [ ":kernfs", "//pkg/abi/linux", + "//pkg/fspath", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 30c06baf0..606ca692d 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -15,6 +15,8 @@ package kernfs import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -26,7 +28,10 @@ import ( // DynamicBytesFile implements kernfs.Inode and represents a read-only // file whose contents are backed by a vfs.DynamicBytesSource. // -// Must be initialized with Init before first use. +// Must be instantiated with NewDynamicBytesFile or initialized with Init +// before first use. +// +// +stateify savable type DynamicBytesFile struct { InodeAttrs InodeNoopRefCount @@ -36,9 +41,14 @@ type DynamicBytesFile struct { data vfs.DynamicBytesSource } -// Init intializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) { - f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444) +var _ Inode = (*DynamicBytesFile)(nil) + +// Init initializes a dynamic bytes file. +func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) + } + f.InodeAttrs.Init(creds, ino, linux.ModeRegular|perm) f.data = data } @@ -59,23 +69,21 @@ func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error { // DynamicBytesFile. // // Must be initialized with Init before first use. +// +// +stateify savable type DynamicBytesFD struct { vfs.FileDescriptionDefaultImpl vfs.DynamicBytesFileDescriptionImpl vfsfd vfs.FileDescription inode Inode - flags uint32 } // Init initializes a DynamicBytesFD. func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) { - m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. - d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. - fd.flags = flags fd.inode = d.Impl().(*Dentry).inode fd.SetDataSource(data) - fd.vfsfd.Init(fd, m, d) + fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}) } // Seek implements vfs.FileDescriptionImpl.Seek. @@ -117,15 +125,3 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { // DynamicBytesFiles are immutable. return syserror.EPERM } - -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *DynamicBytesFD) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *DynamicBytesFD) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil -} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index d6c18937a..bcf069b5f 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -39,17 +39,13 @@ type GenericDirectoryFD struct { vfsfd vfs.FileDescription children *OrderedChildren - flags uint32 off int64 } // Init initializes a GenericDirectoryFD. func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) { - m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. - d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. fd.children = children - fd.flags = flags - fd.vfsfd.Init(fd, m, d) + fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}) } // VFSFileDescription returns a pointer to the vfs.FileDescription representing @@ -156,7 +152,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent fd.off++ } - return nil + var err error + relOffset := fd.off - int64(len(fd.children.set)) - 2 + fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) + return err } // Seek implements vfs.FileDecriptionImpl.Seek. @@ -180,18 +179,6 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int return offset, nil } -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *GenericDirectoryFD) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *GenericDirectoryFD) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil -} - // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.filesystem() diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index db486b6c1..79759e0fc 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -44,39 +44,37 @@ func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingP return nil, err } afterSymlink: + name := rp.Component() + // Revalidation must be skipped if name is "." or ".."; d or its parent + // respectively can't be expected to transition from invalidated back to + // valid, so detecting invalidation and retrying would loop forever. This + // is consistent with Linux: fs/namei.c:walk_component() => lookup_fast() + // calls d_revalidate(), but walk_component() => handle_dots() does not. + if name == "." { + rp.Advance() + return vfsd, nil + } + if name == ".." { + nextVFSD, err := rp.ResolveParent(vfsd) + if err != nil { + return nil, err + } + rp.Advance() + return nextVFSD, nil + } d.dirMu.Lock() - nextVFSD, err := rp.ResolveComponent(vfsd) - d.dirMu.Unlock() + nextVFSD, err := rp.ResolveChild(vfsd, name) if err != nil { + d.dirMu.Unlock() return nil, err } - if nextVFSD != nil { - // Cached dentry exists, revalidate. - next := nextVFSD.Impl().(*Dentry) - if !next.inode.Valid(ctx) { - d.dirMu.Lock() - rp.VirtualFilesystem().ForceDeleteDentry(nextVFSD) - d.dirMu.Unlock() - fs.deferDecRef(nextVFSD) // Reference from Lookup. - nextVFSD = nil - } - } - if nextVFSD == nil { - // Dentry isn't cached; it either doesn't exist or failed - // revalidation. Attempt to resolve it via Lookup. - name := rp.Component() - var err error - nextVFSD, err = d.inode.Lookup(ctx, name) - // Reference on nextVFSD dropped by a corresponding Valid. - if err != nil { - return nil, err - } - d.InsertChild(name, nextVFSD) + next, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, nextVFSD) + d.dirMu.Unlock() + if err != nil { + return nil, err } - next := nextVFSD.Impl().(*Dentry) - // Resolve any symlink at current path component. - if rp.ShouldFollowSymlink() && d.isSymlink() { + if rp.ShouldFollowSymlink() && next.isSymlink() { // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks". target, err := next.inode.Readlink(ctx) if err != nil { @@ -89,7 +87,44 @@ afterSymlink: } rp.Advance() - return nextVFSD, nil + return &next.vfsd, nil +} + +// revalidateChildLocked must be called after a call to parent.vfsd.Child(name) +// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be +// nil) to verify that the returned child (or lack thereof) is correct. +// +// Preconditions: Filesystem.mu must be locked for at least reading. +// parent.dirMu must be locked. parent.isDir(). name is not "." or "..". +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, childVFSD *vfs.Dentry) (*Dentry, error) { + if childVFSD != nil { + // Cached dentry exists, revalidate. + child := childVFSD.Impl().(*Dentry) + if !child.inode.Valid(ctx) { + vfsObj.ForceDeleteDentry(childVFSD) + fs.deferDecRef(childVFSD) // Reference from Lookup. + childVFSD = nil + } + } + if childVFSD == nil { + // Dentry isn't cached; it either doesn't exist or failed + // revalidation. Attempt to resolve it via Lookup. + // + // FIXME(b/144498111): Inode.Lookup() should return *(kernfs.)Dentry, + // not *vfs.Dentry, since (kernfs.)Filesystem assumes that all dentries + // in the filesystem are (kernfs.)Dentry and performs vfs.DentryImpl + // casts accordingly. + var err error + childVFSD, err = parent.inode.Lookup(ctx, name) + if err != nil { + return nil, err + } + // Reference on childVFSD dropped by a corresponding Valid. + parent.insertChildLocked(name, childVFSD) + } + return childVFSD.Impl().(*Dentry), nil } // walkExistingLocked resolves rp to an existing file. @@ -242,6 +277,19 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op return vfsd, nil } +// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. +func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, _, err := fs.walkParentDirLocked(ctx, rp) + if err != nil { + return nil, err + } + vfsd.IncRef() // Ownership transferred to caller. + return vfsd, nil +} + // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { if rp.Done() { @@ -459,40 +507,42 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st } // RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { - noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 - exchange := opts.Flags&linux.RENAME_EXCHANGE != 0 - whiteout := opts.Flags&linux.RENAME_WHITEOUT != 0 - if exchange && (noReplace || whiteout) { - // Can't specify RENAME_NOREPLACE or RENAME_WHITEOUT with RENAME_EXCHANGE. - return syserror.EINVAL - } - if exchange || whiteout { - // Exchange and Whiteout flags are not supported on kernfs. +func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { + // Only RENAME_NOREPLACE is supported. + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { return syserror.EINVAL } + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 fs.mu.Lock() defer fs.mu.Lock() + // Resolve the destination directory first to verify that it's on this + // Mount. + dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } mnt := rp.Mount() - if mnt != vd.Mount() { + if mnt != oldParentVD.Mount() { return syserror.EXDEV } - if err := mnt.CheckBeginWrite(); err != nil { return err } defer mnt.EndWrite() - dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + srcDirVFSD := oldParentVD.Dentry() + srcDir := srcDirVFSD.Impl().(*Dentry) + srcDir.dirMu.Lock() + src, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), srcDir, oldName, srcDirVFSD.Child(oldName)) + srcDir.dirMu.Unlock() fs.processDeferredDecRefsLocked() if err != nil { return err } - - srcVFSD := vd.Dentry() - srcDirVFSD := srcVFSD.Parent() + srcVFSD := &src.vfsd // Can we remove the src dentry? if err := checkDeleteLocked(rp, srcVFSD); err != nil { @@ -683,6 +733,58 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return nil } +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return nil, err + } + // kernfs currently does not support extended attributes. + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + // kernfs currently does not support extended attributes. + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + // PrependPath implements vfs.FilesystemImpl.PrependPath. func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { fs.mu.RLock() diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 7b45b702a..752e0f659 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -139,6 +139,11 @@ func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, panic("Lookup called on non-directory inode") } +// IterDirents implements Inode.IterDirents. +func (*InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { + panic("IterDirents called on non-directory inode") +} + // Valid implements Inode.Valid. func (*InodeNotDirectory) Valid(context.Context) bool { return true @@ -156,6 +161,11 @@ func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dent return nil, syserror.ENOENT } +// IterDirents implements Inode.IterDirents. +func (*InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return offset, nil +} + // Valid implements Inode.Valid. func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool { return true @@ -490,3 +500,13 @@ func (o *OrderedChildren) nthLocked(i int64) *slot { } return nil } + +// InodeSymlink partially implements Inode interface for symlinks. +type InodeSymlink struct { + InodeNotDirectory +} + +// Open implements Inode.Open. +func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + return nil, syserror.ELOOP +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index bb01c3d01..d69b299ae 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -239,14 +239,22 @@ func (d *Dentry) destroy() { // // Precondition: d must represent a directory inode. func (d *Dentry) InsertChild(name string, child *vfs.Dentry) { + d.dirMu.Lock() + d.insertChildLocked(name, child) + d.dirMu.Unlock() +} + +// insertChildLocked is equivalent to InsertChild, with additional +// preconditions. +// +// Precondition: d.dirMu must be locked. +func (d *Dentry) insertChildLocked(name string, child *vfs.Dentry) { if !d.isDir() { panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) } vfsDentry := d.VFSDentry() vfsDentry.IncRef() // DecRef in child's Dentry.destroy. - d.dirMu.Lock() vfsDentry.InsertChild(child, name) - d.dirMu.Unlock() } // The Inode interface maps filesystem-level operations that operate on paths to @@ -396,6 +404,15 @@ type inodeDynamicLookup interface { // Valid should return true if this inode is still valid, or needs to // be resolved again by a call to Lookup. Valid(ctx context.Context) bool + + // IterDirents is used to iterate over dynamically created entries. It invokes + // cb on each entry in the directory represented by the FileDescription. + // 'offset' is the offset for the entire IterDirents call, which may include + // results from the caller. 'relOffset' is the offset inside the entries + // returned by this IterDirents invocation. In other words, + // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff, + // while 'relOffset' is the place where iteration should start from. + IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } type inodeSymlink interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index f78bb7b04..4b6b95f5f 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -58,7 +59,9 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem { ctx := contexttest.Context(t) creds := auth.CredentialsFromContext(ctx) v := vfs.New() - v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}) + v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) if err != nil { t.Fatalf("Failed to create testfs root mount: %v", err) @@ -82,9 +85,9 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem { // Precondition: path should be relative path. func (s *TestSystem) PathOpAtRoot(path string) vfs.PathOperation { return vfs.PathOperation{ - Root: s.root, - Start: s.root, - Pathname: path, + Root: s.root, + Start: s.root, + Path: fspath.Parse(path), } } @@ -132,7 +135,7 @@ type file struct { func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, fs.NextIno(), f) + f.DynamicBytesFile.Init(creds, fs.NextIno(), f, 0777) d := &kernfs.Dentry{} d.Init(f) diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go new file mode 100644 index 000000000..068063f4e --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +type staticSymlink struct { + InodeAttrs + InodeNoopRefCount + InodeSymlink + + target string +} + +var _ Inode = (*staticSymlink)(nil) + +// NewStaticSymlink creates a new symlink file pointing to 'target'. +func NewStaticSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, target string) *Dentry { + inode := &staticSymlink{target: target} + inode.Init(creds, ino, linux.ModeSymlink|perm) + + d := &Dentry{} + d.Init(inode) + return d +} + +func (s *staticSymlink) Readlink(_ context.Context) (string, error) { + return s.target, nil +} diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go deleted file mode 100644 index 1f2a5122a..000000000 --- a/pkg/sentry/fsimpl/memfs/filesystem.go +++ /dev/null @@ -1,592 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memfs - -import ( - "fmt" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// stepLocked resolves rp.Component() in parent directory vfsd. -// -// stepLocked is loosely analogous to fs/namei.c:walk_component(). -// -// Preconditions: filesystem.mu must be locked. !rp.Done(). inode == -// vfsd.Impl().(*dentry).inode. -func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode) (*vfs.Dentry, *inode, error) { - if !inode.isDir() { - return nil, nil, syserror.ENOTDIR - } - if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, nil, err - } -afterSymlink: - nextVFSD, err := rp.ResolveComponent(vfsd) - if err != nil { - return nil, nil, err - } - if nextVFSD == nil { - // Since the Dentry tree is the sole source of truth for memfs, if it's - // not in the Dentry tree, it doesn't exist. - return nil, nil, syserror.ENOENT - } - nextInode := nextVFSD.Impl().(*dentry).inode - if symlink, ok := nextInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO: symlink traversals update access time - if err := rp.HandleSymlink(symlink.target); err != nil { - return nil, nil, err - } - goto afterSymlink // don't check the current directory again - } - rp.Advance() - return nextVFSD, nextInode, nil -} - -// walkExistingLocked resolves rp to an existing file. -// -// walkExistingLocked is loosely analogous to Linux's -// fs/namei.c:path_lookupat(). -// -// Preconditions: filesystem.mu must be locked. -func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) { - vfsd := rp.Start() - inode := vfsd.Impl().(*dentry).inode - for !rp.Done() { - var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode) - if err != nil { - return nil, nil, err - } - } - if rp.MustBeDir() && !inode.isDir() { - return nil, nil, syserror.ENOTDIR - } - return vfsd, inode, nil -} - -// walkParentDirLocked resolves all but the last path component of rp to an -// existing directory. It does not check that the returned directory is -// searchable by the provider of rp. -// -// walkParentDirLocked is loosely analogous to Linux's -// fs/namei.c:path_parentat(). -// -// Preconditions: filesystem.mu must be locked. !rp.Done(). -func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) { - vfsd := rp.Start() - inode := vfsd.Impl().(*dentry).inode - for !rp.Final() { - var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode) - if err != nil { - return nil, nil, err - } - } - if !inode.isDir() { - return nil, nil, syserror.ENOTDIR - } - return vfsd, inode, nil -} - -// checkCreateLocked checks that a file named rp.Component() may be created in -// directory parentVFSD, then returns rp.Component(). -// -// Preconditions: filesystem.mu must be locked. parentInode == -// parentVFSD.Impl().(*dentry).inode. parentInode.isDir() == true. -func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *inode) (string, error) { - if err := parentInode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { - return "", err - } - pc := rp.Component() - if pc == "." || pc == ".." { - return "", syserror.EEXIST - } - childVFSD, err := rp.ResolveChild(parentVFSD, pc) - if err != nil { - return "", err - } - if childVFSD != nil { - return "", syserror.EEXIST - } - if parentVFSD.IsDisowned() { - return "", syserror.ENOENT - } - return pc, nil -} - -// checkDeleteLocked checks that the file represented by vfsd may be deleted. -func checkDeleteLocked(vfsd *vfs.Dentry) error { - parentVFSD := vfsd.Parent() - if parentVFSD == nil { - return syserror.EBUSY - } - if parentVFSD.IsDisowned() { - return syserror.ENOENT - } - return nil -} - -// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. -func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - vfsd, inode, err := walkExistingLocked(rp) - if err != nil { - return nil, err - } - if opts.CheckSearchable { - if !inode.isDir() { - return nil, syserror.ENOTDIR - } - if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } - } - inode.incRef() - return vfsd, nil -} - -// LinkAt implements vfs.FilesystemImpl.LinkAt. -func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := walkParentDirLocked(rp) - if err != nil { - return err - } - pc, err := checkCreateLocked(rp, parentVFSD, parentInode) - if err != nil { - return err - } - if rp.Mount() != vd.Mount() { - return syserror.EXDEV - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - d := vd.Dentry().Impl().(*dentry) - if d.inode.isDir() { - return syserror.EPERM - } - d.inode.incLinksLocked() - child := fs.newDentry(d.inode) - parentVFSD.InsertChild(&child.vfsd, pc) - parentInode.impl.(*directory).childList.PushBack(child) - return nil -} - -// MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := walkParentDirLocked(rp) - if err != nil { - return err - } - pc, err := checkCreateLocked(rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - child := fs.newDentry(fs.newDirectory(rp.Credentials(), opts.Mode)) - parentVFSD.InsertChild(&child.vfsd, pc) - parentInode.impl.(*directory).childList.PushBack(child) - parentInode.incLinksLocked() // from child's ".." - return nil -} - -// MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := walkParentDirLocked(rp) - if err != nil { - return err - } - pc, err := checkCreateLocked(rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - - switch opts.Mode.FileType() { - case 0: - // "Zero file type is equivalent to type S_IFREG." - mknod(2) - fallthrough - case linux.ModeRegular: - // TODO(b/138862511): Implement. - return syserror.EINVAL - - case linux.ModeNamedPipe: - child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode)) - parentVFSD.InsertChild(&child.vfsd, pc) - parentInode.impl.(*directory).childList.PushBack(child) - return nil - - case linux.ModeSocket: - // TODO(b/138862511): Implement. - return syserror.EINVAL - - case linux.ModeCharacterDevice: - fallthrough - case linux.ModeBlockDevice: - // TODO(b/72101894): We don't support creating block or character - // devices at the moment. - // - // When we start supporting block and character devices, we'll - // need to check for CAP_MKNOD here. - return syserror.EPERM - - default: - // "EINVAL - mode requested creation of something other than a - // regular file, device special file, FIFO or socket." - mknod(2) - return syserror.EINVAL - } -} - -// OpenAt implements vfs.FilesystemImpl.OpenAt. -func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - // Filter out flags that are not supported by memfs. O_DIRECTORY and - // O_NOFOLLOW have no effect here (they're handled by VFS by setting - // appropriate bits in rp), but are returned by - // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by - // pipes. - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK - - if opts.Flags&linux.O_CREAT == 0 { - fs.mu.RLock() - defer fs.mu.RUnlock() - vfsd, inode, err := walkExistingLocked(rp) - if err != nil { - return nil, err - } - return inode.open(ctx, rp, vfsd, opts.Flags, false) - } - - mustCreate := opts.Flags&linux.O_EXCL != 0 - vfsd := rp.Start() - inode := vfsd.Impl().(*dentry).inode - fs.mu.Lock() - defer fs.mu.Unlock() - if rp.Done() { - if rp.MustBeDir() { - return nil, syserror.EISDIR - } - if mustCreate { - return nil, syserror.EEXIST - } - return inode.open(ctx, rp, vfsd, opts.Flags, false) - } -afterTrailingSymlink: - // Walk to the parent directory of the last path component. - for !rp.Final() { - var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode) - if err != nil { - return nil, err - } - } - if !inode.isDir() { - return nil, syserror.ENOTDIR - } - // Check for search permission in the parent directory. - if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { - return nil, err - } - // Reject attempts to open directories with O_CREAT. - if rp.MustBeDir() { - return nil, syserror.EISDIR - } - pc := rp.Component() - if pc == "." || pc == ".." { - return nil, syserror.EISDIR - } - // Determine whether or not we need to create a file. - childVFSD, err := rp.ResolveChild(vfsd, pc) - if err != nil { - return nil, err - } - if childVFSD == nil { - // Already checked for searchability above; now check for writability. - if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { - return nil, err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return nil, err - } - defer rp.Mount().EndWrite() - // Create and open the child. - childInode := fs.newRegularFile(rp.Credentials(), opts.Mode) - child := fs.newDentry(childInode) - vfsd.InsertChild(&child.vfsd, pc) - inode.impl.(*directory).childList.PushBack(child) - return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true) - } - // Open existing file or follow symlink. - if mustCreate { - return nil, syserror.EEXIST - } - childInode := childVFSD.Impl().(*dentry).inode - if symlink, ok := childInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO: symlink traversals update access time - if err := rp.HandleSymlink(symlink.target); err != nil { - return nil, err - } - // rp.Final() may no longer be true since we now need to resolve the - // symlink target. - goto afterTrailingSymlink - } - return childInode.open(ctx, rp, childVFSD, opts.Flags, false) -} - -func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { - ats := vfs.AccessTypesForOpenFlags(flags) - if !afterCreate { - if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil { - return nil, err - } - } - mnt := rp.Mount() - switch impl := i.impl.(type) { - case *regularFile: - var fd regularFileFD - fd.flags = flags - fd.readable = vfs.MayReadFileWithOpenFlags(flags) - fd.writable = vfs.MayWriteFileWithOpenFlags(flags) - if fd.writable { - if err := mnt.CheckBeginWrite(); err != nil { - return nil, err - } - // mnt.EndWrite() is called by regularFileFD.Release(). - } - mnt.IncRef() - vfsd.IncRef() - fd.vfsfd.Init(&fd, mnt, vfsd) - if flags&linux.O_TRUNC != 0 { - impl.mu.Lock() - impl.data = impl.data[:0] - atomic.StoreInt64(&impl.dataLen, 0) - impl.mu.Unlock() - } - return &fd.vfsfd, nil - case *directory: - // Can't open directories writably. - if ats&vfs.MayWrite != 0 { - return nil, syserror.EISDIR - } - var fd directoryFD - mnt.IncRef() - vfsd.IncRef() - fd.vfsfd.Init(&fd, mnt, vfsd) - fd.flags = flags - return &fd.vfsfd, nil - case *symlink: - // Can't open symlinks without O_PATH (which is unimplemented). - return nil, syserror.ELOOP - case *namedPipe: - return newNamedPipeFD(ctx, impl, rp, vfsd, flags) - default: - panic(fmt.Sprintf("unknown inode type: %T", i.impl)) - } -} - -// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. -func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() - _, inode, err := walkExistingLocked(rp) - fs.mu.RUnlock() - if err != nil { - return "", err - } - symlink, ok := inode.impl.(*symlink) - if !ok { - return "", syserror.EINVAL - } - return symlink.target, nil -} - -// RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { - if rp.Done() { - return syserror.ENOENT - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := walkParentDirLocked(rp) - if err != nil { - return err - } - _, err = checkCreateLocked(rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - // TODO: actually implement RenameAt - return syserror.EPERM -} - -// RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - fs.mu.Lock() - defer fs.mu.Unlock() - vfsd, inode, err := walkExistingLocked(rp) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - if err := checkDeleteLocked(vfsd); err != nil { - return err - } - if !inode.isDir() { - return syserror.ENOTDIR - } - if vfsd.HasChildren() { - return syserror.ENOTEMPTY - } - if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { - return err - } - // Remove from parent directory's childList. - vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry)) - inode.decRef() - return nil -} - -// SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - fs.mu.RLock() - _, _, err := walkExistingLocked(rp) - fs.mu.RUnlock() - if err != nil { - return err - } - if opts.Stat.Mask == 0 { - return nil - } - // TODO: implement inode.setStat - return syserror.EPERM -} - -// StatAt implements vfs.FilesystemImpl.StatAt. -func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - fs.mu.RLock() - _, inode, err := walkExistingLocked(rp) - fs.mu.RUnlock() - if err != nil { - return linux.Statx{}, err - } - var stat linux.Statx - inode.statTo(&stat) - return stat, nil -} - -// StatFSAt implements vfs.FilesystemImpl.StatFSAt. -func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - fs.mu.RLock() - _, _, err := walkExistingLocked(rp) - fs.mu.RUnlock() - if err != nil { - return linux.Statfs{}, err - } - // TODO: actually implement statfs - return linux.Statfs{}, syserror.ENOSYS -} - -// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - if rp.Done() { - return syserror.EEXIST - } - fs.mu.Lock() - defer fs.mu.Unlock() - parentVFSD, parentInode, err := walkParentDirLocked(rp) - if err != nil { - return err - } - pc, err := checkCreateLocked(rp, parentVFSD, parentInode) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - child := fs.newDentry(fs.newSymlink(rp.Credentials(), target)) - parentVFSD.InsertChild(&child.vfsd, pc) - parentInode.impl.(*directory).childList.PushBack(child) - return nil -} - -// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - fs.mu.Lock() - defer fs.mu.Unlock() - vfsd, inode, err := walkExistingLocked(rp) - if err != nil { - return err - } - if err := rp.Mount().CheckBeginWrite(); err != nil { - return err - } - defer rp.Mount().EndWrite() - if err := checkDeleteLocked(vfsd); err != nil { - return err - } - if inode.isDir() { - return syserror.EISDIR - } - if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { - return err - } - // Remove from parent directory's childList. - vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry)) - inode.decLinksLocked() - return nil -} - -// PrependPath implements vfs.FilesystemImpl.PrependPath. -func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { - fs.mu.RLock() - defer fs.mu.RUnlock() - return vfs.GenericPrependPath(vfsroot, vd, b) -} diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go deleted file mode 100644 index b7f4853b3..000000000 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memfs - -import ( - "io" - "sync" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -type regularFile struct { - inode inode - - mu sync.RWMutex - data []byte - // dataLen is len(data), but accessed using atomic memory operations to - // avoid locking in inode.stat(). - dataLen int64 -} - -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { - file := ®ularFile{} - file.inode.init(file, fs, creds, mode) - file.inode.nlink = 1 // from parent directory - return &file.inode -} - -type regularFileFD struct { - fileDescription - - // These are immutable. - readable bool - writable bool - - // off is the file offset. off is accessed using atomic memory operations. - // offMu serializes operations that may mutate off. - off int64 - offMu sync.Mutex -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { - if fd.writable { - fd.vfsfd.VirtualDentry().Mount().EndWrite() - } -} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - if !fd.readable { - return 0, syserror.EINVAL - } - f := fd.inode().impl.(*regularFile) - f.mu.RLock() - if offset >= int64(len(f.data)) { - f.mu.RUnlock() - return 0, io.EOF - } - n, err := dst.CopyOut(ctx, f.data[offset:]) - f.mu.RUnlock() - return int64(n), err -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - fd.offMu.Lock() - n, err := fd.PRead(ctx, dst, fd.off, opts) - fd.off += n - fd.offMu.Unlock() - return n, err -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - if !fd.writable { - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - srclen := src.NumBytes() - if srclen == 0 { - return 0, nil - } - f := fd.inode().impl.(*regularFile) - f.mu.Lock() - end := offset + srclen - if end < offset { - // Overflow. - f.mu.Unlock() - return 0, syserror.EFBIG - } - if end > f.dataLen { - f.data = append(f.data, make([]byte, end-f.dataLen)...) - atomic.StoreInt64(&f.dataLen, end) - } - n, err := src.CopyIn(ctx, f.data[offset:end]) - f.mu.Unlock() - return int64(n), err -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - fd.offMu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n - fd.offMu.Unlock() - return n, err -} - -// Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - fd.offMu.Lock() - defer fd.offMu.Unlock() - switch whence { - case linux.SEEK_SET: - // use offset as specified - case linux.SEEK_CUR: - offset += fd.off - case linux.SEEK_END: - offset += atomic.LoadInt64(&fd.inode().impl.(*regularFile).dataLen) - default: - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL - } - fd.off = offset - return offset, nil -} - -// Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *regularFileFD) Sync(ctx context.Context) error { - return nil -} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index ade6ac946..1f44b3217 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -6,15 +6,17 @@ package(licenses = ["notice"]) go_library( name = "proc", srcs = [ - "filesystems.go", + "filesystem.go", "loadavg.go", "meminfo.go", "mounts.go", "net.go", - "proc.go", "stat.go", "sys.go", "task.go", + "task_files.go", + "tasks.go", + "tasks_files.go", "version.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc", @@ -24,8 +26,10 @@ go_library( "//pkg/log", "//pkg/sentry/context", "//pkg/sentry/fs", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/limits", "//pkg/sentry/mm", "//pkg/sentry/socket", @@ -34,17 +38,40 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/syserror", ], ) go_test( name = "proc_test", size = "small", - srcs = ["net_test.go"], + srcs = [ + "boot_test.go", + "net_test.go", + "tasks_test.go", + ], embed = [":proc"], deps = [ "//pkg/abi/linux", + "//pkg/cpuid", + "//pkg/fspath", + "//pkg/memutil", + "//pkg/sentry/context", "//pkg/sentry/context/contexttest", + "//pkg/sentry/fs", "//pkg/sentry/inet", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/sched", + "//pkg/sentry/limits", + "//pkg/sentry/loader", + "//pkg/sentry/pgalloc", + "//pkg/sentry/platform", + "//pkg/sentry/platform/kvm", + "//pkg/sentry/platform/ptrace", + "//pkg/sentry/time", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/proc/boot_test.go b/pkg/sentry/fsimpl/proc/boot_test.go new file mode 100644 index 000000000..84a93ee56 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/boot_test.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "flag" + "fmt" + "os" + "runtime" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/sched" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/loader" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/time" + + // Platforms are plugable. + _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm" + _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" +) + +var ( + platformFlag = flag.String("platform", "ptrace", "specify which platform to use") +) + +// boot initializes a new bare bones kernel for test. +func boot() (*kernel.Kernel, error) { + platformCtr, err := platform.Lookup(*platformFlag) + if err != nil { + return nil, fmt.Errorf("platform not found: %v", err) + } + deviceFile, err := platformCtr.OpenDevice() + if err != nil { + return nil, fmt.Errorf("creating platform: %v", err) + } + plat, err := platformCtr.New(deviceFile) + if err != nil { + return nil, fmt.Errorf("creating platform: %v", err) + } + + k := &kernel.Kernel{ + Platform: plat, + } + + mf, err := createMemoryFile() + if err != nil { + return nil, err + } + k.SetMemoryFile(mf) + + // Pass k as the platform since it is savable, unlike the actual platform. + vdso, err := loader.PrepareVDSO(nil, k) + if err != nil { + return nil, fmt.Errorf("creating vdso: %v", err) + } + + // Create timekeeper. + tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) + if err != nil { + return nil, fmt.Errorf("creating timekeeper: %v", err) + } + tk.SetClocks(time.NewCalibratedClocks()) + + creds := auth.NewRootCredentials(auth.NewRootUserNamespace()) + + // Initiate the Kernel object, which is required by the Context passed + // to createVFS in order to mount (among other things) procfs. + if err = k.Init(kernel.InitKernelArgs{ + ApplicationCores: uint(runtime.GOMAXPROCS(-1)), + FeatureSet: cpuid.HostFeatureSet(), + Timekeeper: tk, + RootUserNamespace: creds.UserNamespace, + Vdso: vdso, + RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace), + RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace), + RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), + PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace), + }); err != nil { + return nil, fmt.Errorf("initializing kernel: %v", err) + } + + ctx := k.SupervisorContext() + + // Create mount namespace without root as it's the minimum required to create + // the global thread group. + mntns, err := fs.NewMountNamespace(ctx, nil) + if err != nil { + return nil, err + } + ls, err := limits.NewLinuxLimitSet() + if err != nil { + return nil, err + } + tg := k.NewThreadGroup(mntns, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls) + k.TestOnly_SetGlobalInit(tg) + + return k, nil +} + +// createTask creates a new bare bones task for tests. +func createTask(ctx context.Context, name string, tc *kernel.ThreadGroup) (*kernel.Task, error) { + k := kernel.KernelFromContext(ctx) + config := &kernel.TaskConfig{ + Kernel: k, + ThreadGroup: tc, + TaskContext: &kernel.TaskContext{Name: name}, + Credentials: auth.CredentialsFromContext(ctx), + AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), + UTSNamespace: kernel.UTSNamespaceFromContext(ctx), + IPCNamespace: kernel.IPCNamespaceFromContext(ctx), + AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), + } + return k.TaskSet().NewTask(config) +} + +func createMemoryFile() (*pgalloc.MemoryFile, error) { + const memfileName = "test-memory" + memfd, err := memutil.CreateMemFD(memfileName, 0) + if err != nil { + return nil, fmt.Errorf("error creating memfd: %v", err) + } + memfile := os.NewFile(uintptr(memfd), memfileName) + mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) + if err != nil { + memfile.Close() + return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err) + } + return mf, nil +} diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go new file mode 100644 index 000000000..d09182c77 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -0,0 +1,69 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package proc implements a partial in-memory file system for procfs. +package proc + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// procFSType is the factory class for procfs. +// +// +stateify savable +type procFSType struct{} + +var _ vfs.FilesystemType = (*procFSType)(nil) + +// GetFilesystem implements vfs.FilesystemType. +func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + k := kernel.KernelFromContext(ctx) + if k == nil { + return nil, nil, fmt.Errorf("procfs requires a kernel") + } + pidns := kernel.PIDNamespaceFromContext(ctx) + if pidns == nil { + return nil, nil, fmt.Errorf("procfs requires a PID namespace") + } + + procfs := &kernfs.Filesystem{} + procfs.VFSFilesystem().Init(vfsObj, procfs) + + _, dentry := newTasksInode(procfs, k, pidns) + return procfs.VFSFilesystem(), dentry.VFSDentry(), nil +} + +// dynamicInode is an overfitted interface for common Inodes with +// dynamicByteSource types used in procfs. +type dynamicInode interface { + kernfs.Inode + vfs.DynamicBytesSource + + Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) +} + +func newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { + inode.Init(creds, ino, inode, perm) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go index 9135afef1..5351d86e8 100644 --- a/pkg/sentry/fsimpl/proc/loadavg.go +++ b/pkg/sentry/fsimpl/proc/loadavg.go @@ -19,15 +19,17 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" ) // loadavgData backs /proc/loadavg. // // +stateify savable -type loadavgData struct{} +type loadavgData struct { + kernfs.DynamicBytesFile +} -var _ vfs.DynamicBytesSource = (*loadavgData)(nil) +var _ dynamicInode = (*loadavgData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error { diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go index 9a827cd66..cbdd4f3fc 100644 --- a/pkg/sentry/fsimpl/proc/meminfo.go +++ b/pkg/sentry/fsimpl/proc/meminfo.go @@ -19,21 +19,23 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/sentry/vfs" ) // meminfoData implements vfs.DynamicBytesSource for /proc/meminfo. // // +stateify savable type meminfoData struct { + kernfs.DynamicBytesFile + // k is the owning Kernel. k *kernel.Kernel } -var _ vfs.DynamicBytesSource = (*meminfoData)(nil) +var _ dynamicInode = (*meminfoData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go index 720db3828..50894a534 100644 --- a/pkg/sentry/fsimpl/proc/stat.go +++ b/pkg/sentry/fsimpl/proc/stat.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/vfs" ) // cpuStats contains the breakdown of CPU time for /proc/stat. @@ -66,11 +66,13 @@ func (c cpuStats) String() string { // // +stateify savable type statData struct { + kernfs.DynamicBytesFile + // k is the owning Kernel. k *kernel.Kernel } -var _ vfs.DynamicBytesSource = (*statData)(nil) +var _ dynamicInode = (*statData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error { diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index c46e05c3a..11a64c777 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -15,247 +15,176 @@ package proc import ( - "bytes" - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) -// mapsCommon is embedded by mapsData and smapsData. -type mapsCommon struct { - t *kernel.Task -} - -// mm gets the kernel task's MemoryManager. No additional reference is taken on -// mm here. This is safe because MemoryManager.destroy is required to leave the -// MemoryManager in a state where it's still usable as a DynamicBytesSource. -func (md *mapsCommon) mm() *mm.MemoryManager { - var tmm *mm.MemoryManager - md.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - tmm = mm - } - }) - return tmm -} - -// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. +// taskInode represents the inode for /proc/PID/ directory. // // +stateify savable -type mapsData struct { - mapsCommon +type taskInode struct { + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNoDynamicLookup + kernfs.InodeAttrs + kernfs.OrderedChildren + + task *kernel.Task } -var _ vfs.DynamicBytesSource = (*mapsData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (md *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { - if mm := md.mm(); mm != nil { - mm.ReadMapsDataInto(ctx, buf) +var _ kernfs.Inode = (*taskInode)(nil) + +func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool) *kernfs.Dentry { + contents := map[string]*kernfs.Dentry{ + //"auxv": newAuxvec(t, msrc), + //"cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + //"comm": newComm(t, msrc), + //"environ": newExecArgInode(t, msrc, environExecArg), + //"exe": newExe(t, msrc), + //"fd": newFdDir(t, msrc), + //"fdinfo": newFdInfoDir(t, msrc), + //"gid_map": newGIDMap(t, msrc), + "io": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, newIO(task, isThreadGroup)), + "maps": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &mapsData{task: task}), + //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), + //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), + //"ns": newNamespaceDir(t, msrc), + "smaps": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &smapsData{task: task}), + "stat": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &taskStatData{t: task, pidns: pidns, tgstats: isThreadGroup}), + "statm": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &statmData{t: task}), + "status": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &statusData{t: task, pidns: pidns}), + //"uid_map": newUIDMap(t, msrc), } - return nil -} - -// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps. -// -// +stateify savable -type smapsData struct { - mapsCommon -} - -var _ vfs.DynamicBytesSource = (*smapsData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (sd *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { - if mm := sd.mm(); mm != nil { - mm.ReadSmapsDataInto(ctx, buf) + if isThreadGroup { + //contents["task"] = p.newSubtasks(t, msrc) } - return nil -} - -// +stateify savable -type taskStatData struct { - t *kernel.Task + //if len(p.cgroupControllers) > 0 { + // contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers) + //} - // If tgstats is true, accumulate fault stats (not implemented) and CPU - // time across all tasks in t's thread group. - tgstats bool + taskInode := &taskInode{task: task} + // Note: credentials are overridden by taskOwnedInode. + taskInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) - // pidns is the PID namespace associated with the proc filesystem that - // includes the file using this statData. - pidns *kernel.PIDNamespace -} - -var _ vfs.DynamicBytesSource = (*taskStatData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t)) - fmt.Fprintf(buf, "(%s) ", s.t.Name()) - fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0]) - ppid := kernel.ThreadID(0) - if parent := s.t.Parent(); parent != nil { - ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) - } - fmt.Fprintf(buf, "%d ", ppid) - fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup())) - fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session())) - fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */) - fmt.Fprintf(buf, "0 " /* flags */) - fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */) - var cputime usage.CPUStats - if s.tgstats { - cputime = s.t.ThreadGroup().CPUStats() - } else { - cputime = s.t.CPUStats() - } - fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) - cputime = s.t.ThreadGroup().JoinedChildCPUStats() - fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) - fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness()) - fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count()) + inode := &taskOwnedInode{Inode: taskInode, owner: task} + dentry := &kernfs.Dentry{} + dentry.Init(inode) - // itrealvalue. Since kernel 2.6.17, this field is no longer - // maintained, and is hard coded as 0. - fmt.Fprintf(buf, "0 ") + taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + links := taskInode.OrderedChildren.Populate(dentry, contents) + taskInode.IncLinks(links) - // Start time is relative to boot time, expressed in clock ticks. - fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime()))) + return dentry +} - var vss, rss uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) - fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize) +// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long +// as the task is still running. When it's dead, another tasks with the same +// PID could replace it. +func (i *taskInode) Valid(ctx context.Context) bool { + return i.task.ExitState() != kernel.TaskExitDead +} - // rsslim. - fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur) +// Open implements kernfs.Inode. +func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} - fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */) - fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */) - fmt.Fprintf(buf, "0 0 " /* nswap cnswap */) - terminationSignal := linux.Signal(0) - if s.t == s.t.ThreadGroup().Leader() { - terminationSignal = s.t.ThreadGroup().TerminationSignal() +// SetStat implements kernfs.Inode. +func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { + stat := opts.Stat + if stat.Mask&linux.STATX_MODE != 0 { + return syserror.EPERM } - fmt.Fprintf(buf, "%d ", terminationSignal) - fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */) - fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */) - fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */) - fmt.Fprintf(buf, "0\n" /* exit_code */) - return nil } -// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm. -// -// +stateify savable -type statmData struct { - t *kernel.Task +// taskOwnedInode implements kernfs.Inode and overrides inode owner with task +// effective user and group. +type taskOwnedInode struct { + kernfs.Inode + + // owner is the task that owns this inode. + owner *kernel.Task } -var _ vfs.DynamicBytesSource = (*statmData)(nil) +var _ kernfs.Inode = (*taskOwnedInode)(nil) -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { - var vss, rss uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) +func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { + // Note: credentials are overridden by taskOwnedInode. + inode.Init(task.Credentials(), ino, inode, perm) - fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize) - return nil + taskInode := &taskOwnedInode{Inode: inode, owner: task} + d := &kernfs.Dentry{} + d.Init(taskInode) + return d } -// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status. -// -// +stateify savable -type statusData struct { - t *kernel.Task - pidns *kernel.PIDNamespace +// Stat implements kernfs.Inode. +func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx { + stat := i.Inode.Stat(fs) + uid, gid := i.getOwner(linux.FileMode(stat.Mode)) + stat.UID = uint32(uid) + stat.GID = uint32(gid) + return stat } -var _ vfs.DynamicBytesSource = (*statusData)(nil) +// CheckPermissions implements kernfs.Inode. +func (i *taskOwnedInode) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + mode := i.Mode() + uid, gid := i.getOwner(mode) + return vfs.GenericCheckPermissions( + creds, + ats, + mode.FileType() == linux.ModeDirectory, + uint16(mode), + uid, + gid, + ) +} -// Generate implements vfs.DynamicBytesSource.Generate. -func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name()) - fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus()) - fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup())) - fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t)) - ppid := kernel.ThreadID(0) - if parent := s.t.Parent(); parent != nil { - ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) +func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) { + // By default, set the task owner as the file owner. + creds := i.owner.Credentials() + uid := creds.EffectiveKUID + gid := creds.EffectiveKGID + + // Linux doesn't apply dumpability adjustments to world readable/executable + // directories so that applications can stat /proc/PID to determine the + // effective UID of a process. See fs/proc/base.c:task_dump_owner. + if mode.FileType() == linux.ModeDirectory && mode.Permissions() == 0555 { + return uid, gid } - fmt.Fprintf(buf, "PPid:\t%d\n", ppid) - tpid := kernel.ThreadID(0) - if tracer := s.t.Tracer(); tracer != nil { - tpid = s.pidns.IDOfTask(tracer) + + // If the task is not dumpable, then root (in the namespace preferred) + // owns the file. + m := getMM(i.owner) + if m == nil { + return auth.RootKUID, auth.RootKGID } - fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid) - var fds int - var vss, rss, data uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + if m.Dumpability() != mm.UserDumpable { + uid = auth.RootKUID + if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() { + uid = kuid } - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - data = mm.VirtualDataSize() + gid = auth.RootKGID + if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() { + gid = kgid } - }) - fmt.Fprintf(buf, "FDSize:\t%d\n", fds) - fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10) - fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10) - fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10) - fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count()) - creds := s.t.Credentials() - fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps) - fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps) - fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) - fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) - fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode()) - return nil -} - -// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider. -type ioUsage interface { - // IOUsage returns the io usage data. - IOUsage() *usage.IO -} - -// +stateify savable -type ioData struct { - ioUsage + } + return uid, gid } -var _ vfs.DynamicBytesSource = (*ioData)(nil) - -// Generate implements vfs.DynamicBytesSource.Generate. -func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error { - io := usage.IO{} - io.Accumulate(i.IOUsage()) - - fmt.Fprintf(buf, "char: %d\n", io.CharsRead) - fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten) - fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls) - fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls) - fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead) - fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten) - fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled) - return nil +func newIO(t *kernel.Task, isThreadGroup bool) *ioData { + if isThreadGroup { + return &ioData{ioUsage: t.ThreadGroup()} + } + return &ioData{ioUsage: t} } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go new file mode 100644 index 000000000..93f0e1aa8 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -0,0 +1,272 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/mm" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// mm gets the kernel task's MemoryManager. No additional reference is taken on +// mm here. This is safe because MemoryManager.destroy is required to leave the +// MemoryManager in a state where it's still usable as a DynamicBytesSource. +func getMM(task *kernel.Task) *mm.MemoryManager { + var tmm *mm.MemoryManager + task.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + tmm = mm + } + }) + return tmm +} + +// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. +// +// +stateify savable +type mapsData struct { + kernfs.DynamicBytesFile + + task *kernel.Task +} + +var _ dynamicInode = (*mapsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if mm := getMM(d.task); mm != nil { + mm.ReadMapsDataInto(ctx, buf) + } + return nil +} + +// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps. +// +// +stateify savable +type smapsData struct { + kernfs.DynamicBytesFile + + task *kernel.Task +} + +var _ dynamicInode = (*smapsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if mm := getMM(d.task); mm != nil { + mm.ReadSmapsDataInto(ctx, buf) + } + return nil +} + +// +stateify savable +type taskStatData struct { + kernfs.DynamicBytesFile + + t *kernel.Task + + // If tgstats is true, accumulate fault stats (not implemented) and CPU + // time across all tasks in t's thread group. + tgstats bool + + // pidns is the PID namespace associated with the proc filesystem that + // includes the file using this statData. + pidns *kernel.PIDNamespace +} + +var _ dynamicInode = (*taskStatData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t)) + fmt.Fprintf(buf, "(%s) ", s.t.Name()) + fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0]) + ppid := kernel.ThreadID(0) + if parent := s.t.Parent(); parent != nil { + ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) + } + fmt.Fprintf(buf, "%d ", ppid) + fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup())) + fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session())) + fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */) + fmt.Fprintf(buf, "0 " /* flags */) + fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */) + var cputime usage.CPUStats + if s.tgstats { + cputime = s.t.ThreadGroup().CPUStats() + } else { + cputime = s.t.CPUStats() + } + fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) + cputime = s.t.ThreadGroup().JoinedChildCPUStats() + fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) + fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness()) + fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count()) + + // itrealvalue. Since kernel 2.6.17, this field is no longer + // maintained, and is hard coded as 0. + fmt.Fprintf(buf, "0 ") + + // Start time is relative to boot time, expressed in clock ticks. + fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime()))) + + var vss, rss uint64 + s.t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } + }) + fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize) + + // rsslim. + fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur) + + fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */) + fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */) + fmt.Fprintf(buf, "0 0 " /* nswap cnswap */) + terminationSignal := linux.Signal(0) + if s.t == s.t.ThreadGroup().Leader() { + terminationSignal = s.t.ThreadGroup().TerminationSignal() + } + fmt.Fprintf(buf, "%d ", terminationSignal) + fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */) + fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */) + fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */) + fmt.Fprintf(buf, "0\n" /* exit_code */) + + return nil +} + +// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm. +// +// +stateify savable +type statmData struct { + kernfs.DynamicBytesFile + + t *kernel.Task +} + +var _ dynamicInode = (*statmData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { + var vss, rss uint64 + s.t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } + }) + + fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize) + return nil +} + +// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status. +// +// +stateify savable +type statusData struct { + kernfs.DynamicBytesFile + + t *kernel.Task + pidns *kernel.PIDNamespace +} + +var _ dynamicInode = (*statusData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name()) + fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus()) + fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup())) + fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t)) + ppid := kernel.ThreadID(0) + if parent := s.t.Parent(); parent != nil { + ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) + } + fmt.Fprintf(buf, "PPid:\t%d\n", ppid) + tpid := kernel.ThreadID(0) + if tracer := s.t.Tracer(); tracer != nil { + tpid = s.pidns.IDOfTask(tracer) + } + fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid) + var fds int + var vss, rss, data uint64 + s.t.WithMuLocked(func(t *kernel.Task) { + if fdTable := t.FDTable(); fdTable != nil { + fds = fdTable.Size() + } + if mm := t.MemoryManager(); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + data = mm.VirtualDataSize() + } + }) + fmt.Fprintf(buf, "FDSize:\t%d\n", fds) + fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10) + fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10) + fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10) + fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count()) + creds := s.t.Credentials() + fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps) + fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps) + fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) + fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) + fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(buf, "Mems_allowed:\t1\n") + fmt.Fprintf(buf, "Mems_allowed_list:\t0\n") + return nil +} + +// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider. +type ioUsage interface { + // IOUsage returns the io usage data. + IOUsage() *usage.IO +} + +// +stateify savable +type ioData struct { + kernfs.DynamicBytesFile + + ioUsage +} + +var _ dynamicInode = (*ioData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error { + io := usage.IO{} + io.Accumulate(i.IOUsage()) + + fmt.Fprintf(buf, "char: %d\n", io.CharsRead) + fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten) + fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls) + fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls) + fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead) + fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten) + fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go new file mode 100644 index 000000000..d8f92d52f --- /dev/null +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -0,0 +1,218 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "sort" + "strconv" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + defaultPermission = 0444 + selfName = "self" + threadSelfName = "thread-self" +) + +// InoGenerator generates unique inode numbers for a given filesystem. +type InoGenerator interface { + NextIno() uint64 +} + +// tasksInode represents the inode for /proc/ directory. +// +// +stateify savable +type tasksInode struct { + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeAttrs + kernfs.OrderedChildren + + inoGen InoGenerator + pidns *kernel.PIDNamespace + + // '/proc/self' and '/proc/thread-self' have custom directory offsets in + // Linux. So handle them outside of OrderedChildren. + selfSymlink *vfs.Dentry + threadSelfSymlink *vfs.Dentry +} + +var _ kernfs.Inode = (*tasksInode)(nil) + +func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace) (*tasksInode, *kernfs.Dentry) { + root := auth.NewRootCredentials(pidns.UserNamespace()) + contents := map[string]*kernfs.Dentry{ + //"cpuinfo": newCPUInfo(ctx, msrc), + //"filesystems": seqfile.NewSeqFileInode(ctx, &filesystemsData{}, msrc), + "loadavg": newDentry(root, inoGen.NextIno(), defaultPermission, &loadavgData{}), + "meminfo": newDentry(root, inoGen.NextIno(), defaultPermission, &meminfoData{k: k}), + "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), defaultPermission, "self/mounts"), + "stat": newDentry(root, inoGen.NextIno(), defaultPermission, &statData{k: k}), + //"uptime": newUptime(ctx, msrc), + //"version": newVersionData(root, inoGen.NextIno(), k), + "version": newDentry(root, inoGen.NextIno(), defaultPermission, &versionData{k: k}), + } + + inode := &tasksInode{ + pidns: pidns, + inoGen: inoGen, + selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(), + threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(), + } + inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555) + + dentry := &kernfs.Dentry{} + dentry.Init(inode) + + inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + links := inode.OrderedChildren.Populate(dentry, contents) + inode.IncLinks(links) + + return inode, dentry +} + +// Lookup implements kernfs.inodeDynamicLookup. +func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + // Try to lookup a corresponding task. + tid, err := strconv.ParseUint(name, 10, 64) + if err != nil { + // If it failed to parse, check if it's one of the special handled files. + switch name { + case selfName: + return i.selfSymlink, nil + case threadSelfName: + return i.threadSelfSymlink, nil + } + return nil, syserror.ENOENT + } + + task := i.pidns.TaskWithID(kernel.ThreadID(tid)) + if task == nil { + return nil, syserror.ENOENT + } + + taskDentry := newTaskInode(i.inoGen, task, i.pidns, true) + return taskDentry.VFSDentry(), nil +} + +// Valid implements kernfs.inodeDynamicLookup. +func (i *tasksInode) Valid(ctx context.Context) bool { + return true +} + +// IterDirents implements kernfs.inodeDynamicLookup. +func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { + // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 + const FIRST_PROCESS_ENTRY = 256 + + // Use maxTaskID to shortcut searches that will result in 0 entries. + const maxTaskID = kernel.TasksLimit + 1 + if offset >= maxTaskID { + return offset, nil + } + + // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories + // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by + // '/proc/thread-self' and then '/proc/[pid]'. + if offset < FIRST_PROCESS_ENTRY { + offset = FIRST_PROCESS_ENTRY + } + + if offset == FIRST_PROCESS_ENTRY { + dirent := vfs.Dirent{ + Name: selfName, + Type: linux.DT_LNK, + Ino: i.inoGen.NextIno(), + NextOff: offset + 1, + } + if !cb.Handle(dirent) { + return offset, nil + } + offset++ + } + if offset == FIRST_PROCESS_ENTRY+1 { + dirent := vfs.Dirent{ + Name: threadSelfName, + Type: linux.DT_LNK, + Ino: i.inoGen.NextIno(), + NextOff: offset + 1, + } + if !cb.Handle(dirent) { + return offset, nil + } + offset++ + } + + // Collect all tasks that TGIDs are greater than the offset specified. Per + // Linux we only include in directory listings if it's the leader. But for + // whatever crazy reason, you can still walk to the given node. + var tids []int + startTid := offset - FIRST_PROCESS_ENTRY - 2 + for _, tg := range i.pidns.ThreadGroups() { + tid := i.pidns.IDOfThreadGroup(tg) + if int64(tid) < startTid { + continue + } + if leader := tg.Leader(); leader != nil { + tids = append(tids, int(tid)) + } + } + + if len(tids) == 0 { + return offset, nil + } + + sort.Ints(tids) + for _, tid := range tids { + dirent := vfs.Dirent{ + Name: strconv.FormatUint(uint64(tid), 10), + Type: linux.DT_DIR, + Ino: i.inoGen.NextIno(), + NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1, + } + if !cb.Handle(dirent) { + return offset, nil + } + offset++ + } + return maxTaskID, nil +} + +// Open implements kernfs.Inode. +func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx { + stat := i.InodeAttrs.Stat(vsfs) + + // Add dynamic children to link count. + for _, tg := range i.pidns.ThreadGroups() { + if leader := tg.Leader(); leader != nil { + stat.Nlink++ + } + } + + return stat +} diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go new file mode 100644 index 000000000..91f30a798 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -0,0 +1,92 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "fmt" + "strconv" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" +) + +type selfSymlink struct { + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeSymlink + + pidns *kernel.PIDNamespace +} + +var _ kernfs.Inode = (*selfSymlink)(nil) + +func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry { + inode := &selfSymlink{pidns: pidns} + inode.Init(creds, ino, linux.ModeSymlink|perm) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // Who is reading this link? + return "", syserror.EINVAL + } + tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) + if tgid == 0 { + return "", syserror.ENOENT + } + return strconv.FormatUint(uint64(tgid), 10), nil +} + +type threadSelfSymlink struct { + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeSymlink + + pidns *kernel.PIDNamespace +} + +var _ kernfs.Inode = (*threadSelfSymlink)(nil) + +func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry { + inode := &threadSelfSymlink{pidns: pidns} + inode.Init(creds, ino, linux.ModeSymlink|perm) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // Who is reading this link? + return "", syserror.EINVAL + } + tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) + tid := s.pidns.IDOfTask(t) + if tid == 0 || tgid == 0 { + return "", syserror.ENOENT + } + return fmt.Sprintf("%d/task/%d", tgid, tid), nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go new file mode 100644 index 000000000..ca8c87ec2 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -0,0 +1,555 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "fmt" + "math" + "path" + "strconv" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +var ( + // Next offset 256 by convention. Adds 1 for the next offset. + selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1} + threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1} + + // /proc/[pid] next offset starts at 256+2 (files above), then adds the + // PID, and adds 1 for the next offset. + proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1} + proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1} + proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1} +) + +type testIterDirentsCallback struct { + dirents []vfs.Dirent +} + +func (t *testIterDirentsCallback) Handle(d vfs.Dirent) bool { + t.dirents = append(t.dirents, d) + return true +} + +func checkDots(dirs []vfs.Dirent) ([]vfs.Dirent, error) { + if got := len(dirs); got < 2 { + return dirs, fmt.Errorf("wrong number of dirents, want at least: 2, got: %d: %v", got, dirs) + } + for i, want := range []string{".", ".."} { + if got := dirs[i].Name; got != want { + return dirs, fmt.Errorf("wrong name, want: %s, got: %s", want, got) + } + if got := dirs[i].Type; got != linux.DT_DIR { + return dirs, fmt.Errorf("wrong type, want: %d, got: %d", linux.DT_DIR, got) + } + } + return dirs[2:], nil +} + +func checkTasksStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) { + wants := map[string]vfs.Dirent{ + "loadavg": {Type: linux.DT_REG}, + "meminfo": {Type: linux.DT_REG}, + "mounts": {Type: linux.DT_LNK}, + "self": selfLink, + "stat": {Type: linux.DT_REG}, + "thread-self": threadSelfLink, + "version": {Type: linux.DT_REG}, + } + return checkFiles(gots, wants) +} + +func checkTaskStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) { + wants := map[string]vfs.Dirent{ + "io": {Type: linux.DT_REG}, + "maps": {Type: linux.DT_REG}, + "smaps": {Type: linux.DT_REG}, + "stat": {Type: linux.DT_REG}, + "statm": {Type: linux.DT_REG}, + "status": {Type: linux.DT_REG}, + } + return checkFiles(gots, wants) +} + +func checkFiles(gots []vfs.Dirent, wants map[string]vfs.Dirent) ([]vfs.Dirent, error) { + // Go over all files, when there is a match, the file is removed from both + // 'gots' and 'wants'. wants is expected to reach 0, as all files must + // be present. Remaining files in 'gots', is returned to caller to decide + // whether this is valid or not. + for i := 0; i < len(gots); i++ { + got := gots[i] + want, ok := wants[got.Name] + if !ok { + continue + } + if want.Type != got.Type { + return gots, fmt.Errorf("wrong file type, want: %v, got: %v: %+v", want.Type, got.Type, got) + } + if want.NextOff != 0 && want.NextOff != got.NextOff { + return gots, fmt.Errorf("wrong dirent offset, want: %v, got: %v: %+v", want.NextOff, got.NextOff, got) + } + + delete(wants, got.Name) + gots = append(gots[0:i], gots[i+1:]...) + i-- + } + if len(wants) != 0 { + return gots, fmt.Errorf("not all files were found, missing: %+v", wants) + } + return gots, nil +} + +func setup() (context.Context, *vfs.VirtualFilesystem, vfs.VirtualDentry, error) { + k, err := boot() + if err != nil { + return nil, nil, vfs.VirtualDentry{}, fmt.Errorf("creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("procfs", &procFSType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "procfs", &vfs.GetFilesystemOptions{}) + if err != nil { + return nil, nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace(): %v", err) + } + return ctx, vfsObj, mntns.Root(), nil +} + +func TestTasksEmpty(t *testing.T) { + ctx, vfsObj, root, err := setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + defer root.DecRef() + + fd, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt failed: %v", err) + } + + cb := testIterDirentsCallback{} + if err := fd.Impl().IterDirents(ctx, &cb); err != nil { + t.Fatalf("IterDirents(): %v", err) + } + cb.dirents, err = checkDots(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + cb.dirents, err = checkTasksStaticFiles(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + if len(cb.dirents) != 0 { + t.Errorf("found more files than expected: %+v", cb.dirents) + } +} + +func TestTasks(t *testing.T) { + ctx, vfsObj, root, err := setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + defer root.DecRef() + + k := kernel.KernelFromContext(ctx) + var tasks []*kernel.Task + for i := 0; i < 5; i++ { + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc) + if err != nil { + t.Fatalf("CreateTask(): %v", err) + } + tasks = append(tasks, task) + } + + fd, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) + } + + cb := testIterDirentsCallback{} + if err := fd.Impl().IterDirents(ctx, &cb); err != nil { + t.Fatalf("IterDirents(): %v", err) + } + cb.dirents, err = checkDots(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + cb.dirents, err = checkTasksStaticFiles(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + lastPid := 0 + for _, d := range cb.dirents { + pid, err := strconv.Atoi(d.Name) + if err != nil { + t.Fatalf("Invalid process directory %q", d.Name) + } + if lastPid > pid { + t.Errorf("pids not in order: %v", cb.dirents) + } + found := false + for _, t := range tasks { + if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) { + found = true + } + } + if !found { + t.Errorf("Additional task ID %d listed: %v", pid, tasks) + } + // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the + // PID, and adds 1 for the next offset. + if want := int64(256 + 2 + pid + 1); d.NextOff != want { + t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d) + } + } + + // Test lookup. + for _, path := range []string{"/1", "/2"} { + fd, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(path)}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err) + } + buf := make([]byte, 1) + bufIOSeq := usermem.BytesIOSequence(buf) + if _, err := fd.Read(ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { + t.Errorf("wrong error reading directory: %v", err) + } + } + + if _, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/9999")}, + &vfs.OpenOptions{}, + ); err != syserror.ENOENT { + t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err) + } +} + +func TestTasksOffset(t *testing.T) { + ctx, vfsObj, root, err := setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + defer root.DecRef() + + k := kernel.KernelFromContext(ctx) + for i := 0; i < 3; i++ { + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + if _, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc); err != nil { + t.Fatalf("CreateTask(): %v", err) + } + } + + for _, tc := range []struct { + name string + offset int64 + wants map[string]vfs.Dirent + }{ + { + name: "small offset", + offset: 100, + wants: map[string]vfs.Dirent{ + "self": selfLink, + "thread-self": threadSelfLink, + "1": proc1, + "2": proc2, + "3": proc3, + }, + }, + { + name: "offset at start", + offset: 256, + wants: map[string]vfs.Dirent{ + "self": selfLink, + "thread-self": threadSelfLink, + "1": proc1, + "2": proc2, + "3": proc3, + }, + }, + { + name: "skip /proc/self", + offset: 257, + wants: map[string]vfs.Dirent{ + "thread-self": threadSelfLink, + "1": proc1, + "2": proc2, + "3": proc3, + }, + }, + { + name: "skip symlinks", + offset: 258, + wants: map[string]vfs.Dirent{ + "1": proc1, + "2": proc2, + "3": proc3, + }, + }, + { + name: "skip first process", + offset: 260, + wants: map[string]vfs.Dirent{ + "2": proc2, + "3": proc3, + }, + }, + { + name: "last process", + offset: 261, + wants: map[string]vfs.Dirent{ + "3": proc3, + }, + }, + { + name: "after last", + offset: 262, + wants: nil, + }, + { + name: "TaskLimit+1", + offset: kernel.TasksLimit + 1, + wants: nil, + }, + { + name: "max", + offset: math.MaxInt64, + wants: nil, + }, + } { + t.Run(tc.name, func(t *testing.T) { + fd, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) + } + if _, err := fd.Impl().Seek(ctx, tc.offset, linux.SEEK_SET); err != nil { + t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err) + } + + cb := testIterDirentsCallback{} + if err := fd.Impl().IterDirents(ctx, &cb); err != nil { + t.Fatalf("IterDirents(): %v", err) + } + if cb.dirents, err = checkFiles(cb.dirents, tc.wants); err != nil { + t.Error(err.Error()) + } + if len(cb.dirents) != 0 { + t.Errorf("found more files than expected: %+v", cb.dirents) + } + }) + } +} + +func TestTask(t *testing.T) { + ctx, vfsObj, root, err := setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + defer root.DecRef() + + k := kernel.KernelFromContext(ctx) + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + _, err = createTask(ctx, "name", tc) + if err != nil { + t.Fatalf("CreateTask(): %v", err) + } + + fd, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/1")}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt(/1) failed: %v", err) + } + + cb := testIterDirentsCallback{} + if err := fd.Impl().IterDirents(ctx, &cb); err != nil { + t.Fatalf("IterDirents(): %v", err) + } + cb.dirents, err = checkDots(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + cb.dirents, err = checkTaskStaticFiles(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + if len(cb.dirents) != 0 { + t.Errorf("found more files than expected: %+v", cb.dirents) + } +} + +func TestProcSelf(t *testing.T) { + ctx, vfsObj, root, err := setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + defer root.DecRef() + + k := kernel.KernelFromContext(ctx) + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := createTask(ctx, "name", tc) + if err != nil { + t.Fatalf("CreateTask(): %v", err) + } + + fd, err := vfsObj.OpenAt( + task, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/self/"), FollowFinalSymlink: true}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt(/self/) failed: %v", err) + } + + cb := testIterDirentsCallback{} + if err := fd.Impl().IterDirents(ctx, &cb); err != nil { + t.Fatalf("IterDirents(): %v", err) + } + cb.dirents, err = checkDots(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + cb.dirents, err = checkTaskStaticFiles(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + if len(cb.dirents) != 0 { + t.Errorf("found more files than expected: %+v", cb.dirents) + } +} + +func iterateDir(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, fd *vfs.FileDescription) { + t.Logf("Iterating: /proc%s", fd.MappedName(ctx)) + + cb := testIterDirentsCallback{} + if err := fd.Impl().IterDirents(ctx, &cb); err != nil { + t.Fatalf("IterDirents(): %v", err) + } + var err error + cb.dirents, err = checkDots(cb.dirents) + if err != nil { + t.Error(err.Error()) + } + for _, d := range cb.dirents { + childPath := path.Join(fd.MappedName(ctx), d.Name) + if d.Type == linux.DT_LNK { + link, err := vfsObj.ReadlinkAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(childPath)}, + ) + if err != nil { + t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", childPath, err) + } else { + t.Logf("Skipping symlink: /proc%s => %s", childPath, link) + } + continue + } + + t.Logf("Opening: /proc%s", childPath) + child, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(childPath)}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err) + continue + } + stat, err := child.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Errorf("Stat(%v) failed: %v", childPath, err) + } + if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type { + t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type) + } + if d.Type == linux.DT_DIR { + // Found another dir, let's do it again! + iterateDir(ctx, t, vfsObj, root, child) + } + } +} + +// TestTree iterates all directories and stats every file. +func TestTree(t *testing.T) { + uberCtx, vfsObj, root, err := setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + defer root.DecRef() + + k := kernel.KernelFromContext(uberCtx) + var tasks []*kernel.Task + for i := 0; i < 5; i++ { + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := createTask(uberCtx, fmt.Sprintf("name-%d", i), tc) + if err != nil { + t.Fatalf("CreateTask(): %v", err) + } + tasks = append(tasks, task) + } + + ctx := tasks[0] + fd, err := vfsObj.OpenAt( + ctx, + auth.CredentialsFromContext(uberCtx), + &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) + } + iterateDir(ctx, t, vfsObj, root, fd) +} diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go index e1643d4e0..367f2396b 100644 --- a/pkg/sentry/fsimpl/proc/version.go +++ b/pkg/sentry/fsimpl/proc/version.go @@ -19,19 +19,21 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/vfs" ) // versionData implements vfs.DynamicBytesSource for /proc/version. // // +stateify savable type versionData struct { + kernfs.DynamicBytesFile + // k is the owning Kernel. k *kernel.Kernel } -var _ vfs.DynamicBytesSource = (*versionData)(nil) +var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 0cc751eb8..a5b285987 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -1,14 +1,13 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "dentry_list", out = "dentry_list.go", - package = "memfs", + package = "tmpfs", prefix = "dentry", template = "//pkg/ilist:generic_list", types = { @@ -18,25 +17,34 @@ go_template_instance( ) go_library( - name = "memfs", + name = "tmpfs", srcs = [ "dentry_list.go", "directory.go", "filesystem.go", - "memfs.go", "named_pipe.go", "regular_file.go", "symlink.go", + "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs", + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs", deps = [ "//pkg/abi/linux", "//pkg/amutex", "//pkg/fspath", + "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", + "//pkg/sentry/memmap", + "//pkg/sentry/pgalloc", + "//pkg/sentry/platform", + "//pkg/sentry/safemem", + "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", @@ -48,8 +56,9 @@ go_test( size = "small", srcs = ["benchmark_test.go"], deps = [ - ":memfs", + ":tmpfs", "//pkg/abi/linux", + "//pkg/fspath", "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", @@ -62,15 +71,20 @@ go_test( ) go_test( - name = "memfs_test", + name = "tmpfs_test", size = "small", - srcs = ["pipe_test.go"], - embed = [":memfs"], + srcs = [ + "pipe_test.go", + "regular_file_test.go", + ], + embed = [":tmpfs"], deps = [ "//pkg/abi/linux", + "//pkg/fspath", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go index 4a7a94a52..d88c83499 100644 --- a/pkg/sentry/fsimpl/memfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go @@ -21,12 +21,13 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -175,8 +176,10 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() - vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) + vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } @@ -193,9 +196,9 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ - Root: root, - Start: vd, - Pathname: name, + Root: root, + Start: vd, + Path: fspath.Parse(name), } if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ Mode: 0755, @@ -216,7 +219,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ Root: root, Start: vd, - Pathname: filename, + Path: fspath.Parse(filename), FollowFinalSymlink: true, }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, @@ -237,7 +240,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ Root: root, Start: root, - Pathname: filePath, + Path: fspath.Parse(filePath), FollowFinalSymlink: true, }, &vfs.StatOptions{}) if err != nil { @@ -364,8 +367,10 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() - vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) + vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } @@ -378,9 +383,9 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { root := mntns.Root() defer root.DecRef() pop := vfs.PathOperation{ - Root: root, - Start: root, - Pathname: mountPointName, + Root: root, + Start: root, + Path: fspath.Parse(mountPointName), } if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ Mode: 0755, @@ -394,7 +399,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { } defer mountPoint.DecRef() // Create and mount the submount. - if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil { + if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) @@ -408,9 +413,9 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ - Root: root, - Start: vd, - Pathname: name, + Root: root, + Start: vd, + Path: fspath.Parse(name), } if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ Mode: 0755, @@ -438,7 +443,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ Root: root, Start: vd, - Pathname: filename, + Path: fspath.Parse(filename), FollowFinalSymlink: true, }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, @@ -458,7 +463,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ Root: root, Start: root, - Pathname: filePath, + Path: fspath.Parse(filePath), FollowFinalSymlink: true, }, &vfs.StatOptions{}) if err != nil { diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 0bd82e480..887ca2619 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package memfs +package tmpfs import ( "gvisor.dev/gvisor/pkg/abi/linux" diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go new file mode 100644 index 000000000..26979729e --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -0,0 +1,698 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *filesystem) Sync(ctx context.Context) error { + // All filesystem state is in-memory. + return nil +} + +// stepLocked resolves rp.Component() to an existing file, starting from the +// given directory. +// +// stepLocked is loosely analogous to fs/namei.c:walk_component(). +// +// Preconditions: filesystem.mu must be locked. !rp.Done(). +func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { + if !d.inode.isDir() { + return nil, syserror.ENOTDIR + } + if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + return nil, err + } +afterSymlink: + nextVFSD, err := rp.ResolveComponent(&d.vfsd) + if err != nil { + return nil, err + } + if nextVFSD == nil { + // Since the Dentry tree is the sole source of truth for tmpfs, if it's + // not in the Dentry tree, it doesn't exist. + return nil, syserror.ENOENT + } + next := nextVFSD.Impl().(*dentry) + if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { + // TODO: symlink traversals update access time + if err := rp.HandleSymlink(symlink.target); err != nil { + return nil, err + } + goto afterSymlink // don't check the current directory again + } + rp.Advance() + return next, nil +} + +// walkParentDirLocked resolves all but the last path component of rp to an +// existing directory, starting from the given directory (which is usually +// rp.Start().Impl().(*dentry)). It does not check that the returned directory +// is searchable by the provider of rp. +// +// walkParentDirLocked is loosely analogous to Linux's +// fs/namei.c:path_parentat(). +// +// Preconditions: filesystem.mu must be locked. !rp.Done(). +func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { + for !rp.Final() { + next, err := stepLocked(rp, d) + if err != nil { + return nil, err + } + d = next + } + if !d.inode.isDir() { + return nil, syserror.ENOTDIR + } + return d, nil +} + +// resolveLocked resolves rp to an existing file. +// +// resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). +// +// Preconditions: filesystem.mu must be locked. +func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) { + d := rp.Start().Impl().(*dentry) + for !rp.Done() { + next, err := stepLocked(rp, d) + if err != nil { + return nil, err + } + d = next + } + if rp.MustBeDir() && !d.inode.isDir() { + return nil, syserror.ENOTDIR + } + return d, nil +} + +// doCreateAt checks that creating a file at rp is permitted, then invokes +// create to do so. +// +// doCreateAt is loosely analogous to a conjunction of Linux's +// fs/namei.c:filename_create() and done_path_create(). +// +// Preconditions: !rp.Done(). For the final path component in rp, +// !rp.ShouldFollowSymlink(). +func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string) error) error { + fs.mu.Lock() + defer fs.mu.Unlock() + parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + if err != nil { + return err + } + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + return err + } + name := rp.Component() + if name == "." || name == ".." { + return syserror.EEXIST + } + // Call parent.vfsd.Child() instead of stepLocked() or rp.ResolveChild(), + // because if the child exists we want to return EEXIST immediately instead + // of attempting symlink/mount traversal. + if parent.vfsd.Child(name) != nil { + return syserror.EEXIST + } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } + // In memfs, the only way to cause a dentry to be disowned is by removing + // it from the filesystem, so this check is equivalent to checking if + // parent has been removed. + if parent.vfsd.IsDisowned() { + return syserror.ENOENT + } + mnt := rp.Mount() + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + return create(parent, name) +} + +// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. +func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := resolveLocked(rp) + if err != nil { + return nil, err + } + if opts.CheckSearchable { + if !d.inode.isDir() { + return nil, syserror.ENOTDIR + } + if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil { + return nil, err + } + } + d.IncRef() + return &d.vfsd, nil +} + +// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. +func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + if err != nil { + return nil, err + } + d.IncRef() + return &d.vfsd, nil +} + +// LinkAt implements vfs.FilesystemImpl.LinkAt. +func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { + return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error { + if rp.Mount() != vd.Mount() { + return syserror.EXDEV + } + d := vd.Dentry().Impl().(*dentry) + if d.inode.isDir() { + return syserror.EPERM + } + if d.inode.nlink == 0 { + return syserror.ENOENT + } + if d.inode.nlink == maxLinks { + return syserror.EMLINK + } + d.inode.incLinksLocked() + child := fs.newDentry(d.inode) + parent.vfsd.InsertChild(&child.vfsd, name) + parent.inode.impl.(*directory).childList.PushBack(child) + return nil + }) +} + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + return fs.doCreateAt(rp, true /* dir */, func(parent *dentry, name string) error { + if parent.inode.nlink == maxLinks { + return syserror.EMLINK + } + parent.inode.incLinksLocked() // from child's ".." + child := fs.newDentry(fs.newDirectory(rp.Credentials(), opts.Mode)) + parent.vfsd.InsertChild(&child.vfsd, name) + parent.inode.impl.(*directory).childList.PushBack(child) + return nil + }) +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error { + switch opts.Mode.FileType() { + case 0, linux.S_IFREG: + child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode)) + parent.vfsd.InsertChild(&child.vfsd, name) + parent.inode.impl.(*directory).childList.PushBack(child) + return nil + case linux.S_IFIFO: + child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode)) + parent.vfsd.InsertChild(&child.vfsd, name) + parent.inode.impl.(*directory).childList.PushBack(child) + return nil + case linux.S_IFBLK, linux.S_IFCHR, linux.S_IFSOCK: + // Not yet supported. + return syserror.EPERM + default: + return syserror.EINVAL + } + }) +} + +// OpenAt implements vfs.FilesystemImpl.OpenAt. +func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + if opts.Flags&linux.O_TMPFILE != 0 { + // Not yet supported. + return nil, syserror.EOPNOTSUPP + } + + // Handle O_CREAT and !O_CREAT separately, since in the latter case we + // don't need fs.mu for writing. + if opts.Flags&linux.O_CREAT == 0 { + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := resolveLocked(rp) + if err != nil { + return nil, err + } + return d.open(ctx, rp, opts.Flags, false /* afterCreate */) + } + + mustCreate := opts.Flags&linux.O_EXCL != 0 + start := rp.Start().Impl().(*dentry) + fs.mu.Lock() + defer fs.mu.Unlock() + if rp.Done() { + // Reject attempts to open directories with O_CREAT. + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + if mustCreate { + return nil, syserror.EEXIST + } + return start.open(ctx, rp, opts.Flags, false /* afterCreate */) + } +afterTrailingSymlink: + parent, err := walkParentDirLocked(rp, start) + if err != nil { + return nil, err + } + // Check for search permission in the parent directory. + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + return nil, err + } + // Reject attempts to open directories with O_CREAT. + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + name := rp.Component() + if name == "." || name == ".." { + return nil, syserror.EISDIR + } + // Determine whether or not we need to create a file. + child, err := stepLocked(rp, parent) + if err == syserror.ENOENT { + // Already checked for searchability above; now check for writability. + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { + return nil, err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + defer rp.Mount().EndWrite() + // Create and open the child. + child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode)) + parent.vfsd.InsertChild(&child.vfsd, name) + parent.inode.impl.(*directory).childList.PushBack(child) + return child.open(ctx, rp, opts.Flags, true) + } + if err != nil { + return nil, err + } + // Do we need to resolve a trailing symlink? + if !rp.Done() { + start = parent + goto afterTrailingSymlink + } + // Open existing file. + if mustCreate { + return nil, syserror.EEXIST + } + return child.open(ctx, rp, opts.Flags, false) +} + +func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { + ats := vfs.AccessTypesForOpenFlags(flags) + if !afterCreate { + if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil { + return nil, err + } + } + mnt := rp.Mount() + switch impl := d.inode.impl.(type) { + case *regularFile: + var fd regularFileFD + fd.readable = vfs.MayReadFileWithOpenFlags(flags) + fd.writable = vfs.MayWriteFileWithOpenFlags(flags) + if fd.writable { + if err := mnt.CheckBeginWrite(); err != nil { + return nil, err + } + // mnt.EndWrite() is called by regularFileFD.Release(). + } + fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}) + if flags&linux.O_TRUNC != 0 { + impl.mu.Lock() + impl.data.Truncate(0, impl.memFile) + atomic.StoreUint64(&impl.size, 0) + impl.mu.Unlock() + } + return &fd.vfsfd, nil + case *directory: + // Can't open directories writably. + if ats&vfs.MayWrite != 0 { + return nil, syserror.EISDIR + } + var fd directoryFD + fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}) + return &fd.vfsfd, nil + case *symlink: + // Can't open symlinks without O_PATH (which is unimplemented). + return nil, syserror.ELOOP + case *namedPipe: + return newNamedPipeFD(ctx, impl, rp, &d.vfsd, flags) + default: + panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl)) + } +} + +// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. +func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := resolveLocked(rp) + if err != nil { + return "", err + } + symlink, ok := d.inode.impl.(*symlink) + if !ok { + return "", syserror.EINVAL + } + return symlink.target, nil +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { + if opts.Flags != 0 { + // TODO(b/145974740): Support renameat2 flags. + return syserror.EINVAL + } + + // Resolve newParent first to verify that it's on this Mount. + fs.mu.Lock() + defer fs.mu.Unlock() + newParent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + if err != nil { + return err + } + newName := rp.Component() + if newName == "." || newName == ".." { + return syserror.EBUSY + } + mnt := rp.Mount() + if mnt != oldParentVD.Mount() { + return syserror.EXDEV + } + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + + oldParent := oldParentVD.Dentry().Impl().(*dentry) + if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + return err + } + // Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(), + // because if the existing child is a symlink or mount point then we want + // to rename over it rather than follow it. + renamedVFSD := oldParent.vfsd.Child(oldName) + if renamedVFSD == nil { + return syserror.ENOENT + } + renamed := renamedVFSD.Impl().(*dentry) + if renamed.inode.isDir() { + if renamed == newParent || renamedVFSD.IsAncestorOf(&newParent.vfsd) { + return syserror.EINVAL + } + if oldParent != newParent { + // Writability is needed to change renamed's "..". + if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil { + return err + } + } + } else { + if opts.MustBeDir || rp.MustBeDir() { + return syserror.ENOTDIR + } + } + + if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + return err + } + replacedVFSD := newParent.vfsd.Child(newName) + var replaced *dentry + if replacedVFSD != nil { + replaced = replacedVFSD.Impl().(*dentry) + if replaced.inode.isDir() { + if !renamed.inode.isDir() { + return syserror.EISDIR + } + if replaced.vfsd.HasChildren() { + return syserror.ENOTEMPTY + } + } else { + if rp.MustBeDir() { + return syserror.ENOTDIR + } + if renamed.inode.isDir() { + return syserror.ENOTDIR + } + } + } else { + if renamed.inode.isDir() && newParent.inode.nlink == maxLinks { + return syserror.EMLINK + } + } + if newParent.vfsd.IsDisowned() { + return syserror.ENOENT + } + + // Linux places this check before some of those above; we do it here for + // simplicity, under the assumption that applications are not intentionally + // doing noop renames expecting them to succeed where non-noop renames + // would fail. + if renamedVFSD == replacedVFSD { + return nil + } + vfsObj := rp.VirtualFilesystem() + oldParentDir := oldParent.inode.impl.(*directory) + newParentDir := newParent.inode.impl.(*directory) + if err := vfsObj.PrepareRenameDentry(vfs.MountNamespaceFromContext(ctx), renamedVFSD, replacedVFSD); err != nil { + return err + } + if replaced != nil { + newParentDir.childList.Remove(replaced) + if replaced.inode.isDir() { + newParent.inode.decLinksLocked() // from replaced's ".." + } + replaced.inode.decLinksLocked() + } + oldParentDir.childList.Remove(renamed) + newParentDir.childList.PushBack(renamed) + if renamed.inode.isDir() { + oldParent.inode.decLinksLocked() + newParent.inode.incLinksLocked() + } + // TODO: update timestamps and parent directory sizes + vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD) + return nil +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + if err != nil { + return err + } + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + return err + } + name := rp.Component() + if name == "." { + return syserror.EINVAL + } + if name == ".." { + return syserror.ENOTEMPTY + } + childVFSD := parent.vfsd.Child(name) + if childVFSD == nil { + return syserror.ENOENT + } + child := childVFSD.Impl().(*dentry) + if !child.inode.isDir() { + return syserror.ENOTDIR + } + if childVFSD.HasChildren() { + return syserror.ENOTEMPTY + } + mnt := rp.Mount() + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + vfsObj := rp.VirtualFilesystem() + if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil { + return err + } + parent.inode.impl.(*directory).childList.Remove(child) + parent.inode.decLinksLocked() // from child's ".." + child.inode.decLinksLocked() + vfsObj.CommitDeleteDentry(childVFSD) + return nil +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, err := resolveLocked(rp) + if err != nil { + return err + } + if opts.Stat.Mask == 0 { + return nil + } + // TODO: implement inode.setStat + return syserror.EPERM +} + +// StatAt implements vfs.FilesystemImpl.StatAt. +func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := resolveLocked(rp) + if err != nil { + return linux.Statx{}, err + } + var stat linux.Statx + d.inode.statTo(&stat) + return stat, nil +} + +// StatFSAt implements vfs.FilesystemImpl.StatFSAt. +func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, err := resolveLocked(rp) + if err != nil { + return linux.Statfs{}, err + } + // TODO: actually implement statfs + return linux.Statfs{}, syserror.ENOSYS +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error { + child := fs.newDentry(fs.newSymlink(rp.Credentials(), target)) + parent.vfsd.InsertChild(&child.vfsd, name) + parent.inode.impl.(*directory).childList.PushBack(child) + return nil + }) +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + parent, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + if err != nil { + return err + } + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + return err + } + name := rp.Component() + if name == "." || name == ".." { + return syserror.EISDIR + } + childVFSD := parent.vfsd.Child(name) + if childVFSD == nil { + return syserror.ENOENT + } + child := childVFSD.Impl().(*dentry) + if child.inode.isDir() { + return syserror.EISDIR + } + if !rp.MustBeDir() { + return syserror.ENOTDIR + } + mnt := rp.Mount() + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + vfsObj := rp.VirtualFilesystem() + if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil { + return err + } + parent.inode.impl.(*directory).childList.Remove(child) + child.inode.decLinksLocked() + vfsObj.CommitDeleteDentry(childVFSD) + return nil +} + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, err := resolveLocked(rp) + if err != nil { + return nil, err + } + // TODO(b/127675828): support extended attributes + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, err := resolveLocked(rp) + if err != nil { + return "", err + } + // TODO(b/127675828): support extended attributes + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, err := resolveLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, err := resolveLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 91cb4b1fc..40bde54de 100644 --- a/pkg/sentry/fsimpl/memfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package memfs +package tmpfs import ( "gvisor.dev/gvisor/pkg/abi/linux" @@ -55,8 +55,6 @@ func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, v return nil, err } mnt := rp.Mount() - mnt.IncRef() - vfsd.IncRef() - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil } diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 5bf527c80..70b42a6ec 100644 --- a/pkg/sentry/fsimpl/memfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package memfs +package tmpfs import ( "bytes" "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -38,7 +39,7 @@ func TestSeparateFDs(t *testing.T) { pop := vfs.PathOperation{ Root: root, Start: root, - Pathname: fileName, + Path: fspath.Parse(fileName), FollowFinalSymlink: true, } rfdchan := make(chan *vfs.FileDescription) @@ -76,7 +77,7 @@ func TestNonblockingRead(t *testing.T) { pop := vfs.PathOperation{ Root: root, Start: root, - Pathname: fileName, + Path: fspath.Parse(fileName), FollowFinalSymlink: true, } openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} @@ -108,7 +109,7 @@ func TestNonblockingWriteError(t *testing.T) { pop := vfs.PathOperation{ Root: root, Start: root, - Pathname: fileName, + Path: fspath.Parse(fileName), FollowFinalSymlink: true, } openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} @@ -126,7 +127,7 @@ func TestSingleFD(t *testing.T) { pop := vfs.PathOperation{ Root: root, Start: root, - Pathname: fileName, + Path: fspath.Parse(fileName), FollowFinalSymlink: true, } openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} @@ -151,8 +152,10 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy // Create VFS. vfsObj := vfs.New() - vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) + vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } @@ -160,10 +163,9 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy // Create the pipe. root := mntns.Root() pop := vfs.PathOperation{ - Root: root, - Start: root, - Pathname: fileName, - FollowFinalSymlink: true, + Root: root, + Start: root, + Path: fspath.Parse(fileName), } mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { @@ -174,7 +176,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ Root: root, Start: root, - Pathname: fileName, + Path: fspath.Parse(fileName), FollowFinalSymlink: true, }, &vfs.StatOptions{}) if err != nil { diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go new file mode 100644 index 000000000..f51e247a7 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -0,0 +1,357 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +import ( + "io" + "math" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +type regularFile struct { + inode inode + + // memFile is a platform.File used to allocate pages to this regularFile. + memFile *pgalloc.MemoryFile + + // mu protects the fields below. + mu sync.RWMutex + + // data maps offsets into the file to offsets into memFile that store + // the file's data. + data fsutil.FileRangeSet + + // size is the size of data, but accessed using atomic memory + // operations to avoid locking in inode.stat(). + size uint64 + + // seals represents file seals on this inode. + seals uint32 +} + +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { + file := ®ularFile{ + memFile: fs.memFile, + } + file.inode.init(file, fs, creds, mode) + file.inode.nlink = 1 // from parent directory + return &file.inode +} + +type regularFileFD struct { + fileDescription + + // These are immutable. + readable bool + writable bool + + // off is the file offset. off is accessed using atomic memory operations. + // offMu serializes operations that may mutate off. + off int64 + offMu sync.Mutex +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *regularFileFD) Release() { + if fd.writable { + fd.vfsfd.VirtualDentry().Mount().EndWrite() + } +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if !fd.readable { + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + if dst.NumBytes() == 0 { + return 0, nil + } + f := fd.inode().impl.(*regularFile) + rw := getRegularFileReadWriter(f, offset) + n, err := dst.CopyOutFrom(ctx, rw) + putRegularFileReadWriter(rw) + return int64(n), err +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.offMu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + if !fd.writable { + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + srclen := src.NumBytes() + if srclen == 0 { + return 0, nil + } + f := fd.inode().impl.(*regularFile) + end := offset + srclen + if end < offset { + // Overflow. + return 0, syserror.EFBIG + } + rw := getRegularFileReadWriter(f, offset) + n, err := src.CopyInTo(ctx, rw) + putRegularFileReadWriter(rw) + return n, err +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.offMu.Lock() + n, err := fd.PWrite(ctx, src, fd.off, opts) + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.offMu.Lock() + defer fd.offMu.Unlock() + switch whence { + case linux.SEEK_SET: + // use offset as specified + case linux.SEEK_CUR: + offset += fd.off + case linux.SEEK_END: + offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size)) + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + +// Sync implements vfs.FileDescriptionImpl.Sync. +func (fd *regularFileFD) Sync(ctx context.Context) error { + return nil +} + +// regularFileReadWriter implements safemem.Reader and Safemem.Writer. +type regularFileReadWriter struct { + file *regularFile + + // Offset into the file to read/write at. Note that this may be + // different from the FD offset if PRead/PWrite is used. + off uint64 +} + +var regularFileReadWriterPool = sync.Pool{ + New: func() interface{} { + return ®ularFileReadWriter{} + }, +} + +func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter { + rw := regularFileReadWriterPool.Get().(*regularFileReadWriter) + rw.file = file + rw.off = uint64(offset) + return rw +} + +func putRegularFileReadWriter(rw *regularFileReadWriter) { + rw.file = nil + regularFileReadWriterPool.Put(rw) +} + +// ReadToBlocks implements safemem.Reader.ReadToBlocks. +func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + rw.file.mu.RLock() + + // Compute the range to read (limited by file size and overflow-checked). + if rw.off >= rw.file.size { + rw.file.mu.RUnlock() + return 0, io.EOF + } + end := rw.file.size + if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end { + end = rend + } + + var done uint64 + seg, gap := rw.file.data.Find(uint64(rw.off)) + for rw.off < end { + mr := memmap.MappableRange{uint64(rw.off), uint64(end)} + switch { + case seg.Ok(): + // Get internal mappings. + ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) + if err != nil { + rw.file.mu.RUnlock() + return done, err + } + + // Copy from internal mappings. + n, err := safemem.CopySeq(dsts, ims) + done += n + rw.off += uint64(n) + dsts = dsts.DropFirst64(n) + if err != nil { + rw.file.mu.RUnlock() + return done, err + } + + // Continue. + seg, gap = seg.NextNonEmpty() + + case gap.Ok(): + // Tmpfs holes are zero-filled. + gapmr := gap.Range().Intersect(mr) + dst := dsts.TakeFirst64(gapmr.Length()) + n, err := safemem.ZeroSeq(dst) + done += n + rw.off += uint64(n) + dsts = dsts.DropFirst64(n) + if err != nil { + rw.file.mu.RUnlock() + return done, err + } + + // Continue. + seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} + } + } + rw.file.mu.RUnlock() + return done, nil +} + +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + rw.file.mu.Lock() + + // Compute the range to write (overflow-checked). + end := rw.off + srcs.NumBytes() + if end <= rw.off { + end = math.MaxInt64 + } + + // Check if seals prevent either file growth or all writes. + switch { + case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed + rw.file.mu.Unlock() + return 0, syserror.EPERM + case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed + // When growth is sealed, Linux effectively allows writes which would + // normally grow the file to partially succeed up to the current EOF, + // rounded down to the page boundary before the EOF. + // + // This happens because writes (and thus the growth check) for tmpfs + // files proceed page-by-page on Linux, and the final write to the page + // containing EOF fails, resulting in a partial write up to the start of + // that page. + // + // To emulate this behaviour, artifically truncate the write to the + // start of the page containing the current EOF. + // + // See Linux, mm/filemap.c:generic_perform_write() and + // mm/shmem.c:shmem_write_begin(). + if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart { + end = pgstart + } + if end <= rw.off { + // Truncation would result in no data being written. + rw.file.mu.Unlock() + return 0, syserror.EPERM + } + } + + // Page-aligned mr for when we need to allocate memory. RoundUp can't + // overflow since end is an int64. + pgstartaddr := usermem.Addr(rw.off).RoundDown() + pgendaddr, _ := usermem.Addr(end).RoundUp() + pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)} + + var ( + done uint64 + retErr error + ) + seg, gap := rw.file.data.Find(uint64(rw.off)) + for rw.off < end { + mr := memmap.MappableRange{uint64(rw.off), uint64(end)} + switch { + case seg.Ok(): + // Get internal mappings. + ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write) + if err != nil { + retErr = err + goto exitLoop + } + + // Copy to internal mappings. + n, err := safemem.CopySeq(ims, srcs) + done += n + rw.off += uint64(n) + srcs = srcs.DropFirst64(n) + if err != nil { + retErr = err + goto exitLoop + } + + // Continue. + seg, gap = seg.NextNonEmpty() + + case gap.Ok(): + // Allocate memory for the write. + gapMR := gap.Range().Intersect(pgMR) + fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs) + if err != nil { + retErr = err + goto exitLoop + } + + // Write to that memory as usual. + seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} + } + } +exitLoop: + // If the write ends beyond the file's previous size, it causes the + // file to grow. + if rw.off > rw.file.size { + atomic.StoreUint64(&rw.file.size, rw.off) + } + + rw.file.mu.Unlock() + return done, retErr +} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go new file mode 100644 index 000000000..3731c5b6f --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -0,0 +1,224 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +import ( + "bytes" + "fmt" + "io" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If +// the returned err is not nil, then cleanup should be called when the FD is no +// longer needed. +func newFileFD(ctx context.Context, filename string) (*vfs.FileDescription, func(), error) { + creds := auth.CredentialsFromContext(ctx) + + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + if err != nil { + return nil, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) + } + root := mntns.Root() + + // Create the file that will be write/read. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + FollowFinalSymlink: true, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, + Mode: 0644, + }) + if err != nil { + root.DecRef() + mntns.DecRef(vfsObj) + return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err) + } + + return fd, func() { + root.DecRef() + mntns.DecRef(vfsObj) + }, nil +} + +// Test that we can write some data to a file and read it back.` +func TestSimpleWriteRead(t *testing.T) { + ctx := contexttest.Context(t) + fd, cleanup, err := newFileFD(ctx, "simpleReadWrite") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + // Write. + data := []byte("foobarbaz") + n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + t.Fatalf("fd.Write failed: %v", err) + } + if n != int64(len(data)) { + t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) + } + if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want { + t.Errorf("fd.Write left offset at %d, want %d", got, want) + } + + // Seek back to beginning. + if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil { + t.Fatalf("fd.Seek failed: %v", err) + } + if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want { + t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want) + } + + // Read. + buf := make([]byte, len(data)) + n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read failed: %v", err) + } + if n != int64(len(data)) { + t.Errorf("fd.Read got short read length %d, want %d", n, len(data)) + } + if got, want := string(buf), string(data); got != want { + t.Errorf("Read got %q want %s", got, want) + } + if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want { + t.Errorf("fd.Write left offset at %d, want %d", got, want) + } +} + +func TestPWrite(t *testing.T) { + ctx := contexttest.Context(t) + fd, cleanup, err := newFileFD(ctx, "PRead") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + // Fill file with 1k 'a's. + data := bytes.Repeat([]byte{'a'}, 1000) + n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + t.Fatalf("fd.Write failed: %v", err) + } + if n != int64(len(data)) { + t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) + } + + // Write "gVisor is awesome" at various offsets. + buf := []byte("gVisor is awesome") + offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1} + for _, offset := range offsets { + name := fmt.Sprintf("PWrite offset=%d", offset) + t.Run(name, func(t *testing.T) { + n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{}) + if err != nil { + t.Errorf("fd.PWrite got err %v want nil", err) + } + if n != int64(len(buf)) { + t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf)) + } + + // Update data to reflect expected file contents. + if len(data) < offset+len(buf) { + data = append(data, make([]byte, (offset+len(buf))-len(data))...) + } + copy(data[offset:], buf) + + // Read the whole file and compare with data. + readBuf := make([]byte, len(data)) + n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{}) + if err != nil { + t.Fatalf("fd.PRead failed: %v", err) + } + if n != int64(len(data)) { + t.Errorf("fd.PRead got short read length %d, want %d", n, len(data)) + } + if got, want := string(readBuf), string(data); got != want { + t.Errorf("PRead got %q want %s", got, want) + } + + }) + } +} + +func TestPRead(t *testing.T) { + ctx := contexttest.Context(t) + fd, cleanup, err := newFileFD(ctx, "PRead") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + // Write 100 sequences of 'gVisor is awesome'. + data := bytes.Repeat([]byte("gVisor is awsome"), 100) + n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + t.Fatalf("fd.Write failed: %v", err) + } + if n != int64(len(data)) { + t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) + } + + // Read various sizes from various offsets. + sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000} + offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1} + + for _, size := range sizes { + for _, offset := range offsets { + name := fmt.Sprintf("PRead offset=%d size=%d", offset, size) + t.Run(name, func(t *testing.T) { + var ( + wantRead []byte + wantErr error + ) + if offset < len(data) { + wantRead = data[offset:] + } else if size > 0 { + wantErr = io.EOF + } + if offset+size < len(data) { + wantRead = wantRead[:size] + } + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{}) + if err != wantErr { + t.Errorf("fd.PRead got err %v want %v", err, wantErr) + } + if n != int64(len(wantRead)) { + t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead)) + } + if got := string(buf[:n]); got != string(wantRead) { + t.Errorf("fd.PRead got %q want %q", got, string(wantRead)) + } + }) + } + } +} diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go index b2ac2cbeb..5246aca84 100644 --- a/pkg/sentry/fsimpl/memfs/symlink.go +++ b/pkg/sentry/fsimpl/tmpfs/symlink.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package memfs +package tmpfs import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 4cb2a4e0f..7be6faa5b 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -12,29 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package memfs provides a filesystem implementation that behaves like tmpfs: +// Package tmpfs provides a filesystem implementation that behaves like tmpfs: // the Dentry tree is the sole source of truth for the state of the filesystem. // -// memfs is intended primarily to demonstrate filesystem implementation -// patterns. Real uses cases for an in-memory filesystem should use tmpfs -// instead. -// // Lock order: // // filesystem.mu // regularFileFD.offMu // regularFile.mu // inode.mu -package memfs +package tmpfs import ( "fmt" + "math" "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -46,6 +44,9 @@ type FilesystemType struct{} type filesystem struct { vfsfs vfs.Filesystem + // memFile is used to allocate pages to for regular files. + memFile *pgalloc.MemoryFile + // mu serializes changes to the Dentry tree. mu sync.RWMutex @@ -54,7 +55,13 @@ type filesystem struct { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - var fs filesystem + memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) + if memFileProvider == nil { + panic("MemoryFileProviderFromContext returned nil") + } + fs := filesystem{ + memFile: memFileProvider.MemoryFile(), + } fs.vfsfs.Init(vfsObj, &fs) root := fs.newDentry(fs.newDirectory(creds, 01777)) return &fs.vfsfs, &root.vfsd, nil @@ -64,12 +71,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt func (fs *filesystem) Release() { } -// Sync implements vfs.FilesystemImpl.Sync. -func (fs *filesystem) Sync(ctx context.Context) error { - // All filesystem state is in-memory. - return nil -} - // dentry implements vfs.DentryImpl. type dentry struct { vfsd vfs.Dentry @@ -79,11 +80,11 @@ type dentry struct { // immutable. inode *inode - // memfs doesn't count references on dentries; because the dentry tree is + // tmpfs doesn't count references on dentries; because the dentry tree is // the sole source of truth, it is by definition always consistent with the // state of the filesystem. However, it does count references on inodes, // because inode resources are released when all references are dropped. - // (memfs doesn't really have resources to release, but we implement + // (tmpfs doesn't really have resources to release, but we implement // reference counting because tmpfs regular files will.) // dentryEntry (ugh) links dentries into their parent directory.childList. @@ -137,6 +138,8 @@ type inode struct { impl interface{} // immutable } +const maxLinks = math.MaxUint32 + func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { i.refs = 1 i.mode = uint32(mode) @@ -147,25 +150,33 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, i.impl = impl } -// Preconditions: filesystem.mu must be locked for writing. +// incLinksLocked increments i's link count. +// +// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. +// i.nlink < maxLinks. func (i *inode) incLinksLocked() { - if atomic.AddUint32(&i.nlink, 1) <= 1 { - panic("memfs.inode.incLinksLocked() called with no existing links") + if i.nlink == 0 { + panic("tmpfs.inode.incLinksLocked() called with no existing links") } + if i.nlink == maxLinks { + panic("memfs.inode.incLinksLocked() called with maximum link count") + } + atomic.AddUint32(&i.nlink, 1) } -// Preconditions: filesystem.mu must be locked for writing. +// decLinksLocked decrements i's link count. +// +// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. func (i *inode) decLinksLocked() { - if nlink := atomic.AddUint32(&i.nlink, ^uint32(0)); nlink == 0 { - i.decRef() - } else if nlink == ^uint32(0) { // negative overflow - panic("memfs.inode.decLinksLocked() called with no existing links") + if i.nlink == 0 { + panic("tmpfs.inode.decLinksLocked() called with no existing links") } + atomic.AddUint32(&i.nlink, ^uint32(0)) } func (i *inode) incRef() { if atomic.AddInt64(&i.refs, 1) <= 1 { - panic("memfs.inode.incRef() called without holding a reference") + panic("tmpfs.inode.incRef() called without holding a reference") } } @@ -184,14 +195,14 @@ func (i *inode) tryIncRef() bool { func (i *inode) decRef() { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { // This is unnecessary; it's mostly to simulate what tmpfs would do. - if regfile, ok := i.impl.(*regularFile); ok { - regfile.mu.Lock() - regfile.data = nil - atomic.StoreInt64(®file.dataLen, 0) - regfile.mu.Unlock() + if regFile, ok := i.impl.(*regularFile); ok { + regFile.mu.Lock() + regFile.data.DropAll(regFile.memFile) + atomic.StoreUint64(®File.size, 0) + regFile.mu.Unlock() } } else if refs < 0 { - panic("memfs.inode.decRef() called without holding a reference") + panic("tmpfs.inode.decRef() called without holding a reference") } } @@ -215,7 +226,7 @@ func (i *inode) statTo(stat *linux.Statx) { case *regularFile: stat.Mode |= linux.S_IFREG stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS - stat.Size = uint64(atomic.LoadInt64(&impl.dataLen)) + stat.Size = uint64(atomic.LoadUint64(&impl.size)) // In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in // a uint64 accessed using atomic memory operations to avoid taking // locks). @@ -256,13 +267,11 @@ func (i *inode) direntType() uint8 { } } -// fileDescription is embedded by memfs implementations of +// fileDescription is embedded by tmpfs implementations of // vfs.FileDescriptionImpl. type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - - flags uint32 // status flags; immutable } func (fd *fileDescription) filesystem() *filesystem { @@ -273,18 +282,6 @@ func (fd *fileDescription) inode() *inode { return fd.vfsfd.Dentry().Impl().(*dentry).inode } -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil -} - // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index bd3fb4c03..8653d2f63 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -762,7 +762,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, mounts.IncRef() } - tg := k.newThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock) + tg := k.NewThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits) ctx := args.NewContext(k) // Get the root directory from the MountNamespace. @@ -1191,6 +1191,11 @@ func (k *Kernel) GlobalInit() *ThreadGroup { return k.globalInit } +// TestOnly_SetGlobalInit sets the thread group with ID 1 in the root PID namespace. +func (k *Kernel) TestOnly_SetGlobalInit(tg *ThreadGroup) { + k.globalInit = tg +} + // ApplicationCores returns the number of CPUs visible to sandboxed // applications. func (k *Kernel) ApplicationCores() uint { diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 24ea002ba..b14429854 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -15,17 +15,29 @@ package kernel import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) -// Restartable sequences, as described in https://lwn.net/Articles/650333/. +// Restartable sequences. +// +// We support two different APIs for restartable sequences. +// +// 1. The upstream interface added in v4.18. +// 2. The interface described in https://lwn.net/Articles/650333/. +// +// Throughout this file and other parts of the kernel, the latter is referred +// to as "old rseq". This interface was never merged upstream, but is supported +// for a limited set of applications that use it regardless. -// RSEQCriticalRegion describes a restartable sequence critical region. +// OldRSeqCriticalRegion describes an old rseq critical region. // // +stateify savable -type RSEQCriticalRegion struct { +type OldRSeqCriticalRegion struct { // When a task in this thread group has its CPU preempted (as defined by // platform.ErrContextCPUPreempted) or has a signal delivered to an // application handler while its instruction pointer is in CriticalSection, @@ -35,86 +47,359 @@ type RSEQCriticalRegion struct { Restart usermem.Addr } -// RSEQAvailable returns true if t supports restartable sequences. -func (t *Task) RSEQAvailable() bool { +// RSeqAvailable returns true if t supports (old and new) restartable sequences. +func (t *Task) RSeqAvailable() bool { return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption() } -// RSEQCriticalRegion returns a copy of t's thread group's current restartable -// sequence. -func (t *Task) RSEQCriticalRegion() RSEQCriticalRegion { - return *t.tg.rscr.Load().(*RSEQCriticalRegion) +// SetRSeq registers addr as this thread's rseq structure. +// +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error { + if t.rseqAddr != 0 { + if t.rseqAddr != addr { + return syserror.EINVAL + } + if t.rseqSignature != signature { + return syserror.EINVAL + } + return syserror.EBUSY + } + + // rseq must be aligned and correctly sized. + if addr&(linux.AlignOfRSeq-1) != 0 { + return syserror.EINVAL + } + if length != linux.SizeOfRSeq { + return syserror.EINVAL + } + if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok { + return syserror.EFAULT + } + + t.rseqAddr = addr + t.rseqSignature = signature + + // Initialize the CPUID. + // + // Linux implicitly does this on return from userspace, where failure + // would cause SIGSEGV. + if err := t.rseqUpdateCPU(); err != nil { + t.rseqAddr = 0 + t.rseqSignature = 0 + + t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return syserror.EFAULT + } + + return nil } -// SetRSEQCriticalRegion replaces t's thread group's restartable sequence. +// ClearRSeq unregisters addr as this thread's rseq structure. // -// Preconditions: t.RSEQAvailable() == true. -func (t *Task) SetRSEQCriticalRegion(rscr RSEQCriticalRegion) error { +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error { + if t.rseqAddr == 0 { + return syserror.EINVAL + } + if t.rseqAddr != addr { + return syserror.EINVAL + } + if length != linux.SizeOfRSeq { + return syserror.EINVAL + } + if t.rseqSignature != signature { + return syserror.EPERM + } + + if err := t.rseqClearCPU(); err != nil { + return err + } + + t.rseqAddr = 0 + t.rseqSignature = 0 + + if t.oldRSeqCPUAddr == 0 { + // rseqCPU no longer needed. + t.rseqCPU = -1 + } + + return nil +} + +// OldRSeqCriticalRegion returns a copy of t's thread group's current +// old restartable sequence. +func (t *Task) OldRSeqCriticalRegion() OldRSeqCriticalRegion { + return *t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion) +} + +// SetOldRSeqCriticalRegion replaces t's thread group's old restartable +// sequence. +// +// Preconditions: t.RSeqAvailable() == true. +func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error { // These checks are somewhat more lenient than in Linux, which (bizarrely) - // requires rscr.CriticalSection to be non-empty and rscr.Restart to be - // outside of rscr.CriticalSection, even if rscr.CriticalSection.Start == 0 + // requires r.CriticalSection to be non-empty and r.Restart to be + // outside of r.CriticalSection, even if r.CriticalSection.Start == 0 // (which disables the critical region). - if rscr.CriticalSection.Start == 0 { - rscr.CriticalSection.End = 0 - rscr.Restart = 0 - t.tg.rscr.Store(&rscr) + if r.CriticalSection.Start == 0 { + r.CriticalSection.End = 0 + r.Restart = 0 + t.tg.oldRSeqCritical.Store(&r) return nil } - if rscr.CriticalSection.Start >= rscr.CriticalSection.End { + if r.CriticalSection.Start >= r.CriticalSection.End { return syserror.EINVAL } - if rscr.CriticalSection.Contains(rscr.Restart) { + if r.CriticalSection.Contains(r.Restart) { return syserror.EINVAL } - // TODO(jamieliu): check that rscr.CriticalSection and rscr.Restart are in - // the application address range, for consistency with Linux - t.tg.rscr.Store(&rscr) + // TODO(jamieliu): check that r.CriticalSection and r.Restart are in + // the application address range, for consistency with Linux. + t.tg.oldRSeqCritical.Store(&r) return nil } -// RSEQCPUAddr returns the address that RSEQ will keep updated with t's CPU -// number. +// OldRSeqCPUAddr returns the address that old rseq will keep updated with t's +// CPU number. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) RSEQCPUAddr() usermem.Addr { - return t.rseqCPUAddr +func (t *Task) OldRSeqCPUAddr() usermem.Addr { + return t.oldRSeqCPUAddr } -// SetRSEQCPUAddr replaces the address that RSEQ will keep updated with t's CPU -// number. +// SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with +// t's CPU number. // -// Preconditions: t.RSEQAvailable() == true. The caller must be running on the +// Preconditions: t.RSeqAvailable() == true. The caller must be running on the // task goroutine. t's AddressSpace must be active. -func (t *Task) SetRSEQCPUAddr(addr usermem.Addr) error { - t.rseqCPUAddr = addr - if addr != 0 { - t.rseqCPU = int32(hostcpu.GetCPU()) - if err := t.rseqCopyOutCPU(); err != nil { - t.rseqCPUAddr = 0 - t.rseqCPU = -1 - return syserror.EINVAL // yes, EINVAL, not err or EFAULT - } - } else { - t.rseqCPU = -1 +func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error { + t.oldRSeqCPUAddr = addr + + // Check that addr is writable. + // + // N.B. rseqUpdateCPU may fail on a bad t.rseqAddr as well. That's + // unfortunate, but unlikely in a correct program. + if err := t.rseqUpdateCPU(); err != nil { + t.oldRSeqCPUAddr = 0 + return syserror.EINVAL // yes, EINVAL, not err or EFAULT } return nil } // Preconditions: The caller must be running on the task goroutine. t's // AddressSpace must be active. -func (t *Task) rseqCopyOutCPU() error { +func (t *Task) rseqUpdateCPU() error { + if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 { + t.rseqCPU = -1 + return nil + } + + t.rseqCPU = int32(hostcpu.GetCPU()) + + // Update both CPUs, even if one fails. + rerr := t.rseqCopyOutCPU() + oerr := t.oldRSeqCopyOutCPU() + + if rerr != nil { + return rerr + } + return oerr +} + +// Preconditions: The caller must be running on the task goroutine. t's +// AddressSpace must be active. +func (t *Task) oldRSeqCopyOutCPU() error { + if t.oldRSeqCPUAddr == 0 { + return nil + } + buf := t.CopyScratchBuffer(4) usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) - _, err := t.CopyOutBytes(t.rseqCPUAddr, buf) + _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf) + return err +} + +// Preconditions: The caller must be running on the task goroutine. t's +// AddressSpace must be active. +func (t *Task) rseqCopyOutCPU() error { + if t.rseqAddr == 0 { + return nil + } + + buf := t.CopyScratchBuffer(8) + // CPUIDStart and CPUID are the first two fields in linux.RSeq. + usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart + usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID + // N.B. This write is not atomic, but since this occurs on the task + // goroutine then as long as userspace uses a single-instruction read + // it can't see an invalid value. + _, err := t.CopyOutBytes(t.rseqAddr, buf) + return err +} + +// Preconditions: The caller must be running on the task goroutine. t's +// AddressSpace must be active. +func (t *Task) rseqClearCPU() error { + buf := t.CopyScratchBuffer(8) + // CPUIDStart and CPUID are the first two fields in linux.RSeq. + usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart + usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID + // N.B. This write is not atomic, but since this occurs on the task + // goroutine then as long as userspace uses a single-instruction read + // it can't see an invalid value. + _, err := t.CopyOutBytes(t.rseqAddr, buf) return err } +// rseqAddrInterrupt checks if IP is in a critical section, and aborts if so. +// +// This is a bit complex since both the RSeq and RSeqCriticalSection structs +// are stored in userspace. So we must: +// +// 1. Copy in the address of RSeqCriticalSection from RSeq. +// 2. Copy in RSeqCriticalSection itself. +// 3. Validate critical section struct version, address range, abort address. +// 4. Validate the abort signature (4 bytes preceding abort IP match expected +// signature). +// 5. Clear address of RSeqCriticalSection from RSeq. +// 6. Finally, conditionally abort. +// +// See kernel/rseq.c:rseq_ip_fixup for reference. +// +// Preconditions: The caller must be running on the task goroutine. t's +// AddressSpace must be active. +func (t *Task) rseqAddrInterrupt() { + if t.rseqAddr == 0 { + return + } + + critAddrAddr, ok := t.rseqAddr.AddLength(linux.OffsetOfRSeqCriticalSection) + if !ok { + // SetRSeq should validate this. + panic(fmt.Sprintf("t.rseqAddr (%#x) not large enough", t.rseqAddr)) + } + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + t.Debugf("Only 64-bit rseq supported.") + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + buf := t.CopyScratchBuffer(8) + if _, err := t.CopyInBytes(critAddrAddr, buf); err != nil { + t.Debugf("Failed to copy critical section address from %#x for rseq: %v", critAddrAddr, err) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf)) + if critAddr == 0 { + return + } + + buf = t.CopyScratchBuffer(linux.SizeOfRSeqCriticalSection) + if _, err := t.CopyInBytes(critAddr, buf); err != nil { + t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + // Manually marshal RSeqCriticalSection as this is in the hot path when + // rseq is enabled. It must be as fast as possible. + // + // TODO(b/130243041): Replace with go_marshal. + cs := linux.RSeqCriticalSection{ + Version: usermem.ByteOrder.Uint32(buf[0:4]), + Flags: usermem.ByteOrder.Uint32(buf[4:8]), + Start: usermem.ByteOrder.Uint64(buf[8:16]), + PostCommitOffset: usermem.ByteOrder.Uint64(buf[16:24]), + Abort: usermem.ByteOrder.Uint64(buf[24:32]), + } + + if cs.Version != 0 { + t.Debugf("Unknown version in %+v", cs) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + start := usermem.Addr(cs.Start) + critRange, ok := start.ToRange(cs.PostCommitOffset) + if !ok { + t.Debugf("Invalid start and offset in %+v", cs) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + abort := usermem.Addr(cs.Abort) + if critRange.Contains(abort) { + t.Debugf("Abort in critical section in %+v", cs) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + // Verify signature. + sigAddr := abort - linux.SizeOfRSeqSignature + + buf = t.CopyScratchBuffer(linux.SizeOfRSeqSignature) + if _, err := t.CopyInBytes(sigAddr, buf); err != nil { + t.Debugf("Failed to copy critical section signature from %#x for rseq: %v", sigAddr, err) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + sig := usermem.ByteOrder.Uint32(buf) + if sig != t.rseqSignature { + t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + // Clear the critical section address. + // + // NOTE(b/143949567): We don't support any rseq flags, so we always + // restart if we are in the critical section, and thus *always* clear + // critAddrAddr. + if _, err := t.MemoryManager().ZeroOut(t, critAddrAddr, int64(t.Arch().Width()), usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + t.Debugf("Failed to clear critical section address from %#x for rseq: %v", critAddrAddr, err) + t.forceSignal(linux.SIGSEGV, false /* unconditional */) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + return + } + + // Finally we can actually decide whether or not to restart. + if !critRange.Contains(usermem.Addr(t.Arch().IP())) { + return + } + + t.Arch().SetIP(uintptr(cs.Abort)) +} + // Preconditions: The caller must be running on the task goroutine. -func (t *Task) rseqInterrupt() { - rscr := t.tg.rscr.Load().(*RSEQCriticalRegion) - if ip := t.Arch().IP(); rscr.CriticalSection.Contains(usermem.Addr(ip)) { - t.Debugf("Interrupted RSEQ critical section at %#x; restarting at %#x", ip, rscr.Restart) - t.Arch().SetIP(uintptr(rscr.Restart)) - t.Arch().SetRSEQInterruptedIP(ip) +func (t *Task) oldRSeqInterrupt() { + r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion) + if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) { + t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart) + t.Arch().SetIP(uintptr(r.Restart)) + t.Arch().SetOldRSeqInterruptedIP(ip) } } + +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) rseqInterrupt() { + t.rseqAddrInterrupt() + t.oldRSeqInterrupt() +} diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 5bd610f68..19034a21e 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -71,9 +71,20 @@ type Registry struct { mu sync.Mutex `state:"nosave"` // shms maps segment ids to segments. + // + // shms holds all referenced segments, which are removed on the last + // DecRef. Thus, it cannot itself hold a reference on the Shm. + // + // Since removal only occurs after the last (unlocked) DecRef, there + // exists a short window during which a Shm still exists in Shm, but is + // unreferenced. Users must use TryIncRef to determine if the Shm is + // still valid. shms map[ID]*Shm // keysToShms maps segment keys to segments. + // + // Shms in keysToShms are guaranteed to be referenced, as they are + // removed by disassociateKey before the last DecRef. keysToShms map[Key]*Shm // Sum of the sizes of all existing segments rounded up to page size, in @@ -95,10 +106,18 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { } // FindByID looks up a segment given an ID. +// +// FindByID returns a reference on Shm. func (r *Registry) FindByID(id ID) *Shm { r.mu.Lock() defer r.mu.Unlock() - return r.shms[id] + s := r.shms[id] + // Take a reference on s. If TryIncRef fails, s has reached the last + // DecRef, but hasn't quite been removed from r.shms yet. + if s != nil && s.TryIncRef() { + return s + } + return nil } // dissociateKey removes the association between a segment and its key, @@ -119,6 +138,8 @@ func (r *Registry) dissociateKey(s *Shm) { // FindOrCreate looks up or creates a segment in the registry. It's functionally // analogous to open(2). +// +// FindOrCreate returns a reference on Shm. func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) { if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) { // "A new segment was to be created and size is less than SHMMIN or @@ -166,6 +187,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui return nil, syserror.EEXIST } + shm.IncRef() return shm, nil } @@ -193,7 +215,14 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui // Need to create a new segment. creator := fs.FileOwnerFromContext(ctx) perms := fs.FilePermsFromMode(mode) - return r.newShm(ctx, pid, key, creator, perms, size) + s, err := r.newShm(ctx, pid, key, creator, perms, size) + if err != nil { + return nil, err + } + // The initial reference is held by s itself. Take another to return to + // the caller. + s.IncRef() + return s, nil } // newShm creates a new segment in the registry. @@ -296,22 +325,26 @@ func (r *Registry) remove(s *Shm) { // Shm represents a single shared memory segment. // -// Shm segment are backed directly by an allocation from platform -// memory. Segments are always mapped as a whole, greatly simplifying how -// mappings are tracked. However note that mremap and munmap calls may cause the -// vma for a segment to become fragmented; which requires special care when -// unmapping a segment. See mm/shm.go. +// Shm segment are backed directly by an allocation from platform memory. +// Segments are always mapped as a whole, greatly simplifying how mappings are +// tracked. However note that mremap and munmap calls may cause the vma for a +// segment to become fragmented; which requires special care when unmapping a +// segment. See mm/shm.go. // // Segments persist until they are explicitly marked for destruction via -// shmctl(SHM_RMID). +// MarkDestroyed(). // // Shm implements memmap.Mappable and memmap.MappingIdentity. // // +stateify savable type Shm struct { - // AtomicRefCount tracks the number of references to this segment from - // maps. A segment always holds a reference to itself, until it's marked for + // AtomicRefCount tracks the number of references to this segment. + // + // A segment holds a reference to itself until it is marked for // destruction. + // + // In addition to direct users, the MemoryManager will hold references + // via MappingIdentity. refs.AtomicRefCount mfp pgalloc.MemoryFileProvider @@ -484,9 +517,8 @@ type AttachOpts struct { // ConfigureAttach creates an mmap configuration for the segment with the // requested attach options. // -// ConfigureAttach returns with a ref on s on success. The caller should drop -// this once the map is installed. This reference prevents s from being -// destroyed before the returned configuration is used. +// Postconditions: The returned MMapOpts are valid only as long as a reference +// continues to be held on s. func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) { s.mu.Lock() defer s.mu.Unlock() @@ -504,7 +536,6 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac // in the user namespace that governs its IPC namespace." - man shmat(2) return memmap.MMapOpts{}, syserror.EACCES } - s.IncRef() return memmap.MMapOpts{ Length: s.size, Offset: 0, @@ -549,10 +580,15 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) { } creds := auth.CredentialsFromContext(ctx) - nattach := uint64(s.ReadRefs()) - // Don't report the self-reference we keep prior to being marked for - // destruction. However, also don't report a count of -1 for segments marked - // as destroyed, with no mappings. + // Use the reference count as a rudimentary count of the number of + // attaches. We exclude: + // + // 1. The reference the caller holds. + // 2. The self-reference held by s prior to destruction. + // + // Note that this may still overcount by including transient references + // used in concurrent calls. + nattach := uint64(s.ReadRefs()) - 1 if !s.pendingDestruction { nattach-- } @@ -620,18 +656,17 @@ func (s *Shm) MarkDestroyed() { s.registry.dissociateKey(s) s.mu.Lock() - // Only drop the segment's self-reference once, when destruction is - // requested. Otherwise, repeated calls to shmctl(IPC_RMID) would force a - // segment to be destroyed prematurely, potentially with active maps to the - // segment's address range. Remaining references are dropped when the - // segment is detached or unmaped. + defer s.mu.Unlock() if !s.pendingDestruction { s.pendingDestruction = true - s.mu.Unlock() // Must release s.mu before calling s.DecRef. + // Drop the self-reference so destruction occurs when all + // external references are gone. + // + // N.B. This cannot be the final DecRef, as the caller also + // holds a reference. s.DecRef() return } - s.mu.Unlock() } // checkOwnership verifies whether a segment may be accessed by ctx as an diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index ab0c6c4aa..d25a7903b 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -489,18 +489,43 @@ type Task struct { // netns is protected by mu. netns is owned by the task goroutine. netns bool - // If rseqPreempted is true, before the next call to p.Switch(), interrupt - // RSEQ critical regions as defined by tg.rseq and write the task - // goroutine's CPU number to rseqCPUAddr. rseqCPU is the last CPU number - // written to rseqCPUAddr. + // If rseqPreempted is true, before the next call to p.Switch(), + // interrupt rseq critical regions as defined by rseqAddr and + // tg.oldRSeqCritical and write the task goroutine's CPU number to + // rseqAddr/oldRSeqCPUAddr. // - // If rseqCPUAddr is 0, rseqCPU is -1. + // We support two ABIs for restartable sequences: // - // rseqCPUAddr, rseqCPU, and rseqPreempted are exclusive to the task - // goroutine. + // 1. The upstream interface added in v4.18, + // 2. An "old" interface never merged upstream. In the implementation, + // this is referred to as "old rseq". + // + // rseqPreempted is exclusive to the task goroutine. rseqPreempted bool `state:"nosave"` - rseqCPUAddr usermem.Addr - rseqCPU int32 + + // rseqCPU is the last CPU number written to rseqAddr/oldRSeqCPUAddr. + // + // If rseq is unused, rseqCPU is -1 for convenient use in + // platform.Context.Switch. + // + // rseqCPU is exclusive to the task goroutine. + rseqCPU int32 + + // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable. + // + // oldRSeqCPUAddr is exclusive to the task goroutine. + oldRSeqCPUAddr usermem.Addr + + // rseqAddr is a pointer to the userspace linux.RSeq structure. + // + // rseqAddr is exclusive to the task goroutine. + rseqAddr usermem.Addr + + // rseqSignature is the signature that the rseq abort IP must be signed + // with. + // + // rseqSignature is exclusive to the task goroutine. + rseqSignature uint32 // copyScratchBuffer is a buffer available to CopyIn/CopyOut // implementations that require an intermediate buffer to copy data diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 3eadfedb4..247bd4aba 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -236,14 +236,19 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else if opts.NewPIDNamespace { pidns = pidns.NewChild(userns) } + tg := t.tg + rseqAddr := usermem.Addr(0) + rseqSignature := uint32(0) if opts.NewThreadGroup { tg.mounts.IncRef() sh := t.tg.signalHandlers if opts.NewSignalHandlers { sh = sh.Fork() } - tg = t.k.newThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy(), t.k.monotonicClock) + tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy()) + rseqAddr = t.rseqAddr + rseqSignature = t.rseqSignature } cfg := &TaskConfig{ @@ -260,6 +265,8 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { UTSNamespace: utsns, IPCNamespace: ipcns, AbstractSocketNamespace: t.abstractSockets, + RSeqAddr: rseqAddr, + RSeqSignature: rseqSignature, ContainerID: t.ContainerID(), } if opts.NewThreadGroup { diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 90a6190f1..fa6528386 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -190,9 +190,11 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.updateRSSLocked() // Restartable sequence state is discarded. t.rseqPreempted = false - t.rseqCPUAddr = 0 t.rseqCPU = -1 - t.tg.rscr.Store(&RSEQCriticalRegion{}) + t.rseqAddr = 0 + t.rseqSignature = 0 + t.oldRSeqCPUAddr = 0 + t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{}) t.tg.pidns.owner.mu.Unlock() // Remove FDs with the CloseOnExec flag set. diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index d97f8c189..6357273d3 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -169,12 +169,22 @@ func (*runApp) execute(t *Task) taskRunState { // Apply restartable sequences. if t.rseqPreempted { t.rseqPreempted = false - if t.rseqCPUAddr != 0 { + if t.rseqAddr != 0 || t.oldRSeqCPUAddr != 0 { + // Linux writes the CPU on every preemption. We only do + // so if it changed. Thus we may delay delivery of + // SIGSEGV if rseqAddr/oldRSeqCPUAddr is invalid. cpu := int32(hostcpu.GetCPU()) if t.rseqCPU != cpu { t.rseqCPU = cpu if err := t.rseqCopyOutCPU(); err != nil { - t.Warningf("Failed to copy CPU to %#x for RSEQ: %v", t.rseqCPUAddr, err) + t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err) + t.forceSignal(linux.SIGSEGV, false) + t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) + // Re-enter the task run loop for signal delivery. + return (*runApp)(nil) + } + if err := t.oldRSeqCopyOutCPU(); err != nil { + t.Debugf("Failed to copy CPU to %#x for old rseq: %v", t.oldRSeqCPUAddr, err) t.forceSignal(linux.SIGSEGV, false) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) // Re-enter the task run loop for signal delivery. @@ -320,7 +330,7 @@ func (*runApp) execute(t *Task) taskRunState { return (*runApp)(nil) case platform.ErrContextCPUPreempted: - // Ensure that RSEQ critical sections are interrupted and per-thread + // Ensure that rseq critical sections are interrupted and per-thread // CPU values are updated before the next platform.Context.Switch(). t.rseqPreempted = true return (*runApp)(nil) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 3522a4ae5..58af16ee2 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -79,6 +80,13 @@ type TaskConfig struct { // AbstractSocketNamespace is the AbstractSocketNamespace of the new task. AbstractSocketNamespace *AbstractSocketNamespace + // RSeqAddr is a pointer to the the userspace linux.RSeq structure. + RSeqAddr usermem.Addr + + // RSeqSignature is the signature that the rseq abort IP must be signed + // with. + RSeqSignature uint32 + // ContainerID is the container the new task belongs to. ContainerID string } @@ -126,6 +134,8 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, rseqCPU: -1, + rseqAddr: cfg.RSeqAddr, + rseqSignature: cfg.RSeqSignature, futexWaiter: futex.NewWaiter(), containerID: cfg.ContainerID, } diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 72568d296..c0197a563 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -238,8 +238,8 @@ type ThreadGroup struct { // execed is protected by the TaskSet mutex. execed bool - // rscr is the thread group's RSEQ critical region. - rscr atomic.Value `state:".(*RSEQCriticalRegion)"` + // oldRSeqCritical is the thread group's old rseq critical region. + oldRSeqCritical atomic.Value `state:".(*OldRSeqCriticalRegion)"` // mounts is the thread group's mount namespace. This does not really // correspond to a "mount namespace" in Linux, but is more like a @@ -256,35 +256,35 @@ type ThreadGroup struct { tty *TTY } -// newThreadGroup returns a new, empty thread group in PID namespace ns. The +// NewThreadGroup returns a new, empty thread group in PID namespace ns. The // thread group leader will send its parent terminationSignal when it exits. // The new thread group isn't visible to the system until a task has been // created inside of it by a successful call to TaskSet.NewTask. -func (k *Kernel) newThreadGroup(mounts *fs.MountNamespace, ns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet, monotonicClock *timekeeperClock) *ThreadGroup { +func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet) *ThreadGroup { tg := &ThreadGroup{ threadGroupNode: threadGroupNode{ - pidns: ns, + pidns: pidns, }, signalHandlers: sh, terminationSignal: terminationSignal, ioUsage: &usage.IO{}, limits: limits, - mounts: mounts, + mounts: mntns, } tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg}) tg.timers = make(map[linux.TimerID]*IntervalTimer) - tg.rscr.Store(&RSEQCriticalRegion{}) + tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{}) return tg } -// saveRscr is invoked by stateify. -func (tg *ThreadGroup) saveRscr() *RSEQCriticalRegion { - return tg.rscr.Load().(*RSEQCriticalRegion) +// saveOldRSeqCritical is invoked by stateify. +func (tg *ThreadGroup) saveOldRSeqCritical() *OldRSeqCriticalRegion { + return tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion) } -// loadRscr is invoked by stateify. -func (tg *ThreadGroup) loadRscr(rscr *RSEQCriticalRegion) { - tg.rscr.Store(rscr) +// loadOldRSeqCritical is invoked by stateify. +func (tg *ThreadGroup) loadOldRSeqCritical(r *OldRSeqCriticalRegion) { + tg.oldRSeqCritical.Store(r) } // SignalHandlers returns the signal handlers used by tg. diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go index 8c2246bb4..79610acb7 100644 --- a/pkg/sentry/mm/procfs.go +++ b/pkg/sentry/mm/procfs.go @@ -66,8 +66,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer var start usermem.Addr for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { - // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get - // "panic: autosave error: type usermem.Addr is not registered". mm.appendVMAMapsEntryLocked(ctx, vseg, buf) } @@ -81,7 +79,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer // // Artifically adjust the seqfile handle so we only output vsyscall entry once. if start != vsyscallEnd { - // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd. buf.WriteString(vsyscallMapsEntry) } } @@ -97,8 +94,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile start = *handle.(*usermem.Addr) } for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { - // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get - // "panic: autosave error: type usermem.Addr is not registered". vmaAddr := vseg.End() data = append(data, seqfile.SeqData{ Buf: mm.vmaMapsEntryLocked(ctx, vseg), @@ -116,7 +111,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile // // Artifically adjust the seqfile handle so we only output vsyscall entry once. if start != vsyscallEnd { - // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd. vmaAddr := vsyscallEnd data = append(data, seqfile.SeqData{ Buf: []byte(vsyscallMapsEntry), @@ -187,15 +181,12 @@ func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffe var start usermem.Addr for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { - // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get - // "panic: autosave error: type usermem.Addr is not registered". mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf) } // We always emulate vsyscall, so advertise it here. See // ReadMapsSeqFileData for additional commentary. if start != vsyscallEnd { - // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd. buf.WriteString(vsyscallSmapsEntry) } } @@ -211,8 +202,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil start = *handle.(*usermem.Addr) } for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { - // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get - // "panic: autosave error: type usermem.Addr is not registered". vmaAddr := vseg.End() data = append(data, seqfile.SeqData{ Buf: mm.vmaSmapsEntryLocked(ctx, vseg), @@ -223,7 +212,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil // We always emulate vsyscall, so advertise it here. See // ReadMapsSeqFileData for additional commentary. if start != vsyscallEnd { - // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd. vmaAddr := vsyscallEnd data = append(data, seqfile.SeqData{ Buf: []byte(vsyscallSmapsEntry), diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 5c52d4007..f3afd98da 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -34,6 +34,8 @@ go_library( "machine_arm64_unsafe.go", "machine_unsafe.go", "physical_map.go", + "physical_map_amd64.go", + "physical_map_arm64.go", "virtual_map.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index 586e91bb2..91de5dab1 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -24,15 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) -const ( - // reservedMemory is a chunk of physical memory reserved starting at - // physical address zero. There are some special pages in this region, - // so we just call the whole thing off. - // - // Other architectures may define this to be zero. - reservedMemory = 0x100000000 -) - type region struct { virtual uintptr length uintptr @@ -59,8 +50,7 @@ func fillAddressSpace() (excludedRegions []region) { // We can cut vSize in half, because the kernel will be using the top // half and we ignore it while constructing mappings. It's as if we've // already excluded half the possible addresses. - vSize := uintptr(1) << ring0.VirtualAddressBits() - vSize = vSize >> 1 + vSize := ring0.UserspaceSize // We exclude reservedMemory below from our physical memory size, so it // needs to be dropped here as well. Otherwise, we could end up with diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/platform/kvm/physical_map_amd64.go index 0e016bca5..c5adfb577 100644 --- a/pkg/sentry/fsimpl/proc/filesystems.go +++ b/pkg/sentry/platform/kvm/physical_map_amd64.go @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proc +package kvm -// filesystemsData implements vfs.DynamicBytesSource for /proc/filesystems. -// -// +stateify savable -type filesystemsData struct{} - -// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for -// filesystemsData. We would need to retrive filesystem names from -// vfs.VirtualFilesystem. Also needs vfs replacement for -// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev. +const ( + // reservedMemory is a chunk of physical memory reserved starting at + // physical address zero. There are some special pages in this region, + // so we just call the whole thing off. + reservedMemory = 0x100000000 +) diff --git a/pkg/sentry/fsimpl/proc/proc.go b/pkg/sentry/platform/kvm/physical_map_arm64.go index 31dec36de..4d8561453 100644 --- a/pkg/sentry/fsimpl/proc/proc.go +++ b/pkg/sentry/platform/kvm/physical_map_arm64.go @@ -12,5 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package proc implements a partial in-memory file system for procfs. -package proc +package kvm + +const ( + reservedMemory = 0 +) diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s index 64c718d21..16f9c523e 100644 --- a/pkg/sentry/platform/ptrace/stub_amd64.s +++ b/pkg/sentry/platform/ptrace/stub_amd64.s @@ -64,6 +64,8 @@ begin: CMPQ AX, $0 JL error + MOVQ $0, BX + // SIGSTOP to wait for attach. // // The SYSCALL instruction will be used for future syscall injection by @@ -73,23 +75,26 @@ begin: MOVQ $SIGSTOP, SI SYSCALL - // The tracer may "detach" and/or allow code execution here in three cases: - // - // 1. New (traced) stub threads are explicitly detached by the - // goroutine in newSubprocess. However, they are detached while in - // group-stop, so they do not execute code here. - // - // 2. If a tracer thread exits, it implicitly detaches from the stub, - // potentially allowing code execution here. However, the Go runtime - // never exits individual threads, so this case never occurs. - // - // 3. subprocess.createStub clones a new stub process that is untraced, + // The sentry sets BX to 1 when creating stub process. + CMPQ BX, $1 + JE clone + + // Notify the Sentry that syscall exited. +done: + INT $3 + // Be paranoid. + JMP done +clone: + // subprocess.createStub clones a new stub process that is untraced, // thus executing this code. We setup the PDEATHSIG before SIGSTOPing // ourselves for attach by the tracer. // // R15 has been updated with the expected PPID. - JMP begin + CMPQ AX, $0 + JE begin + // The clone syscall returns a non-zero value. + JMP done error: // Exit with -errno. MOVQ AX, DI diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s index 2c5e4d5cb..6162df02a 100644 --- a/pkg/sentry/platform/ptrace/stub_arm64.s +++ b/pkg/sentry/platform/ptrace/stub_arm64.s @@ -59,6 +59,8 @@ begin: CMP $0x0, R0 BLT error + MOVD $0, R9 + // SIGSTOP to wait for attach. // // The SYSCALL instruction will be used for future syscall injection by @@ -66,22 +68,26 @@ begin: MOVD $SYS_KILL, R8 MOVD $SIGSTOP, R1 SVC - // The tracer may "detach" and/or allow code execution here in three cases: - // - // 1. New (traced) stub threads are explicitly detached by the - // goroutine in newSubprocess. However, they are detached while in - // group-stop, so they do not execute code here. - // - // 2. If a tracer thread exits, it implicitly detaches from the stub, - // potentially allowing code execution here. However, the Go runtime - // never exits individual threads, so this case never occurs. - // - // 3. subprocess.createStub clones a new stub process that is untraced, + + // The sentry sets R9 to 1 when creating stub process. + CMP $1, R9 + BEQ clone + +done: + // Notify the Sentry that syscall exited. + BRK $3 + B done // Be paranoid. +clone: + // subprocess.createStub clones a new stub process that is untraced, // thus executing this code. We setup the PDEATHSIG before SIGSTOPing // ourselves for attach by the tracer. // // R7 has been updated with the expected PPID. - B begin + CMP $0, R0 + BEQ begin + + // The clone system call returned a non-zero value. + B done error: // Exit with -errno. diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index ddb1f41e3..20244fd95 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -21,6 +21,7 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -429,13 +430,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { } for { - // Execute the syscall instruction. - if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { + // Execute the syscall instruction. The task has to stop on the + // trap instruction which is right after the syscall + // instruction. + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_CONT, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } sig := t.wait(stopped) - if sig == (syscallEvent | syscall.SIGTRAP) { + if sig == syscall.SIGTRAP { // Reached syscall-enter-stop. break } else { @@ -447,18 +450,6 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { } } - // Complete the actual system call. - if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { - panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) - } - - // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens - // between syscall-enter-stop and syscall-exit-stop; it happens *after* - // syscall-exit-stop.)" - ptrace(2), "Syscall-stops" - if sig := t.wait(stopped); sig != (syscallEvent | syscall.SIGTRAP) { - t.dumpAndPanic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig)) - } - // Grab registers. if err := t.getRegs(regs); err != nil { panic(fmt.Sprintf("ptrace get regs failed: %v", err)) @@ -541,14 +532,14 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { if isSingleStepping(regs) { if _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, - syscall.PTRACE_SYSEMU_SINGLESTEP, + unix.PTRACE_SYSEMU_SINGLESTEP, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) } } else { if _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, - syscall.PTRACE_SYSEMU, + unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) } diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 606dc2b1d..e99798c56 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -141,9 +141,11 @@ func (t *thread) adjustInitRegsRip() { t.initRegs.Rip -= initRegsRipAdjustment } -// Pass the expected PPID to the child via R15 when creating stub process +// Pass the expected PPID to the child via R15 when creating stub process. func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.R15 = uint64(ppid) + // Rbx has to be set to 1 when creating stub process. + initregs.Rbx = 1 } // patchSignalInfo patches the signal info to account for hitting the seccomp diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index 62a686ee7..7b975137f 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -127,6 +127,8 @@ func (t *thread) adjustInitRegsRip() { // Pass the expected PPID to the child via X7 when creating stub process func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.Regs[7] = uint64(ppid) + // R9 has to be set to 1 when creating stub process. + initregs.Regs[9] = 1 } // patchSignalInfo patches the signal info to account for hitting the seccomp diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index cf13ea5e4..74968dfdf 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -54,7 +54,7 @@ func probeSeccomp() bool { for { // Attempt an emulation. - if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index f1af18265..87f4552b5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -71,6 +71,7 @@ go_library( "lib_amd64.go", "lib_amd64.s", "lib_arm64.go", + "lib_arm64.s", "ring0.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index fbfbd9bab..dc0eeec01 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -73,6 +73,9 @@ type CPUArchState struct { // application context pointer. appAddr uintptr + + // lazyVFP is the value of cpacr_el1. + lazyVFP uintptr } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 29c475882..64e9c0845 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -31,6 +31,11 @@ #define RSV_REG R18_PLATFORM #define RSV_REG_APP R9 +#define FPEN_NOTRAP 0x3 +#define FPEN_SHIFT 20 + +#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) + #define REGISTERS_SAVE(reg, offset) \ MOVD R0, offset+PTRACE_R0(reg); \ MOVD R1, offset+PTRACE_R1(reg); \ @@ -279,6 +284,16 @@ #define IRQ_DISABLE \ MSR $2, DAIFClr; +#define VFP_ENABLE \ + MOVD $FPEN_ENABLE, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define VFP_DISABLE \ + MOVD $0x0, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + #define KERNEL_ENTRY_FROM_EL0 \ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ @@ -318,6 +333,11 @@ TEXT ·Halt(SB),NOSPLIT,$0 BNE mmio_exit MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG) mmio_exit: + // Disable fpsimd. + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, CPU_LAZY_VFP(RSV_REG) + VFP_DISABLE + // MMIO_EXIT. MOVD $0, R9 MOVD R0, 0xffff000000001000(R9) @@ -337,6 +357,73 @@ TEXT ·Current(SB),NOSPLIT,$0-8 #define STACK_FRAME_SIZE 16 TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 + // Step1, save sentry context into memory. + REGISTERS_SAVE(RSV_REG, CPU_REGISTERS) + MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG) + + WORD $0xd5384003 // MRS SPSR_EL1, R3 + MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG) + MOVD R30, CPU_REGISTERS+PTRACE_PC(RSV_REG) + MOVD RSP, R3 + MOVD R3, CPU_REGISTERS+PTRACE_SP(RSV_REG) + + MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3 + + // Step2, save SP_EL1, PSTATE into kernel temporary stack. + // switch to temporary stack. + LOAD_KERNEL_STACK(RSV_REG) + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + + SUB $STACK_FRAME_SIZE, RSP, RSP + MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11 + MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12 + STP (R11, R12), 16*0(RSP) + + MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11 + MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12 + + // Step3, test user pagetable. + // If user pagetable is empty, trapped in el1_ia. + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + SWITCH_TO_APP_PAGETABLE(RSV_REG) + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + SWITCH_TO_KVM_PAGETABLE(RSV_REG) + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + + // If pagetable is not empty, recovery kernel temporary stack. + ADD $STACK_FRAME_SIZE, RSP, RSP + + // Step4, load app context pointer. + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP + + // Step5, prepare the environment for container application. + // set sp_el0. + MOVD PTRACE_SP(RSV_REG_APP), R1 + WORD $0xd5184101 //MSR R1, SP_EL0 + // set pc. + MOVD PTRACE_PC(RSV_REG_APP), R1 + MSR R1, ELR_EL1 + // set pstate. + MOVD PTRACE_PSTATE(RSV_REG_APP), R1 + WORD $0xd5184001 //MSR R1, SPSR_EL1 + + // RSV_REG & RSV_REG_APP will be loaded at the end. + REGISTERS_LOAD(RSV_REG_APP, 0) + + // switch to user pagetable. + MOVD PTRACE_R18(RSV_REG_APP), RSV_REG + MOVD PTRACE_R9(RSV_REG_APP), RSV_REG_APP + + SUB $STACK_FRAME_SIZE, RSP, RSP + STP (RSV_REG, RSV_REG_APP), 16*0(RSP) + + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + + SWITCH_TO_APP_PAGETABLE(RSV_REG) + + LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) + ADD $STACK_FRAME_SIZE, RSP, RSP + ERET() TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 @@ -382,6 +469,8 @@ TEXT ·El1_sync(SB),NOSPLIT,$0 BEQ el1_svc CMP $ESR_ELx_EC_BREAKPT_CUR, R24 BGE el1_dbg + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el1_fpsimd_acc B el1_invalid el1_da: @@ -402,6 +491,10 @@ el1_svc: el1_dbg: B ·Shutdown(SB) +el1_fpsimd_acc: + VFP_ENABLE + B ·kernelExitToEl1(SB) // Resume. + el1_invalid: B ·Shutdown(SB) @@ -528,6 +621,10 @@ TEXT ·Vectors(SB),NOSPLIT,$0 B ·El0_error_invalid(SB) nop31Instructions() + // The exception-vector-table is required to be 11-bits aligned. + // Please see Linux source code as reference: arch/arm64/kernel/entry.s. + // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs. + // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. WORD $0xd503201f //nop nop31Instructions() WORD $0xd503201f diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 900ee6380..8bcfe1032 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -16,10 +16,24 @@ package ring0 -// LoadFloatingPoint loads floating point state by the most efficient mechanism -// available (set by Init). -var LoadFloatingPoint func(*byte) +// CPACREL1 returns the value of the CPACR_EL1 register. +func CPACREL1() (value uintptr) -// SaveFloatingPoint saves floating point state by the most efficient mechanism -// available (set by Init). -var SaveFloatingPoint func(*byte) +// FPCR returns the value of FPCR register. +func FPCR() (value uintptr) + +// SetFPCR writes the FPCR value. +func SetFPCR(value uintptr) + +// FPSR returns the value of FPSR register. +func FPSR() (value uintptr) + +// SetFPSR writes the FPSR value. +func SetFPSR(value uintptr) + +// SaveVRegs saves V0-V31 registers. +// V0-V31: 32 128-bit registers for floating point and simd. +func SaveVRegs(*byte) + +// LoadVRegs loads V0-V31 registers. +func LoadVRegs(*byte) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s new file mode 100644 index 000000000..0e6a6235b --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -0,0 +1,121 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +TEXT ·CPACREL1(SB),NOSPLIT,$0-8 + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·GetFPCR(SB),NOSPLIT,$0-8 + WORD $0xd53b4201 // MRS NZCV, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·GetFPSR(SB),NOSPLIT,$0-8 + WORD $0xd53b4421 // MRS FPSR, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·SetFPCR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4201 // MSR R1, NZCV + RET + +TEXT ·SetFPSR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4421 // MSR R1, FPSR + RET + +TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD F0, 16*1(R0) + FMOVD F1, 16*2(R0) + FMOVD F2, 16*3(R0) + FMOVD F3, 16*4(R0) + FMOVD F4, 16*5(R0) + FMOVD F5, 16*6(R0) + FMOVD F6, 16*7(R0) + FMOVD F7, 16*8(R0) + FMOVD F8, 16*9(R0) + FMOVD F9, 16*10(R0) + FMOVD F10, 16*11(R0) + FMOVD F11, 16*12(R0) + FMOVD F12, 16*13(R0) + FMOVD F13, 16*14(R0) + FMOVD F14, 16*15(R0) + FMOVD F15, 16*16(R0) + FMOVD F16, 16*17(R0) + FMOVD F17, 16*18(R0) + FMOVD F18, 16*19(R0) + FMOVD F19, 16*20(R0) + FMOVD F20, 16*21(R0) + FMOVD F21, 16*22(R0) + FMOVD F22, 16*23(R0) + FMOVD F23, 16*24(R0) + FMOVD F24, 16*25(R0) + FMOVD F25, 16*26(R0) + FMOVD F26, 16*27(R0) + FMOVD F27, 16*28(R0) + FMOVD F28, 16*29(R0) + FMOVD F29, 16*30(R0) + FMOVD F30, 16*31(R0) + FMOVD F31, 16*32(R0) + ISB $15 + + RET + +TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD 16*1(R0), F0 + FMOVD 16*2(R0), F1 + FMOVD 16*3(R0), F2 + FMOVD 16*4(R0), F3 + FMOVD 16*5(R0), F4 + FMOVD 16*6(R0), F5 + FMOVD 16*7(R0), F6 + FMOVD 16*8(R0), F7 + FMOVD 16*9(R0), F8 + FMOVD 16*10(R0), F9 + FMOVD 16*11(R0), F10 + FMOVD 16*12(R0), F11 + FMOVD 16*13(R0), F12 + FMOVD 16*14(R0), F13 + FMOVD 16*15(R0), F14 + FMOVD 16*16(R0), F15 + FMOVD 16*17(R0), F16 + FMOVD 16*18(R0), F17 + FMOVD 16*19(R0), F18 + FMOVD 16*20(R0), F19 + FMOVD 16*21(R0), F20 + FMOVD 16*22(R0), F21 + FMOVD 16*23(R0), F22 + FMOVD 16*24(R0), F23 + FMOVD 16*25(R0), F24 + FMOVD 16*26(R0), F25 + FMOVD 16*27(R0), F26 + FMOVD 16*28(R0), F27 + FMOVD 16*29(R0), F28 + FMOVD 16*30(R0), F29 + FMOVD 16*31(R0), F30 + FMOVD 16*32(R0), F31 + ISB $15 + + RET diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index d7aa1c7cc..cd2a65f97 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -39,6 +39,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index af1a4e95f..4301b697c 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -471,6 +471,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con case linux.SOL_IP: switch h.Type { case linux.IP_TOS: + if length < linux.SizeOfControlMessageTOS { + return socket.ControlMessages{}, syserror.EINVAL + } cmsgs.IP.HasTOS = true binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) i += AlignUp(length, width) @@ -481,6 +484,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con case linux.SOL_IPV6: switch h.Type { case linux.IPV6_TCLASS: + if length < linux.SizeOfControlMessageTClass { + return socket.ControlMessages{}, syserror.EINVAL + } cmsgs.IP.HasTClass = true binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += AlignUp(length, width) diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 79589e3c8..136821963 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 4a1b87a9a..d2e3644a6 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -500,29 +499,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have trunc := flags&linux.MSG_TRUNC != 0 r := unix.EndpointReader{ + Ctx: t, Endpoint: s.ep, Peek: flags&linux.MSG_PEEK != 0, } + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + // If MSG_TRUNC is set with a zero byte destination then we still need // to read the message and discard it, or in the case where MSG_PEEK is // set, leave it be. In both cases the full message length must be - // returned. However, the memory manager for the destination will not read - // the endpoint if the destination is zero length. - // - // In order for the endpoint to be read when the destination size is zero, - // we must cause a read of the endpoint by using a separate fake zero - // length block sequence and calling the EndpointReader directly. + // returned. if trunc && dst.Addrs.NumBytes() == 0 { - // Perform a read to a zero byte block sequence. We can ignore the - // original destination since it was zero bytes. The length returned by - // ReadToBlocks is ignored and we return the full message length to comply - // with MSG_TRUNC. - _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) - return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } } - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC @@ -540,7 +539,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 140851c17..764f11a6b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -222,17 +222,25 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and + // transport.Endpoint.SetSockOptBool. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and + // transport.Endpoint.GetSockOpt. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack @@ -977,13 +985,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - if len(v) == 0 { + if v == 0 { return []byte{}, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument } - return append([]byte(v), 0), nil + s := t.NetworkContext() + if s == nil { + return nil, syserr.ErrNoDevice + } + nic, ok := s.Interfaces()[int32(v)] + if !ok { + // The NICID no longer indicates a valid interface, probably because that + // interface was removed. + return nil, syserr.ErrUnknownDevice + } + return append([]byte(nic.Name), 0), nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1213,12 +1231,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf return nil, syserr.ErrInvalidArgument } - var v tcpip.V6OnlyOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + var o uint32 + if v { + o = 1 + } + return int32(o), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1427,7 +1448,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i if n == -1 { n = len(optVal) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + name := string(optVal[:n]) + if name == "" { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + } + s := t.NetworkContext() + if s == nil { + return syserr.ErrNoDevice + } + for nicID, nic := range s.Interfaces() { + if nic.Name == name { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + } + } + return syserr.ErrUnknownDevice case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { @@ -1621,7 +1655,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 2ec1a662d..2447f24ef 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -83,6 +83,19 @@ type EndpointReader struct { ControlTrunc bool } +// Truncate calls RecvMsg on the endpoint without writing to a destination. +func (r *EndpointReader) Truncate() error { + // Ignore bytes read since it will always be zero. + _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From) + r.Control = c + r.ControlTrunc = ct + r.MsgSize = ms + if err != nil { + return err.ToError() + } + return nil +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 529a7a7a9..37c7ac3c1 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,17 +175,25 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptBool sets a socket option for simple cases when a value has + // the int type. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptBool gets a socket option for simple cases when a return + // value has the int type. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. @@ -851,11 +859,19 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return nil +} + +func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 885758054..91effe89a 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -544,8 +544,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if senderRequested { r.From = &tcpip.FullAddress{} } + + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. + if trunc && dst.Addrs.NumBytes() == 0 { + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } + + } + var total int64 - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -580,7 +599,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index 9fa2f0e16..1e823b685 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -378,5 +378,7 @@ func init() { syscallTable{ os: abi.Linux, arch: arch.AMD64, - syscalls: linuxAMD64}) + syscalls: linuxAMD64, + }, + ) } diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go index dea309fda..c77d418e6 100644 --- a/pkg/sentry/strace/select.go +++ b/pkg/sentry/strace/select.go @@ -36,6 +36,9 @@ func fdsFromSet(t *kernel.Task, set []byte) []int { } func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string { + if nfds < 0 { + return fmt.Sprintf("%#x (negative nfds)", addr) + } if addr == 0 { return "null" } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 6766ba587..a76975cee 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -30,6 +30,7 @@ go_library( "sys_random.go", "sys_read.go", "sys_rlimit.go", + "sys_rseq.go", "sys_rusage.go", "sys_sched.go", "sys_seccomp.go", diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index 272ae9991..479c5f6ff 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -377,7 +377,7 @@ var AMD64 = &kernel.SyscallTable{ 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 332: syscalls.Supported("statx", Statx), 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), - 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil), // Linux skips ahead to syscall 424 to sync numbers between arches. 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index 3b584eed9..d3f61f5e8 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -307,7 +307,7 @@ var ARM64 = &kernel.SyscallTable{ 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 291: syscalls.Supported("statx", Statx), 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), - 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil), // Linux skips ahead to syscall 424 to sync numbers between arches. 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b9bd25464..bde17a767 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -226,6 +226,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if mask == 0 { return 0, nil, syserror.EINVAL } + if val <= 0 { + // The Linux kernel wakes one waiter even if val is + // non-positive. + val = 1 + } n, err := t.Futex().Wake(t, addr, private, mask, val) return uintptr(n), nil, err @@ -242,6 +247,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FUTEX_WAKE_OP: op := uint32(val3) + if val <= 0 { + // The Linux kernel wakes one waiter even if val is + // non-positive. + val = 1 + } n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op) return uintptr(n), nil, err diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go new file mode 100644 index 000000000..90db10ea6 --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_rseq.go @@ -0,0 +1,48 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// RSeq implements syscall rseq(2). +func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + length := args[1].Uint() + flags := args[2].Int() + signature := args[3].Uint() + + if !t.RSeqAvailable() { + // Event for applications that want rseq on a configuration + // that doesn't support them. + t.Kernel().EmitUnimplementedEvent(t) + return 0, nil, syserror.ENOSYS + } + + switch flags { + case 0: + // Register. + return 0, nil, t.SetRSeq(addr, length, signature) + case linux.RSEQ_FLAG_UNREGISTER: + return 0, nil, t.ClearRSeq(addr, length, signature) + default: + // Unknown flag. + return 0, nil, syserror.EINVAL + } +} diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index d57ffb3a1..4a8bc24a2 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -39,10 +39,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } + defer segment.DecRef() return uintptr(segment.ID), nil, nil } // findSegment retrives a shm segment by the given id. +// +// findSegment returns a reference on Shm. func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) { r := t.IPCNamespace().ShmRegistry() segment := r.FindByID(id) @@ -63,6 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, syserror.EINVAL } + defer segment.DecRef() opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{ Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC, @@ -72,7 +76,6 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - defer segment.DecRef() addr, err = t.MemoryManager().MMap(t, opts) return uintptr(addr), nil, err } @@ -105,6 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } + defer segment.DecRef() stat, err := segment.IPCStat(t) if err == nil { @@ -128,6 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } + defer segment.DecRef() switch cmd { case linux.IPC_SET: diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index e3e554b88..4c6aa04a1 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -9,6 +9,7 @@ go_library( "context.go", "debug.go", "dentry.go", + "device.go", "file_description.go", "file_description_impl_util.go", "filesystem.go", diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 40f4c1d09..1bc9c4a38 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -85,12 +85,12 @@ type Dentry struct { // mounts is accessed using atomic memory operations. mounts uint32 - // mu synchronizes disowning and mounting over this Dentry. - mu sync.Mutex - // children are child Dentries. children map[string]*Dentry + // mu synchronizes disowning and mounting over this Dentry. + mu sync.Mutex + // impl is the DentryImpl associated with this Dentry. impl is immutable. // This should be the last field in Dentry. impl DentryImpl @@ -199,6 +199,18 @@ func (d *Dentry) HasChildren() bool { return len(d.children) != 0 } +// Children returns a map containing all of d's children. +func (d *Dentry) Children() map[string]*Dentry { + if !d.HasChildren() { + return nil + } + m := make(map[string]*Dentry) + for name, child := range d.children { + m[name] = child + } + return m +} + // InsertChild makes child a child of d with the given name. // // InsertChild is a mutator of d and child. @@ -222,6 +234,18 @@ func (d *Dentry) InsertChild(child *Dentry, name string) { child.name = name } +// IsAncestorOf returns true if d is an ancestor of d2; that is, d is either +// d2's parent or an ancestor of d2's parent. +func (d *Dentry) IsAncestorOf(d2 *Dentry) bool { + for d2.parent != nil { + if d2.parent == d { + return true + } + d2 = d2.parent + } + return false +} + // PrepareDeleteDentry must be called before attempting to delete the file // represented by d. If PrepareDeleteDentry succeeds, the caller must call // AbortDeleteDentry or CommitDeleteDentry depending on the deletion's outcome. @@ -271,21 +295,6 @@ func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { } } -// DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as -// appropriate for in-memory filesystems that don't need to ensure that some -// external state change succeeds before committing the deletion. -// -// DeleteDentry is a mutator of d and d.Parent(). -// -// Preconditions: d is a child Dentry. -func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error { - if err := vfs.PrepareDeleteDentry(mntns, d); err != nil { - return err - } - vfs.CommitDeleteDentry(d) - return nil -} - // ForceDeleteDentry causes d to become disowned. It should only be used in // cases where VFS has no ability to stop the deletion (e.g. d represents the // local state of a file on a remote filesystem on which the file has already @@ -314,7 +323,7 @@ func (vfs *VirtualFilesystem) ForceDeleteDentry(d *Dentry) { // CommitRenameExchangeDentry depending on the rename's outcome. // // Preconditions: from is a child Dentry. If to is not nil, it must be a child -// Dentry from the same Filesystem. +// Dentry from the same Filesystem. from != to. func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error { if checkInvariants { if from.parent == nil { diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go new file mode 100644 index 000000000..cb672e36f --- /dev/null +++ b/pkg/sentry/vfs/device.go @@ -0,0 +1,100 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/syserror" +) + +// DeviceKind indicates whether a device is a block or character device. +type DeviceKind uint32 + +const ( + // BlockDevice indicates a block device. + BlockDevice DeviceKind = iota + + // CharDevice indicates a character device. + CharDevice +) + +// String implements fmt.Stringer.String. +func (kind DeviceKind) String() string { + switch kind { + case BlockDevice: + return "block" + case CharDevice: + return "character" + default: + return fmt.Sprintf("invalid device kind %d", kind) + } +} + +type devTuple struct { + kind DeviceKind + major uint32 + minor uint32 +} + +// A Device backs device special files. +type Device interface { + // Open returns a FileDescription representing this device. + Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error) +} + +type registeredDevice struct { + dev Device + opts RegisterDeviceOptions +} + +// RegisterDeviceOptions contains options to +// VirtualFilesystem.RegisterDevice(). +type RegisterDeviceOptions struct { + // GroupName is the name shown for this device registration in + // /proc/devices. If GroupName is empty, this registration will not be + // shown in /proc/devices. + GroupName string +} + +// RegisterDevice registers the given Device in vfs with the given major and +// minor device numbers. +func (vfs *VirtualFilesystem) RegisterDevice(kind DeviceKind, major, minor uint32, dev Device, opts *RegisterDeviceOptions) error { + tup := devTuple{kind, major, minor} + vfs.devicesMu.Lock() + defer vfs.devicesMu.Unlock() + if existing, ok := vfs.devices[tup]; ok { + return fmt.Errorf("%s device number (%d, %d) is already registered to device type %T", kind, major, minor, existing.dev) + } + vfs.devices[tup] = ®isteredDevice{ + dev: dev, + opts: *opts, + } + return nil +} + +// OpenDeviceSpecialFile returns a FileDescription representing the given +// device. +func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mount, d *Dentry, kind DeviceKind, major, minor uint32, opts *OpenOptions) (*FileDescription, error) { + tup := devTuple{kind, major, minor} + vfs.devicesMu.RLock() + defer vfs.devicesMu.RUnlock() + rd, ok := vfs.devices[tup] + if !ok { + return nil, syserror.ENXIO + } + return rd.dev.Open(ctx, mnt, d, *opts) +} diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 6575afd16..6afe280bc 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -20,8 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -38,49 +40,58 @@ type FileDescription struct { // operations. refs int64 + // statusFlags contains status flags, "initialized by open(2) and possibly + // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic + // memory operations. + statusFlags uint32 + // vd is the filesystem location at which this FileDescription was opened. // A reference is held on vd. vd is immutable. vd VirtualDentry + opts FileDescriptionOptions + // impl is the FileDescriptionImpl associated with this Filesystem. impl is // immutable. This should be the last field in FileDescription. impl FileDescriptionImpl } -// Init must be called before first use of fd. It takes ownership of references -// on mnt and d held by the caller. -func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) { +// FileDescriptionOptions contains options to FileDescription.Init(). +type FileDescriptionOptions struct { + // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is + // usually only the case if O_DIRECT would actually have an effect. + AllowDirectIO bool + + // If UseDentryMetadata is true, calls to FileDescription methods that + // interact with file and filesystem metadata (Stat, SetStat, StatFS, + // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling + // the corresponding FilesystemImpl methods instead of the corresponding + // FileDescriptionImpl methods. + // + // UseDentryMetadata is intended for file descriptions that are implemented + // outside of individual filesystems, such as pipes, sockets, and device + // special files. FileDescriptions for which UseDentryMetadata is true may + // embed DentryMetadataFileDescriptionImpl to obtain appropriate + // implementations of FileDescriptionImpl methods that should not be + // called. + UseDentryMetadata bool +} + +// Init must be called before first use of fd. It takes references on mnt and +// d. statusFlags is the initial file description status flags, which is +// usually the full set of flags passed to open(2). +func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) { fd.refs = 1 + fd.statusFlags = statusFlags | linux.O_LARGEFILE fd.vd = VirtualDentry{ mount: mnt, dentry: d, } + fd.vd.IncRef() + fd.opts = *opts fd.impl = impl } -// Impl returns the FileDescriptionImpl associated with fd. -func (fd *FileDescription) Impl() FileDescriptionImpl { - return fd.impl -} - -// Mount returns the mount on which fd was opened. It does not take a reference -// on the returned Mount. -func (fd *FileDescription) Mount() *Mount { - return fd.vd.mount -} - -// Dentry returns the dentry at which fd was opened. It does not take a -// reference on the returned Dentry. -func (fd *FileDescription) Dentry() *Dentry { - return fd.vd.dentry -} - -// VirtualDentry returns the location at which fd was opened. It does not take -// a reference on the returned VirtualDentry. -func (fd *FileDescription) VirtualDentry() VirtualDentry { - return fd.vd -} - // IncRef increments fd's reference count. func (fd *FileDescription) IncRef() { atomic.AddInt64(&fd.refs, 1) @@ -112,6 +123,82 @@ func (fd *FileDescription) DecRef() { } } +// Mount returns the mount on which fd was opened. It does not take a reference +// on the returned Mount. +func (fd *FileDescription) Mount() *Mount { + return fd.vd.mount +} + +// Dentry returns the dentry at which fd was opened. It does not take a +// reference on the returned Dentry. +func (fd *FileDescription) Dentry() *Dentry { + return fd.vd.dentry +} + +// VirtualDentry returns the location at which fd was opened. It does not take +// a reference on the returned VirtualDentry. +func (fd *FileDescription) VirtualDentry() VirtualDentry { + return fd.vd +} + +// StatusFlags returns file description status flags, as for fcntl(F_GETFL). +func (fd *FileDescription) StatusFlags() uint32 { + return atomic.LoadUint32(&fd.statusFlags) +} + +// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL). +func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error { + // Compare Linux's fs/fcntl.c:setfl(). + oldFlags := fd.StatusFlags() + // Linux documents this check as "O_APPEND cannot be cleared if the file is + // marked as append-only and the file is open for write", which would make + // sense. However, the check as actually implemented seems to be "O_APPEND + // cannot be changed if the file is marked as append-only". + if (flags^oldFlags)&linux.O_APPEND != 0 { + stat, err := fd.Stat(ctx, StatOptions{ + // There is no mask bit for stx_attributes. + Mask: 0, + // Linux just reads inode::i_flags directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return err + } + if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) { + return syserror.EPERM + } + } + if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) { + stat, err := fd.Stat(ctx, StatOptions{ + Mask: linux.STATX_UID, + // Linux's inode_owner_or_capable() just reads inode::i_uid + // directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return err + } + if stat.Mask&linux.STATX_UID == 0 { + return syserror.EPERM + } + if !CanActAsOwner(creds, auth.KUID(stat.UID)) { + return syserror.EPERM + } + } + if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { + return syserror.EINVAL + } + // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? + const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK + atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) + return nil +} + +// Impl returns the FileDescriptionImpl associated with fd. +func (fd *FileDescription) Impl() FileDescriptionImpl { + return fd.impl +} + // FileDescriptionImpl contains implementation details for an FileDescription. // Implementations of FileDescriptionImpl should contain their associated // FileDescription by value as their first field. @@ -120,6 +207,8 @@ func (fd *FileDescription) DecRef() { // be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID and // auth.KGID respectively). // +// All methods may return errors not specified. +// // FileDescriptionImpl is analogous to Linux's struct file_operations. type FileDescriptionImpl interface { // Release is called when the associated FileDescription reaches zero @@ -131,14 +220,6 @@ type FileDescriptionImpl interface { // prevent the file descriptor from being closed. OnClose(ctx context.Context) error - // StatusFlags returns file description status flags, as for - // fcntl(F_GETFL). - StatusFlags(ctx context.Context) (uint32, error) - - // SetStatusFlags sets file description status flags, as for - // fcntl(F_SETFL). - SetStatusFlags(ctx context.Context, flags uint32) error - // Stat returns metadata for the file represented by the FileDescription. Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) @@ -156,6 +237,10 @@ type FileDescriptionImpl interface { // PRead reads from the file into dst, starting at the given offset, and // returns the number of bytes read. PRead is permitted to return partial // reads with a nil error. + // + // Errors: + // + // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP. PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) // Read is similar to PRead, but does not specify an offset. @@ -165,6 +250,10 @@ type FileDescriptionImpl interface { // the number of bytes read; note that POSIX 2.9.7 "Thread Interactions // with Regular File Operations" requires that all operations that may // mutate the FileDescription offset are serialized. + // + // Errors: + // + // - If opts.Flags specifies unsupported options, Read returns EOPNOTSUPP. Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) // PWrite writes src to the file, starting at the given offset, and returns @@ -174,6 +263,11 @@ type FileDescriptionImpl interface { // As in Linux (but not POSIX), if O_APPEND is in effect for the // FileDescription, PWrite should ignore the offset and append data to the // end of the file. + // + // Errors: + // + // - If opts.Flags specifies unsupported options, PWrite returns + // EOPNOTSUPP. PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) // Write is similar to PWrite, but does not specify an offset, which is @@ -183,6 +277,10 @@ type FileDescriptionImpl interface { // PWrite that uses a FileDescription offset, to make it possible for // remote filesystems to implement O_APPEND correctly (i.e. atomically with // respect to writers outside the scope of VFS). + // + // Errors: + // + // - If opts.Flags specifies unsupported options, Write returns EOPNOTSUPP. Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) // IterDirents invokes cb on each entry in the directory represented by the @@ -212,7 +310,21 @@ type FileDescriptionImpl interface { // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // TODO: extended attributes; file locking + // Listxattr returns all extended attribute names for the file. + Listxattr(ctx context.Context) ([]string, error) + + // Getxattr returns the value associated with the given extended attribute + // for the file. + Getxattr(ctx context.Context, name string) (string, error) + + // Setxattr changes the value associated with the given extended attribute + // for the file. + Setxattr(ctx context.Context, opts SetxattrOptions) error + + // Removexattr removes the given extended attribute from the file. + Removexattr(ctx context.Context, name string) error + + // TODO: file locking } // Dirent holds the information contained in struct linux_dirent64. @@ -249,31 +361,49 @@ func (fd *FileDescription) OnClose(ctx context.Context) error { return fd.impl.OnClose(ctx) } -// StatusFlags returns file description status flags, as for fcntl(F_GETFL). -func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { - flags, err := fd.impl.StatusFlags(ctx) - flags |= linux.O_LARGEFILE - return flags, err -} - -// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL). -func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - return fd.impl.SetStatusFlags(ctx, flags) -} - // Stat returns metadata for the file represented by fd. func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) + vfsObj.putResolvingPath(rp) + return stat, err + } return fd.impl.Stat(ctx, opts) } // SetStat updates metadata for the file represented by fd. func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) + vfsObj.putResolvingPath(rp) + return err + } return fd.impl.SetStat(ctx, opts) } // StatFS returns metadata for the filesystem containing the file represented // by fd. func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) + vfsObj.putResolvingPath(rp) + return statfs, err + } return fd.impl.StatFS(ctx) } @@ -329,6 +459,78 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. return fd.impl.Ioctl(ctx, uio, args) } +// Listxattr returns all extended attribute names for the file represented by +// fd. +func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp) + vfsObj.putResolvingPath(rp) + return names, err + } + names, err := fd.impl.Listxattr(ctx) + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by default + // don't exist. + return nil, nil + } + return names, err +} + +// Getxattr returns the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name) + vfsObj.putResolvingPath(rp) + return val, err + } + return fd.impl.Getxattr(ctx, name) +} + +// Setxattr changes the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts) + vfsObj.putResolvingPath(rp) + return err + } + return fd.impl.Setxattr(ctx, opts) +} + +// Removexattr removes the given extended attribute from the file represented +// by fd. +func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { + if fd.opts.UseDentryMetadata { + vfsObj := fd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vd, + Start: fd.vd, + }) + err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name) + vfsObj.putResolvingPath(rp) + return err + } + return fd.impl.Removexattr(ctx, name) +} + // SyncFS instructs the filesystem containing fd to execute the semantics of // syncfs(2). func (fd *FileDescription) SyncFS(ctx context.Context) error { @@ -347,7 +549,7 @@ func (fd *FileDescription) MappedName(ctx context.Context) string { // DeviceID implements memmap.MappingIdentity.DeviceID. func (fd *FileDescription) DeviceID() uint64 { - stat, err := fd.impl.Stat(context.Background(), StatOptions{ + stat, err := fd.Stat(context.Background(), StatOptions{ // There is no STATX_DEV; we assume that Stat will return it if it's // available regardless of mask. Mask: 0, @@ -363,7 +565,7 @@ func (fd *FileDescription) DeviceID() uint64 { // InodeID implements memmap.MappingIdentity.InodeID. func (fd *FileDescription) InodeID() uint64 { - stat, err := fd.impl.Stat(context.Background(), StatOptions{ + stat, err := fd.Stat(context.Background(), StatOptions{ Mask: linux.STATX_INO, // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly. Sync: linux.AT_STATX_DONT_SYNC, @@ -376,5 +578,5 @@ func (fd *FileDescription) InodeID() uint64 { // Msync implements memmap.MappingIdentity.Msync. func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error { - return fd.impl.Sync(ctx) + return fd.Sync(ctx) } diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index aae023254..66eb57bc2 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -127,6 +127,31 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } +// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// inode_operations::listxattr == NULL in Linux. +func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) { + // This isn't exactly accurate; see FileDescription.Listxattr. + return nil, syserror.ENOTSUP +} + +// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) { + return "", syserror.ENOTSUP +} + +// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return syserror.ENOTSUP +} + +// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { + return syserror.ENOTSUP +} + // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. @@ -152,6 +177,21 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme return 0, syserror.EISDIR } +// DentryMetadataFileDescriptionImpl may be embedded by implementations of +// FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is +// true to obtain implementations of Stat and SetStat that panic. +type DentryMetadataFileDescriptionImpl struct{} + +// Stat implements FileDescriptionImpl.Stat. +func (DentryMetadataFileDescriptionImpl) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + panic("illegal call to DentryMetadataFileDescriptionImpl.Stat") +} + +// SetStat implements FileDescriptionImpl.SetStat. +func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetStatOptions) error { + panic("illegal call to DentryMetadataFileDescriptionImpl.SetStat") +} + // DynamicBytesFileDescriptionImpl may be embedded by implementations of // FileDescriptionImpl that represent read-only regular files whose contents // are backed by a bytes.Buffer that is regenerated when necessary, consistent @@ -174,6 +214,17 @@ type DynamicBytesSource interface { Generate(ctx context.Context, buf *bytes.Buffer) error } +// StaticData implements DynamicBytesSource over a static string. +type StaticData struct { + Data string +} + +// Generate implements DynamicBytesSource. +func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error { + buf.WriteString(s.Data) + return nil +} + // SetDataSource must be called exactly once on fd before first use. func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) { fd.data = data diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index ac7799296..9ed58512f 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -48,7 +48,7 @@ type genCountFD struct { func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription { var fd genCountFD - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, 0 /* statusFlags */, mnt, vfsd, &FileDescriptionOptions{}) fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd) return &fd.vfsfd } @@ -89,7 +89,7 @@ func TestGenCountFD(t *testing.T) { creds := auth.CredentialsFromContext(ctx) vfsObj := New() // vfs.New() - vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}) + vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}, &RegisterFilesystemTypeOptions{}) mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{}) if err != nil { t.Fatalf("failed to create testfs root mount: %v", err) diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 8011eba3f..ea78f555b 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -108,6 +108,24 @@ func (fs *Filesystem) DecRef() { // (responsible for actually implementing the operation) isn't known until path // resolution is complete. // +// Unless otherwise specified, FilesystemImpl methods are responsible for +// performing permission checks. In many cases, vfs package functions in +// permissions.go may be used to help perform these checks. +// +// When multiple specified error conditions apply to a given method call, the +// implementation may return any applicable errno unless otherwise specified, +// but returning the earliest error specified is preferable to maximize +// compatibility with Linux. +// +// All methods may return errors not specified, notably including: +// +// - ENOENT if a required path component does not exist. +// +// - ENOTDIR if an intermediate path component is not a directory. +// +// - Errors from vfs-package functions (ResolvingPath.Resolve*(), +// Mount.CheckBeginWrite(), permission-checking functions, etc.) +// // For all methods that take or return linux.Statx, Statx.Uid and Statx.Gid // should be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID // and auth.KGID respectively). @@ -130,46 +148,223 @@ type FilesystemImpl interface { // GetDentryAt does not correspond directly to a Linux syscall; it is used // in the implementation of: // - // - Syscalls that need to resolve two paths: rename(), renameat(), - // renameat2(), link(), linkat(). + // - Syscalls that need to resolve two paths: link(), linkat(). // // - Syscalls that need to refer to a filesystem position outside the // context of a file description: chdir(), fchdir(), chroot(), mount(), // umount(). GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) + // GetParentDentryAt returns a Dentry representing the directory at the + // second-to-last path component in rp. (Note that, despite the name, this + // is not necessarily the parent directory of the file at rp, since the + // last path component in rp may be "." or "..".) A reference is taken on + // the returned Dentry. + // + // GetParentDentryAt does not correspond directly to a Linux syscall; it is + // used in the implementation of the rename() family of syscalls, which + // must resolve the parent directories of two paths. + // + // Preconditions: !rp.Done(). + // + // Postconditions: If GetParentDentryAt returns a nil error, then + // rp.Final(). If GetParentDentryAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). + GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error) + // LinkAt creates a hard link at rp representing the same file as vd. It // does not take ownership of references on vd. // - // The implementation is responsible for checking that vd.Mount() == - // rp.Mount(), and that vd does not represent a directory. + // Errors: + // + // - If the last path component in rp is "." or "..", LinkAt returns + // EEXIST. + // + // - If a file already exists at rp, LinkAt returns EEXIST. + // + // - If rp.MustBeDir(), LinkAt returns ENOENT. + // + // - If the directory in which the link would be created has been removed + // by RmdirAt or RenameAt, LinkAt returns ENOENT. + // + // - If rp.Mount != vd.Mount(), LinkAt returns EXDEV. + // + // - If vd represents a directory, LinkAt returns EPERM. + // + // - If vd represents a file for which all existing links have been + // removed, or a file created by open(O_TMPFILE|O_EXCL), LinkAt returns + // ENOENT. Equivalently, if vd represents a file with a link count of 0 not + // created by open(O_TMPFILE) without O_EXCL, LinkAt returns ENOENT. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). + // + // Postconditions: If LinkAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error // MkdirAt creates a directory at rp. + // + // Errors: + // + // - If the last path component in rp is "." or "..", MkdirAt returns + // EEXIST. + // + // - If a file already exists at rp, MkdirAt returns EEXIST. + // + // - If the directory in which the new directory would be created has been + // removed by RmdirAt or RenameAt, MkdirAt returns ENOENT. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). + // + // Postconditions: If MkdirAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error // MknodAt creates a regular file, device special file, or named pipe at // rp. + // + // Errors: + // + // - If the last path component in rp is "." or "..", MknodAt returns + // EEXIST. + // + // - If a file already exists at rp, MknodAt returns EEXIST. + // + // - If rp.MustBeDir(), MknodAt returns ENOENT. + // + // - If the directory in which the file would be created has been removed + // by RmdirAt or RenameAt, MknodAt returns ENOENT. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). + // + // Postconditions: If MknodAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error // OpenAt returns an FileDescription providing access to the file at rp. A // reference is taken on the returned FileDescription. + // + // Errors: + // + // - If opts.Flags specifies O_TMPFILE and this feature is unsupported by + // the implementation, OpenAt returns EOPNOTSUPP. (All other unsupported + // features are silently ignored, consistently with Linux's open*(2).) OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error) // ReadlinkAt returns the target of the symbolic link at rp. + // + // Errors: + // + // - If the file at rp is not a symbolic link, ReadlinkAt returns EINVAL. ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error) - // RenameAt renames the Dentry represented by vd to rp. It does not take - // ownership of references on vd. + // RenameAt renames the file named oldName in directory oldParentVD to rp. + // It does not take ownership of references on oldParentVD. + // + // Errors [1]: + // + // - If opts.Flags specifies unsupported options, RenameAt returns EINVAL. + // + // - If the last path component in rp is "." or "..", and opts.Flags + // contains RENAME_NOREPLACE, RenameAt returns EEXIST. + // + // - If the last path component in rp is "." or "..", and opts.Flags does + // not contain RENAME_NOREPLACE, RenameAt returns EBUSY. + // + // - If rp.Mount != oldParentVD.Mount(), RenameAt returns EXDEV. + // + // - If the renamed file is not a directory, and opts.MustBeDir is true, + // RenameAt returns ENOTDIR. + // + // - If renaming would replace an existing file and opts.Flags contains + // RENAME_NOREPLACE, RenameAt returns EEXIST. + // + // - If there is no existing file at rp and opts.Flags contains + // RENAME_EXCHANGE, RenameAt returns ENOENT. + // + // - If there is an existing non-directory file at rp, and rp.MustBeDir() + // is true, RenameAt returns ENOTDIR. + // + // - If the renamed file is not a directory, opts.Flags does not contain + // RENAME_EXCHANGE, and rp.MustBeDir() is true, RenameAt returns ENOTDIR. + // (This check is not subsumed by the check for directory replacement below + // since it applies even if there is no file to replace.) + // + // - If the renamed file is a directory, and the new parent directory of + // the renamed file is either the renamed directory or a descendant + // subdirectory of the renamed directory, RenameAt returns EINVAL. // - // The implementation is responsible for checking that vd.Mount() == - // rp.Mount(). - RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error + // - If renaming would exchange the renamed file with an ancestor directory + // of the renamed file, RenameAt returns EINVAL. + // + // - If renaming would replace an ancestor directory of the renamed file, + // RenameAt returns ENOTEMPTY. (This check would be subsumed by the + // non-empty directory check below; however, this check takes place before + // the self-rename check.) + // + // - If the renamed file would replace or exchange with itself (i.e. the + // source and destination paths resolve to the same file), RenameAt returns + // nil, skipping the checks described below. + // + // - If the source or destination directory is not writable by the provider + // of rp.Credentials(), RenameAt returns EACCES. + // + // - If the renamed file is a directory, and renaming would replace a + // non-directory file, RenameAt returns ENOTDIR. + // + // - If the renamed file is not a directory, and renaming would replace a + // directory, RenameAt returns EISDIR. + // + // - If the new parent directory of the renamed file has been removed by + // RmdirAt or a preceding call to RenameAt, RenameAt returns ENOENT. + // + // - If the renamed file is a directory, it is not writable by the + // provider of rp.Credentials(), and the source and destination parent + // directories are different, RenameAt returns EACCES. (This is nominally + // required to change the ".." entry in the renamed directory.) + // + // - If renaming would replace a non-empty directory, RenameAt returns + // ENOTEMPTY. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). oldName is not "." or "..". + // + // Postconditions: If RenameAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). + // + // [1] "The worst of all namespace operations - renaming directory. + // "Perverted" doesn't even start to describe it. Somebody in UCB had a + // heck of a trip..." - fs/namei.c:vfs_rename() + RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error // RmdirAt removes the directory at rp. + // + // Errors: + // + // - If the last path component in rp is ".", RmdirAt returns EINVAL. + // + // - If the last path component in rp is "..", RmdirAt returns ENOTEMPTY. + // + // - If no file exists at rp, RmdirAt returns ENOENT. + // + // - If the file at rp exists but is not a directory, RmdirAt returns + // ENOTDIR. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). + // + // Postconditions: If RmdirAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). RmdirAt(ctx context.Context, rp *ResolvingPath) error // SetStatAt updates metadata for the file at the given path. + // + // Errors: + // + // - If opts specifies unsupported options, SetStatAt returns EINVAL. SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error // StatAt returns metadata for the file at rp. @@ -181,11 +376,82 @@ type FilesystemImpl interface { StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error) // SymlinkAt creates a symbolic link at rp referring to the given target. + // + // Errors: + // + // - If the last path component in rp is "." or "..", SymlinkAt returns + // EEXIST. + // + // - If a file already exists at rp, SymlinkAt returns EEXIST. + // + // - If rp.MustBeDir(), SymlinkAt returns ENOENT. + // + // - If the directory in which the symbolic link would be created has been + // removed by RmdirAt or RenameAt, SymlinkAt returns ENOENT. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). + // + // Postconditions: If SymlinkAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error - // UnlinkAt removes the non-directory file at rp. + // UnlinkAt removes the file at rp. + // + // Errors: + // + // - If the last path component in rp is "." or "..", UnlinkAt returns + // EISDIR. + // + // - If no file exists at rp, UnlinkAt returns ENOENT. + // + // - If rp.MustBeDir(), and the file at rp exists and is not a directory, + // UnlinkAt returns ENOTDIR. + // + // - If the file at rp exists but is a directory, UnlinkAt returns EISDIR. + // + // Preconditions: !rp.Done(). For the final path component in rp, + // !rp.ShouldFollowSymlink(). + // + // Postconditions: If UnlinkAt returns an error returned by + // ResolvingPath.Resolve*(), then !rp.Done(). UnlinkAt(ctx context.Context, rp *ResolvingPath) error + // ListxattrAt returns all extended attribute names for the file at rp. + // + // Errors: + // + // - If extended attributes are not supported by the filesystem, + // ListxattrAt returns nil. (See FileDescription.Listxattr for an + // explanation.) + ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) + + // GetxattrAt returns the value associated with the given extended + // attribute for the file at rp. + // + // Errors: + // + // - If extended attributes are not supported by the filesystem, GetxattrAt + // returns ENOTSUP. + GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) + + // SetxattrAt changes the value associated with the given extended + // attribute for the file at rp. + // + // Errors: + // + // - If extended attributes are not supported by the filesystem, SetxattrAt + // returns ENOTSUP. + SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + + // RemovexattrAt removes the given extended attribute from the file at rp. + // + // Errors: + // + // - If extended attributes are not supported by the filesystem, + // RemovexattrAt returns ENOTSUP. + RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + // PrependPath prepends a path from vd to vd.Mount().Root() to b. // // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered @@ -208,7 +474,7 @@ type FilesystemImpl interface { // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error - // TODO: extended attributes; inotify_add_watch(); bind() + // TODO: inotify_add_watch(); bind() } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index c335e206d..023301780 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -15,6 +15,7 @@ package vfs import ( + "bytes" "fmt" "gvisor.dev/gvisor/pkg/sentry/context" @@ -43,28 +44,70 @@ type GetFilesystemOptions struct { InternalData interface{} } +type registeredFilesystemType struct { + fsType FilesystemType + opts RegisterFilesystemTypeOptions +} + +// RegisterFilesystemTypeOptions contains options to +// VirtualFilesystem.RegisterFilesystem(). +type RegisterFilesystemTypeOptions struct { + // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt() + // for which MountOptions.InternalMount == false to use this filesystem + // type. + AllowUserMount bool + + // If AllowUserList is true, make this filesystem type visible in + // /proc/filesystems. + AllowUserList bool + + // If RequiresDevice is true, indicate that mounting this filesystem + // requires a block device as the mount source in /proc/filesystems. + RequiresDevice bool +} + // RegisterFilesystemType registers the given FilesystemType in vfs with the // given name. -func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType) error { +func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) error { vfs.fsTypesMu.Lock() defer vfs.fsTypesMu.Unlock() if existing, ok := vfs.fsTypes[name]; ok { - return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing) + return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing.fsType) + } + vfs.fsTypes[name] = ®isteredFilesystemType{ + fsType: fsType, + opts: *opts, } - vfs.fsTypes[name] = fsType return nil } // MustRegisterFilesystemType is equivalent to RegisterFilesystemType but // panics on failure. -func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType) { - if err := vfs.RegisterFilesystemType(name, fsType); err != nil { +func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) { + if err := vfs.RegisterFilesystemType(name, fsType, opts); err != nil { panic(fmt.Sprintf("failed to register filesystem type %T: %v", fsType, err)) } } -func (vfs *VirtualFilesystem) getFilesystemType(name string) FilesystemType { +func (vfs *VirtualFilesystem) getFilesystemType(name string) *registeredFilesystemType { vfs.fsTypesMu.RLock() defer vfs.fsTypesMu.RUnlock() return vfs.fsTypes[name] } + +// GenerateProcFilesystems emits the contents of /proc/filesystems for vfs to +// buf. +func (vfs *VirtualFilesystem) GenerateProcFilesystems(buf *bytes.Buffer) { + vfs.fsTypesMu.RLock() + defer vfs.fsTypesMu.RUnlock() + for name, rft := range vfs.fsTypes { + if !rft.opts.AllowUserList { + continue + } + var nodev string + if !rft.opts.RequiresDevice { + nodev = "nodev" + } + fmt.Fprintf(buf, "%s\t%s\n", nodev, name) + } +} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index ec23ab0dd..00177b371 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -112,11 +112,11 @@ type MountNamespace struct { // configured by the given arguments. A reference is taken on the returned // MountNamespace. func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { - fsType := vfs.getFilesystemType(fsTypeName) - if fsType == nil { + rft := vfs.getFilesystemType(fsTypeName) + if rft == nil { return nil, syserror.ENODEV } - fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts) + fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts) if err != nil { return nil, err } @@ -136,11 +136,14 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth // MountAt creates and mounts a Filesystem configured by the given arguments. func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { - fsType := vfs.getFilesystemType(fsTypeName) - if fsType == nil { + rft := vfs.getFilesystemType(fsTypeName) + if rft == nil { return syserror.ENODEV } - fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) + if !opts.InternalMount && !rft.opts.AllowUserMount { + return syserror.ENODEV + } + fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return err } diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3ecbc8fc1..b7774bf28 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -50,6 +50,10 @@ type MknodOptions struct { type MountOptions struct { // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). GetFilesystemOptions GetFilesystemOptions + + // If InternalMount is true, allow the use of filesystem types for which + // RegisterFilesystemTypeOptions.AllowUserMount == false. + InternalMount bool } // OpenOptions contains options to VirtualFilesystem.OpenAt() and @@ -83,6 +87,9 @@ type ReadOptions struct { type RenameOptions struct { // Flags contains flags as specified for renameat2(2). Flags uint32 + + // If MustBeDir is true, the renamed file must be a directory. + MustBeDir bool } // SetStatOptions contains options to VirtualFilesystem.SetStatAt(), @@ -101,6 +108,20 @@ type SetStatOptions struct { Stat linux.Statx } +// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), +// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and +// FileDescriptionImpl.Setxattr(). +type SetxattrOptions struct { + // Name is the name of the extended attribute being mutated. + Name string + + // Value is the extended attribute's new value. + Value string + + // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2). + Flags uint32 +} + // StatOptions contains options to VirtualFilesystem.StatAt(), // FilesystemImpl.StatAt(), FileDescription.Stat(), and // FileDescriptionImpl.Stat(). diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 621f5a6f8..f0641d314 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -85,11 +85,11 @@ func init() { // so error "constants" are really mutable vars, necessitating somewhat // expensive interface object comparisons. -type resolveMountRootError struct{} +type resolveMountRootOrJumpError struct{} // Error implements error.Error. -func (resolveMountRootError) Error() string { - return "resolving mount root" +func (resolveMountRootOrJumpError) Error() string { + return "resolving mount root or jump" } type resolveMountPointError struct{} @@ -112,30 +112,26 @@ var resolvingPathPool = sync.Pool{ }, } -func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) (*ResolvingPath, error) { - path, err := fspath.Parse(pop.Pathname) - if err != nil { - return nil, err - } +func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath { rp := resolvingPathPool.Get().(*ResolvingPath) rp.vfs = vfs rp.root = pop.Root rp.mount = pop.Start.mount rp.start = pop.Start.dentry - rp.pit = path.Begin + rp.pit = pop.Path.Begin rp.flags = 0 if pop.FollowFinalSymlink { rp.flags |= rpflagsFollowFinalSymlink } - rp.mustBeDir = path.Dir - rp.mustBeDirOrig = path.Dir + rp.mustBeDir = pop.Path.Dir + rp.mustBeDirOrig = pop.Path.Dir rp.symlinks = 0 rp.curPart = 0 rp.numOrigParts = 1 rp.creds = creds - rp.parts[0] = path.Begin - rp.origParts[0] = path.Begin - return rp, nil + rp.parts[0] = pop.Path.Begin + rp.origParts[0] = pop.Path.Begin + return rp } func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) { @@ -274,7 +270,7 @@ func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) { // ... of non-root mount. rp.nextMount = vd.mount rp.nextStart = vd.dentry - return nil, resolveMountRootError{} + return nil, resolveMountRootOrJumpError{} } // ... of root mount. parent = d @@ -345,29 +341,34 @@ func (rp *ResolvingPath) ShouldFollowSymlink() bool { // symlink target and returns nil. Otherwise it returns a non-nil error. // // Preconditions: !rp.Done(). +// +// Postconditions: If HandleSymlink returns a nil error, then !rp.Done(). func (rp *ResolvingPath) HandleSymlink(target string) error { if rp.symlinks >= linux.MaxSymlinkTraversals { return syserror.ELOOP } - targetPath, err := fspath.Parse(target) - if err != nil { - return err + if len(target) == 0 { + return syserror.ENOENT } rp.symlinks++ + targetPath := fspath.Parse(target) if targetPath.Absolute { rp.absSymlinkTarget = targetPath return resolveAbsSymlinkError{} } - if !targetPath.Begin.Ok() { - panic(fmt.Sprintf("symbolic link has non-empty target %q that is both relative and has no path components?", target)) - } // Consume the path component that represented the symlink. rp.Advance() // Prepend the symlink target to the relative path. + if checkInvariants { + if !targetPath.HasComponents() { + panic(fmt.Sprintf("non-empty pathname %q parsed to relative path with no components", target)) + } + } rp.relpathPrepend(targetPath) return nil } +// Preconditions: path.HasComponents(). func (rp *ResolvingPath) relpathPrepend(path fspath.Path) { if rp.pit.Ok() { rp.parts[rp.curPart] = rp.pit @@ -385,11 +386,32 @@ func (rp *ResolvingPath) relpathPrepend(path fspath.Path) { } } +// HandleJump is called when the current path component is a "magic" link to +// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem +// method should continue path traversal, HandleMagicSymlink updates the path +// component stream to reflect the magic link target and returns nil. Otherwise +// it returns a non-nil error. +// +// Preconditions: !rp.Done(). +func (rp *ResolvingPath) HandleJump(target VirtualDentry) error { + if rp.symlinks >= linux.MaxSymlinkTraversals { + return syserror.ELOOP + } + rp.symlinks++ + // Consume the path component that represented the magic link. + rp.Advance() + // Unconditionally return a resolveMountRootOrJumpError, even if the Mount + // isn't changing, to force restarting at the new Dentry. + target.IncRef() + rp.nextMount = target.mount + rp.nextStart = target.dentry + return resolveMountRootOrJumpError{} +} + func (rp *ResolvingPath) handleError(err error) bool { switch err.(type) { - case resolveMountRootError: - // Switch to the new Mount. We hold references on the Mount and Dentry - // (from VFS.getMountpointAt()). + case resolveMountRootOrJumpError: + // Switch to the new Mount. We hold references on the Mount and Dentry. rp.decRefStartAndMount() rp.mount = rp.nextMount rp.start = rp.nextStart @@ -407,9 +429,8 @@ func (rp *ResolvingPath) handleError(err error) bool { return true case resolveMountPointError: - // Switch to the new Mount. We hold a reference on the Mount (from - // VFS.getMountAt()), but borrow the reference on the mount root from - // the Mount. + // Switch to the new Mount. We hold a reference on the Mount, but + // borrow the reference on the mount root from the Mount. rp.decRefStartAndMount() rp.mount = rp.nextMount rp.start = rp.nextMount.root @@ -447,6 +468,17 @@ func (rp *ResolvingPath) handleError(err error) bool { } } +// canHandleError returns true if err is an error returned by rp.Resolve*() +// that rp.handleError() may attempt to handle. +func (rp *ResolvingPath) canHandleError(err error) bool { + switch err.(type) { + case resolveMountRootOrJumpError, resolveMountPointError, resolveAbsSymlinkError: + return true + default: + return false + } +} + // MustBeDir returns true if the file traversed by rp must be a directory. func (rp *ResolvingPath) MustBeDir() bool { return rp.mustBeDir diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go index 7a1d9e383..ee5c8b9e2 100644 --- a/pkg/sentry/vfs/testutil.go +++ b/pkg/sentry/vfs/testutil.go @@ -57,6 +57,11 @@ func (fs *FDTestFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, return nil, syserror.EPERM } +// GetParentDentryAt implements FilesystemImpl.GetParentDentryAt. +func (fs *FDTestFilesystem) GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error) { + return nil, syserror.EPERM +} + // LinkAt implements FilesystemImpl.LinkAt. func (fs *FDTestFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error { return syserror.EPERM @@ -83,7 +88,7 @@ func (fs *FDTestFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) ( } // RenameAt implements FilesystemImpl.RenameAt. -func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error { +func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error { return syserror.EPERM } @@ -117,6 +122,26 @@ func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) err return syserror.EPERM } +// ListxattrAt implements FilesystemImpl.ListxattrAt. +func (fs *FDTestFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) { + return nil, syserror.EPERM +} + +// GetxattrAt implements FilesystemImpl.GetxattrAt. +func (fs *FDTestFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) { + return "", syserror.EPERM +} + +// SetxattrAt implements FilesystemImpl.SetxattrAt. +func (fs *FDTestFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { + return syserror.EPERM +} + +// RemovexattrAt implements FilesystemImpl.RemovexattrAt. +func (fs *FDTestFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { + return syserror.EPERM +} + // PrependPath implements FilesystemImpl.PrependPath. func (fs *FDTestFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error { b.PrependComponent(fmt.Sprintf("vfs.fdTestDentry:%p", vd.dentry.impl.(*fdTestDentry))) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 7262b0d0a..ea2db7031 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -28,9 +28,11 @@ package vfs import ( + "fmt" "sync" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -73,23 +75,29 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // devices contains all registered Devices. devices is protected by + // devicesMu. + devicesMu sync.RWMutex + devices map[devTuple]*registeredDevice + + // fsTypes contains all registered FilesystemTypes. fsTypes is protected by + // fsTypesMu. + fsTypesMu sync.RWMutex + fsTypes map[string]*registeredFilesystemType + // filesystems contains all Filesystems. filesystems is protected by // filesystemsMu. filesystemsMu sync.Mutex filesystems map[*Filesystem]struct{} - - // fsTypes contains all FilesystemTypes that are usable in the - // VirtualFilesystem. fsTypes is protected by fsTypesMu. - fsTypesMu sync.RWMutex - fsTypes map[string]FilesystemType } // New returns a new VirtualFilesystem with no mounts or FilesystemTypes. func New() *VirtualFilesystem { vfs := &VirtualFilesystem{ mountpoints: make(map[*Dentry]map[*Mount]struct{}), + devices: make(map[devTuple]*registeredDevice), + fsTypes: make(map[string]*registeredFilesystemType), filesystems: make(map[*Filesystem]struct{}), - fsTypes: make(map[string]FilesystemType), } vfs.mounts.Init() return vfs @@ -111,11 +119,11 @@ type PathOperation struct { // are borrowed from the provider of the PathOperation (i.e. the caller of // the VFS method to which the PathOperation was passed). // - // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. + // Invariants: Start.Ok(). If Path.Absolute, then Start == Root. Start VirtualDentry // Path is the pathname traversed by this operation. - Pathname string + Path fspath.Path // If FollowFinalSymlink is true, and the Dentry traversed by the final // path component represents a symbolic link, the symbolic link should be @@ -126,10 +134,7 @@ type PathOperation struct { // GetDentryAt returns a VirtualDentry representing the given path, at which a // file must exist. A reference is taken on the returned VirtualDentry. func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return VirtualDentry{}, err - } + rp := vfs.getResolvingPath(creds, pop) for { d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) if err == nil { @@ -148,6 +153,33 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede } } +// Preconditions: pop.Path.Begin.Ok(). +func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (VirtualDentry, string, error) { + rp := vfs.getResolvingPath(creds, pop) + for { + parent, err := rp.mount.fs.impl.GetParentDentryAt(ctx, rp) + if err == nil { + parentVD := VirtualDentry{ + mount: rp.mount, + dentry: parent, + } + rp.mount.IncRef() + name := rp.Component() + vfs.putResolvingPath(rp) + return parentVD, name, nil + } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return VirtualDentry{}, "", err + } + } +} + // LinkAt creates a hard link at newpop representing the existing file at // oldpop. func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error { @@ -155,21 +187,36 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential if err != nil { return err } - rp, err := vfs.getResolvingPath(creds, newpop) - if err != nil { + + if !newpop.Path.Begin.Ok() { oldVD.DecRef() - return err + if newpop.Path.Absolute { + return syserror.EEXIST + } + return syserror.ENOENT + } + if newpop.FollowFinalSymlink { + oldVD.DecRef() + ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink") + return syserror.EINVAL } + + rp := vfs.getResolvingPath(creds, newpop) for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - oldVD.DecRef() vfs.putResolvingPath(rp) + oldVD.DecRef() return nil } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } if !rp.handleError(err) { - oldVD.DecRef() vfs.putResolvingPath(rp) + oldVD.DecRef() return err } } @@ -177,19 +224,32 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential // MkdirAt creates a directory at the given path. func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { + if !pop.Path.Begin.Ok() { + if pop.Path.Absolute { + return syserror.EEXIST + } + return syserror.ENOENT + } + if pop.FollowFinalSymlink { + ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink") + return syserror.EINVAL + } // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is // also honored." - mkdir(2) opts.Mode &= 0777 | linux.S_ISVTX - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err - } + + rp := vfs.getResolvingPath(creds, pop) for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(rp) return nil } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } if !rp.handleError(err) { vfs.putResolvingPath(rp) return err @@ -200,16 +260,29 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia // MknodAt creates a file of the given mode at the given path. It returns an // error from the syserror package. func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil + if !pop.Path.Begin.Ok() { + if pop.Path.Absolute { + return syserror.EEXIST + } + return syserror.ENOENT + } + if pop.FollowFinalSymlink { + ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink") + return syserror.EINVAL } + + rp := vfs.getResolvingPath(creds, pop) for { - if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { + err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) + if err != nil { vfs.putResolvingPath(rp) return nil } - // Handle mount traversals. + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } if !rp.handleError(err) { vfs.putResolvingPath(rp) return err @@ -259,10 +332,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential if opts.Flags&linux.O_NOFOLLOW != 0 { pop.FollowFinalSymlink = false } - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil, err - } + rp := vfs.getResolvingPath(creds, pop) if opts.Flags&linux.O_DIRECTORY != 0 { rp.mustBeDir = true rp.mustBeDirOrig = true @@ -282,10 +352,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential // ReadlinkAt returns the target of the symbolic link at the given path. func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return "", err - } + rp := vfs.getResolvingPath(creds, pop) for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { @@ -301,25 +368,59 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden // RenameAt renames the file at oldpop to newpop. func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error { - oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) - if err != nil { - return err + if !oldpop.Path.Begin.Ok() { + if oldpop.Path.Absolute { + return syserror.EBUSY + } + return syserror.ENOENT + } + if oldpop.FollowFinalSymlink { + ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink") + return syserror.EINVAL } - rp, err := vfs.getResolvingPath(creds, newpop) + + oldParentVD, oldName, err := vfs.getParentDirAndName(ctx, creds, oldpop) if err != nil { - oldVD.DecRef() return err } + if oldName == "." || oldName == ".." { + oldParentVD.DecRef() + return syserror.EBUSY + } + + if !newpop.Path.Begin.Ok() { + oldParentVD.DecRef() + if newpop.Path.Absolute { + return syserror.EBUSY + } + return syserror.ENOENT + } + if newpop.FollowFinalSymlink { + oldParentVD.DecRef() + ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink") + return syserror.EINVAL + } + + rp := vfs.getResolvingPath(creds, newpop) + renameOpts := *opts + if oldpop.Path.Dir { + renameOpts.MustBeDir = true + } for { - err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts) + err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - oldVD.DecRef() vfs.putResolvingPath(rp) + oldParentVD.DecRef() return nil } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } if !rp.handleError(err) { - oldVD.DecRef() vfs.putResolvingPath(rp) + oldParentVD.DecRef() return err } } @@ -327,16 +428,29 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti // RmdirAt removes the directory at the given path. func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err + if !pop.Path.Begin.Ok() { + if pop.Path.Absolute { + return syserror.EBUSY + } + return syserror.ENOENT + } + if pop.FollowFinalSymlink { + ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink") + return syserror.EINVAL } + + rp := vfs.getResolvingPath(creds, pop) for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { vfs.putResolvingPath(rp) return nil } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } if !rp.handleError(err) { vfs.putResolvingPath(rp) return err @@ -346,10 +460,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia // SetStatAt changes metadata for the file at the given path. func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err - } + rp := vfs.getResolvingPath(creds, pop) for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { @@ -365,10 +476,7 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent // StatAt returns metadata for the file at the given path. func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return linux.Statx{}, err - } + rp := vfs.getResolvingPath(creds, pop) for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { @@ -385,10 +493,7 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential // StatFSAt returns metadata for the filesystem containing the file at the // given path. func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return linux.Statfs{}, err - } + rp := vfs.getResolvingPath(creds, pop) for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { @@ -404,16 +509,29 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti // SymlinkAt creates a symbolic link at the given path with the given target. func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err + if !pop.Path.Begin.Ok() { + if pop.Path.Absolute { + return syserror.EEXIST + } + return syserror.ENOENT } + if pop.FollowFinalSymlink { + ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink") + return syserror.EINVAL + } + + rp := vfs.getResolvingPath(creds, pop) for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { vfs.putResolvingPath(rp) return nil } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } if !rp.handleError(err) { vfs.putResolvingPath(rp) return err @@ -423,16 +541,104 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent // UnlinkAt deletes the non-directory file at the given path. func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err + if !pop.Path.Begin.Ok() { + if pop.Path.Absolute { + return syserror.EBUSY + } + return syserror.ENOENT } + if pop.FollowFinalSymlink { + ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink") + return syserror.EINVAL + } + + rp := vfs.getResolvingPath(creds, pop) for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { vfs.putResolvingPath(rp) return nil } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err)) + } + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// ListxattrAt returns all extended attribute names for the file at the given +// path. +func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) { + rp := vfs.getResolvingPath(creds, pop) + for { + names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return names, nil + } + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by + // default don't exist. + vfs.putResolvingPath(rp) + return nil, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// GetxattrAt returns the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) { + rp := vfs.getResolvingPath(creds, pop) + for { + val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return val, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// SetxattrAt changes the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { + rp := vfs.getResolvingPath(creds, pop) + for { + err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// RemovexattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { + rp := vfs.getResolvingPath(creds, pop) + for { + err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } if !rp.handleError(err) { vfs.putResolvingPath(rp) return err diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 1987e89cc..2269f6237 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -45,6 +45,7 @@ var ( ELIBBAD = error(syscall.ELIBBAD) ELOOP = error(syscall.ELOOP) EMFILE = error(syscall.EMFILE) + EMLINK = error(syscall.EMLINK) EMSGSIZE = error(syscall.EMSGSIZE) ENAMETOOLONG = error(syscall.ENAMETOOLONG) ENOATTR = ENODATA diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 65d4d0cd8..e07ebd153 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -10,6 +10,7 @@ go_library( "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", + "timer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], @@ -26,3 +27,10 @@ go_test( srcs = ["tcpip_test.go"], embed = [":tcpip"], ) + +go_test( + name = "timer_test", + size = "small", + srcs = ["timer_test.go"], + deps = [":tcpip"], +) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index f1d837196..f2061c778 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -44,6 +44,7 @@ go_test( ], deps = [ ":header", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index fc671e439..135a60b12 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -15,6 +15,7 @@ package header import ( + "crypto/sha256" "encoding/binary" "strings" @@ -102,6 +103,11 @@ const ( // bytes including and after the IIDOffsetInIPv6Address-th byte are // for the IID. IIDOffsetInIPv6Address = 8 + + // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes + // for the secret key used to generate an opaque interface identifier as + // outlined by RFC 7217. + OpaqueIIDSecretKeyMinBytes = 16 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -326,3 +332,42 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } + +// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier +// (IID) to buf as outlined by RFC 7217 and returns the extended buffer. +// +// The opaque IID is generated from the cryptographic hash of the concatenation +// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret +// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and +// MUST be generated to a pseudo-random number. See RFC 4086 for randomness +// requirements for security. +// +// If buf has enough capacity for the IID (IIDSize bytes), a new underlying +// array for the buffer will not be allocated. +func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte { + // As per RFC 7217 section 5, the opaque identifier can be generated as a + // cryptographic hash of the concatenation of each of the function parameters. + // Note, we omit the optional Network_ID field. + h := sha256.New() + // h.Write never returns an error. + h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address])) + h.Write([]byte(nicName)) + h.Write([]byte{dadCounter}) + h.Write(secretKey) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + return append(buf, sum[:IIDSize]...) +} + +// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with +// an opaque IID. +func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address { + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 42c5c6fc1..1994003ed 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -15,9 +15,12 @@ package header_test import ( + "bytes" + "crypto/sha256" "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -43,3 +46,163 @@ func TestLinkLocalAddr(t *testing.T) { t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) } } + +func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + prefix: header.IPv6LinkLocalPrefix.Subnet(), + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := sha256.New() + h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) + h.Write([]byte(test.nicName)) + h.Write([]byte{test.dadCounter}) + if k := test.secretKey; k != nil { + h.Write(k) + } + var hashSum [sha256.Size]byte + h.Sum(hashSum[:0]) + want := hashSum[:header.IIDSize] + + // Passing a nil buffer should result in a new buffer returned with the + // IID. + if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + + // Passing a buffer with sufficient capacity for the IID should populate + // the buffer provided. + var iidBuf [header.IIDSize]byte + if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + if got := iidBuf[:]; !bytes.Equal(got, want) { + t.Errorf("got iidBuf = %x, want = %x", got, want) + } + }) + } +} + +func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + prefix := header.IPv6LinkLocalPrefix.Subnet() + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + addrBytes := [header.IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + test.nicName, + test.dadCounter, + test.secretKey, + )) + + if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { + t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index da8482509..42cacb8a6 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4144a7837..f1bc33adf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e645cf62c..4ee3d5b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { @@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. ip := header.IPv4(pkt.Data.First()) @@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { e.HandlePacket(r, pkt.Clone()) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index dd31f0fb7..58c3c79b9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -161,8 +161,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { - // TODO(b/119580726): Support IPv6 header-included packets. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { + // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 69077669a..b8f9517d0 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -59,6 +59,7 @@ go_test( ], deps = [ ":stack", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 27bd02e76..35825ebf7 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -131,40 +131,34 @@ type NDPDispatcher interface { OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is - // discovered. Implementations must return true along with a new valid - // route table if the newly discovered router should be remembered. If - // an implementation returns false, the second return value will be - // ignored. + // discovered. Implementations must return true if the newly discovered + // router should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool // OnDefaultRouterInvalidated will be called when a discovered default - // router is invalidated. Implementers must return a new valid route - // table. + // router that was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true along with a new valid - // route table if the newly discovered on-link prefix should be - // remembered. If an implementation returns false, the second return - // value will be ignored. + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix is invalidated. Implementers must return a new valid route - // table. + // prefix that was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) // OnAutoGenAddress will be called when a new prefix with its // autonomous address-configuration flag set has been received and SLAAC @@ -175,6 +169,15 @@ type NDPDispatcher interface { // call functions on the stack itself. OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + // OnAutoGenAddressDeprecated will be called when an auto-generated + // address (as part of SLAAC) has been deprecated, but is still + // considered valid. Note, if an address is invalidated at the same + // time it is deprecated, the deprecation event MAY be omitted. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) + // OnAutoGenAddressInvalidated will be called when an auto-generated // address (as part of SLAAC) has been invalidated. // @@ -296,71 +299,27 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - invalidationTimer *time.Timer - - // Used to inform the timer not to invalidate the default router (R) in - // a race condition (T1 is a goroutine that handles an RA from R and T2 - // is the goroutine that handles R's invalidation timer firing): - // T1: Receive a new RA from R - // T1: Obtain the NIC's lock before processing the RA - // T2: R's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends R's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates R immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate R, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool + invalidationTimer tcpip.CancellableTimer } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - invalidationTimer *time.Timer - - // Used to signal the timer not to invalidate the on-link prefix (P) in - // a race condition (T1 is a goroutine that handles a PI for P and T2 - // is the goroutine that handles P's invalidation timer firing): - // T1: Receive a new PI for P - // T1: Obtain the NIC's lock before processing the PI - // T2: P's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends P's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates P immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate P, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool + invalidationTimer tcpip.CancellableTimer } // autoGenAddressState holds data associated with an address generated via // SLAAC. type autoGenAddressState struct { - invalidationTimer *time.Timer - - // Used to signal the timer not to invalidate the SLAAC address (A) in - // a race condition (T1 is a goroutine that handles a PI for A and T2 - // is the goroutine that handles A's invalidation timer firing): - // T1: Receive a new PI for A - // T1: Obtain the NIC's lock before processing the PI - // T2: A's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends A's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates A immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate A, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool - - // Nonzero only when the address is not valid forever (invalidationTimer - // is not nil). + // A reference to the referencedNetworkEndpoint that this autoGenAddressState + // is holding state for. + ref *referencedNetworkEndpoint + + deprecationTimer tcpip.CancellableTimer + invalidationTimer tcpip.CancellableTimer + + // Nonzero only when the address is not valid forever. validUntil time.Time } @@ -562,7 +521,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { // handleRA handles a Router Advertisement message that arrived on the NIC // this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the NIC configured to handle RAs at all? // @@ -591,27 +550,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update // the invalidation timer. - timer := rtr.invalidationTimer - - // We should ALWAYS have an invalidation timer for a - // discovered router. - if timer == nil { - panic("ndphandlera: RA invalidation timer should not be nil") - } - - if !timer.Stop() { - // If we reach this point, then we know the - // timer fired after we already took the NIC - // lock. Inform the timer not to invalidate the - // router when it obtains the lock as we just - // got a new RA that refreshes its lifetime to a - // non-zero value. See - // defaultRouterState.doNotInvalidate for more - // details. - *rtr.doNotInvalidate = true - } - - timer.Reset(rl) + rtr.invalidationTimer.StopLocked() + rtr.invalidationTimer.Reset(rl) + ndp.defaultRouters[ip] = rtr case ok && rl == 0: // We know about the router but it is no longer to be @@ -668,7 +609,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // invalidateDefaultRouter invalidates a discovered default router. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr, ok := ndp.defaultRouters[ip] @@ -678,16 +619,13 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.Stop() - rtr.invalidationTimer = nil - *rtr.doNotInvalidate = true - rtr.doNotInvalidate = nil + rtr.invalidationTimer.StopLocked() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) } } @@ -696,43 +634,29 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The router identified by ip MUST NOT already be known by the NIC. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) - if !remember { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return } - // Used to signal the timer not to invalidate the default router (R) in - // a race condition. See defaultRouterState.doNotInvalidate for more - // details. - var doNotInvalidate bool - - ndp.defaultRouters[ip] = defaultRouterState{ - invalidationTimer: time.AfterFunc(rl, func() { - ndp.nic.stack.mu.Lock() - defer ndp.nic.stack.mu.Unlock() - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if doNotInvalidate { - doNotInvalidate = false - return - } - + state := defaultRouterState{ + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), - doNotInvalidate: &doNotInvalidate, } - ndp.nic.stack.routeTable = routeTable + state.invalidationTimer.Reset(rl) + + ndp.defaultRouters[ip] = state } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -740,42 +664,36 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The prefix identified by prefix MUST NOT already be known. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } // Inform the integrator when we discovered an on-link prefix. - remember, routeTable := ndp.nic.stack.ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) - if !remember { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return } - // Used to signal the timer not to invalidate the on-link prefix (P) in - // a race condition. See onLinkPrefixState.doNotInvalidate for more - // details. - var doNotInvalidate bool - var timer *time.Timer - - // Only create a timer if the lifetime is not infinite. - if l < header.NDPInfiniteLifetime { - timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + state := onLinkPrefixState{ + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateOnLinkPrefix(prefix) + }), } - ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ - invalidationTimer: timer, - doNotInvalidate: &doNotInvalidate, + if l < header.NDPInfiniteLifetime { + state.invalidationTimer.Reset(l) } - ndp.nic.stack.routeTable = routeTable + ndp.onLinkPrefixes[prefix] = state } // invalidateOnLinkPrefix invalidates a discovered on-link prefix. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { s, ok := ndp.onLinkPrefixes[prefix] @@ -785,51 +703,23 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - if s.invalidationTimer != nil { - s.invalidationTimer.Stop() - s.invalidationTimer = nil - *s.doNotInvalidate = true - } - - s.doNotInvalidate = nil + s.invalidationTimer.StopLocked() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) } } -// prefixInvalidationCallback returns a new on-link prefix invalidation timer -// for prefix that fires after vl. -// -// doNotInvalidate is used to signal the timer when it fires at the same time -// that a prefix's valid lifetime gets refreshed. See -// onLinkPrefixState.doNotInvalidate for more details. -func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { - return time.AfterFunc(vl, func() { - ndp.nic.stack.mu.Lock() - defer ndp.nic.stack.mu.Unlock() - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotInvalidate { - *doNotInvalidate = false - return - } - - ndp.invalidateOnLinkPrefix(prefix) - }) -} - // handleOnLinkPrefixInformation handles a Prefix Information option with // its on-link flag set, as per RFC 4861 section 6.3.4. // // handleOnLinkPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the on-link flag is set. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { prefix := pi.Subnet() prefixState, ok := ndp.onLinkPrefixes[prefix] @@ -862,42 +752,17 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. + // // Update the invalidation timer. - timer := prefixState.invalidationTimer - - if timer == nil && vl >= header.NDPInfiniteLifetime { - // Had infinite valid lifetime before and - // continues to have an invalid lifetime. Do - // nothing further. - return - } - - if timer != nil && !timer.Stop() { - // If we reach this point, then we know the timer alread fired - // after we took the NIC lock. Inform the timer to not - // invalidate the prefix once it obtains the lock as we just - // got a new PI that refreshes its lifetime to a non-zero value. - // See onLinkPrefixState.doNotInvalidate for more details. - *prefixState.doNotInvalidate = true - } - if vl >= header.NDPInfiniteLifetime { - // Prefix is now valid forever so we don't need - // an invalidation timer. - prefixState.invalidationTimer = nil - ndp.onLinkPrefixes[prefix] = prefixState - return - } + prefixState.invalidationTimer.StopLocked() - if timer != nil { - // We already have a timer so just reset it to - // expire after the new valid lifetime. - timer.Reset(vl) - return + if vl < header.NDPInfiniteLifetime { + // Prefix is valid for a finite lifetime, reset the timer to expire after + // the new valid lifetime. + prefixState.invalidationTimer.Reset(vl) } - // We do not have a timer so just create a new one. - prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) ndp.onLinkPrefixes[prefix] = prefixState } @@ -907,7 +772,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // handleAutonomousPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the autonomous flag is set. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { vl := pi.ValidLifetime() pl := pi.PreferredLifetime() @@ -922,103 +787,30 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() // Check if we already have an auto-generated address for prefix. - for _, ref := range ndp.nic.endpoints { - if ref.protocol != header.IPv6ProtocolNumber { - continue - } - - if ref.configType != slaac { - continue - } - - addr := ref.ep.ID().LocalAddress - refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + for addr, addrState := range ndp.autoGenAddresses { + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} if refAddrWithPrefix.Subnet() != prefix { continue } - // - // At this point, we know we are refreshing a SLAAC generated - // IPv6 address with the prefix, prefix. Do the work as outlined - // by RFC 4862 section 5.5.3.e. - // - - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) - } - - // TODO(b/143713887): Handle deprecating auto-generated address - // after the preferred lifetime. - - // As per RFC 4862 section 5.5.3.e, the valid lifetime of the - // address generated by SLAAC is as follows: - // - // 1) If the received Valid Lifetime is greater than 2 hours or - // greater than RemainingLifetime, set the valid lifetime of - // the address to the advertised Valid Lifetime. - // - // 2) If RemainingLifetime is less than or equal to 2 hours, - // ignore the advertised Valid Lifetime. - // - // 3) Otherwise, reset the valid lifetime of the address to 2 - // hours. - - // Handle the infinite valid lifetime separately as we do not - // keep a timer in this case. - if vl >= header.NDPInfiniteLifetime { - if addrState.invalidationTimer != nil { - // Valid lifetime was finite before, but now it - // is valid forever. - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer = nil - addrState.validUntil = time.Time{} - ndp.autoGenAddresses[addr] = addrState - } - - return - } - - var effectiveVl time.Duration - var rl time.Duration - - // If the address was originally set to be valid forever, - // assume the remaining time to be the maximum possible value. - if addrState.invalidationTimer == nil { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(addrState.validUntil) - } - - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl <= MinPrefixInformationValidLifetimeForUpdate { - ndp.autoGenAddresses[addr] = addrState - return - } else { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate - } - - if addrState.invalidationTimer == nil { - addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) - } else { - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer.Reset(effectiveVl) - } - - addrState.validUntil = time.Now().Add(effectiveVl) - ndp.autoGenAddresses[addr] = addrState + // At this point, we know we are refreshing a SLAAC generated IPv6 address + // with the prefix prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.e. + ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) return } // We do not already have an address within the prefix, prefix. Do the // work as outlined by RFC 4862 section 5.5.3.d if n is configured // to auto-generated global addresses by SLAAC. + ndp.newAutoGenAddress(prefix, pl, vl) +} +// newAutoGenAddress generates a new SLAAC address with the provided lifetimes +// for prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) { // Are we configured to auto-generate new global addresses? if !ndp.configs.AutoGenGlobalAddresses { return @@ -1038,22 +830,24 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - // Only attempt to generate an interface-specific IID if we have a valid - // link address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - linkAddr := ndp.nic.linkEP.LinkAddress() - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return - } + addrBytes := []byte(prefix.ID()) + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) + } else { + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } - // Generate an address within prefix from the modified EUI-64 of ndp's - // NIC's Ethernet MAC address. - addrBytes := make([]byte, header.IPv6AddressSize) - copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + // Generate an address within prefix from the modified EUI-64 of ndp's NIC's + // Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } addr := tcpip.Address(addrBytes) addrWithPrefix := tcpip.AddressWithPrefix{ Address: addr, @@ -1066,37 +860,141 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform } // Inform the integrator that we have a new SLAAC address. - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } - if !ndp.nic.stack.ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { // Informed by the integrator not to add the address. return } - if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + protocolAddr := tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addrWithPrefix, - }, FirstPrimaryEndpoint, permanent, slaac); err != nil { - panic(err) + } + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + if err != nil { + log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) } - // Setup the timers to deprecate and invalidate this newly generated + state := autoGenAddressState{ + ref: ref, + deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) + } + addrState.ref.deprecated = true + ndp.notifyAutoGenAddressDeprecated(addr) + }), + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateAutoGenAddress(addr) + }), + } + + // Setup the initial timers to deprecate and invalidate this newly generated // address. - // TODO(b/143713887): Handle deprecating auto-generated addresses - // after the preferred lifetime. + if !deprecated && pl < header.NDPInfiniteLifetime { + state.deprecationTimer.Reset(pl) + } - var doNotInvalidate bool - var vTimer *time.Timer if vl < header.NDPInfiniteLifetime { - vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + state.invalidationTimer.Reset(vl) + state.validUntil = time.Now().Add(vl) + } + + ndp.autoGenAddresses[addr] = state +} + +// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated +// address addr. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + } + defer func() { ndp.autoGenAddresses[addr] = addrState }() + + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + wasDeprecated := addrState.ref.deprecated + addrState.ref.deprecated = deprecated + + // Only send the deprecation event if the deprecated status for addr just + // changed from non-deprecated to deprecated. + if !wasDeprecated && deprecated { + ndp.notifyAutoGenAddressDeprecated(addr) + } + + // If addr was preferred for some finite lifetime before, stop the deprecation + // timer so it can be reset. + addrState.deprecationTimer.StopLocked() + + // Reset the deprecation timer if addr has a finite preferred lifetime. + if !deprecated && pl < header.NDPInfiniteLifetime { + addrState.deprecationTimer.Reset(pl) + } + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address + // + // + // 1) If the received Valid Lifetime is greater than 2 hours or greater than + // RemainingLifetime, set the valid lifetime of the address to the + // advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the + // advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + + // Handle the infinite valid lifetime separately as we do not keep a timer in + // this case. + if vl >= header.NDPInfiniteLifetime { + addrState.invalidationTimer.StopLocked() + addrState.validUntil = time.Time{} + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, assume the remaining + // time to be the maximum possible value. + if addrState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - ndp.autoGenAddresses[addr] = autoGenAddressState{ - invalidationTimer: vTimer, - doNotInvalidate: &doNotInvalidate, - validUntil: time.Now().Add(vl), + addrState.invalidationTimer.StopLocked() + addrState.invalidationTimer.Reset(effectiveVl) + addrState.validUntil = time.Now().Add(effectiveVl) +} + +// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr +// has been deprecated. +func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) } } @@ -1120,23 +1018,16 @@ func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { state, ok := ndp.autoGenAddresses[addr] - if !ok { return false } - if state.invalidationTimer != nil { - state.invalidationTimer.Stop() - state.invalidationTimer = nil - *state.doNotInvalidate = true - } - - state.doNotInvalidate = nil - + state.deprecationTimer.StopLocked() + state.invalidationTimer.StopLocked() delete(ndp.autoGenAddresses, addr) - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ Address: addr, PrefixLen: validPrefixLenForAutoGen, }) @@ -1145,22 +1036,37 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// autoGenAddrInvalidationTimer returns a new invalidation timer for an -// auto-generated address that fires after vl. +// cleanupHostOnlyState cleans up any state that is only useful for hosts. // -// doNotInvalidate is used to inform the timer when it fires at the same time -// that an auto-generated address's valid lifetime gets refreshed. See -// autoGenAddrState.doNotInvalidate for more details. -func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { - return time.AfterFunc(vl, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotInvalidate { - *doNotInvalidate = false - return - } - +// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a +// host to a router. This function will invalidate all discovered on-link +// prefixes, discovered routers, and auto-generated addresses as routers do not +// normally process Router Advertisements to discover default routers and +// on-link prefixes, and auto-generate addresses via SLAAC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupHostOnlyState() { + for addr, _ := range ndp.autoGenAddresses { ndp.invalidateAutoGenAddress(addr) - }) + } + + if got := len(ndp.autoGenAddresses); got != 0 { + log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got) + } + + for prefix, _ := range ndp.onLinkPrefixes { + ndp.invalidateOnLinkPrefix(prefix) + } + + if got := len(ndp.onLinkPrefixes); got != 0 { + log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got) + } + + for router, _ := range ndp.defaultRouters { + ndp.invalidateDefaultRouter(router) + } + + if got := len(ndp.defaultRouters); got != 0 { + log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) + } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d8e7ce67e..d334af289 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -29,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const ( @@ -45,8 +48,25 @@ var ( llAddr1 = header.LinkLocalAddr(linkAddr1) llAddr2 = header.LinkLocalAddr(linkAddr2) llAddr3 = header.LinkLocalAddr(linkAddr3) + dstAddr = tcpip.FullAddress{ + Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + Port: 25, + } ) +func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return tcpip.AddressWithPrefix{} + } + + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } +} + // prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent // tcpip.Subnet, and an address where the lower half of the address is composed // of the EUI-64 of linkAddr if it is a valid unicast ethernet address. @@ -59,17 +79,7 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi subnet := prefix.Subnet() - var addr tcpip.AddressWithPrefix - if header.IsValidUnicastEthernetAddress(linkAddr) { - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } - } - - return prefix, subnet, addr + return prefix, subnet, addrForSubnet(subnet, linkAddr) } // TestDADDisabled tests that an address successfully resolves immediately @@ -79,7 +89,7 @@ func TestDADDisabled(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -132,6 +142,7 @@ type ndpAutoGenAddrEventType int const ( newAddr ndpAutoGenAddrEventType = iota + deprecatedAddr invalidatedAddr ) @@ -163,7 +174,6 @@ type ndpDispatcher struct { rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent - routeTable []tcpip.Route } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -179,106 +189,56 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } // Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ nicID, addr, true, } } - if !n.rememberRouter { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - }) - n.routeTable = rt - return true, rt + return n.rememberRouter } // Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ nicID, addr, false, } } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt } // Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ nicID, prefix, true, } } - if !n.rememberPrefix { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: prefix, - NIC: nicID, - }) - n.routeTable = rt - return true, rt + return n.rememberPrefix } // Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ nicID, prefix, false, } } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: prefix, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt } func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { - if n.autoGenAddrC != nil { - n.autoGenAddrC <- ndpAutoGenAddrEvent{ + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ nicID, addr, newAddr, @@ -287,9 +247,19 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi return true } +func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + deprecatedAddr, + } + } +} + func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if n.autoGenAddrC != nil { - n.autoGenAddrC <- ndpAutoGenAddrEvent{ + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ nicID, addr, invalidatedAddr, @@ -299,8 +269,8 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi // Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if n.rdnssC != nil { - n.rdnssC <- ndpRDNSSEvent{ + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ nicID, ndpRDNSS{ addrs, @@ -330,6 +300,8 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { t.Parallel() @@ -520,7 +492,7 @@ func TestDADFail(t *testing.T) { } opts.NDPConfigs.RetransmitTimer = time.Second * 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -601,7 +573,7 @@ func TestDADStop(t *testing.T) { NDPConfigs: ndpConfigs, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -702,7 +674,7 @@ func TestSetNDPConfigurations(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPDisp: &ndpDisp, @@ -903,8 +875,6 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { - t.Parallel() - // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -920,9 +890,9 @@ func TestNoRouterDiscovery(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -942,21 +912,27 @@ func TestNoRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("unexpectedly discovered a router when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -970,50 +946,25 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - - // Rx an RA with short lifetime. - lifetime := time.Duration(1) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr2 { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") + default: + t.Fatal("expected router discovery event") } - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) - } - - // Wait for the normal invalidation time plus an extra second to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the router in the first place. + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } @@ -1021,10 +972,10 @@ func TestRouterDiscovery(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1034,22 +985,29 @@ func TestRouterDiscovery(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) { + expectRouterEvent := func(addr tcpip.Address, discovered bool) { t.Helper() select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } case <-time.After(timeout): - t.Fatal("timeout waiting for router discovery event") + t.Fatal("timed out waiting for router discovery event") } } @@ -1063,40 +1021,25 @@ func TestRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("unexpectedly discovered a router with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Rx an RA from lladdr2 with a huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) - - // Should have a default route through the discovered router. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr2, true) // Rx an RA from another router (lladdr3) with non-zero lifetime. - l3Lifetime := time.Duration(6) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) - waitForEvent(llAddr3, true, defaultTimeout) - - // Should have default routes through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) // Rx an RA from lladdr2 with lesser lifetime. - l2Lifetime := time.Duration(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) select { case <-ndpDisp.routerC: t.Fatal("Should not receive a router event when updating lifetimes for known routers") - case <-time.After(defaultTimeout): - } - - // Should still have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + default: } // Wait for lladdr2's router invalidation timer to fire. The lifetime @@ -1106,31 +1049,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout) - - // Should no longer have the default route through lladdr2. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) - - // Should have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr2, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - waitForEvent(llAddr2, false, defaultTimeout) - - // Should have deleted the default route through the router that just - // got invalidated. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr2, false) // Wait for lladdr3's router invalidation timer to fire. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1139,13 +1066,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout) - - // Should not have any routes now that all discovered routers have been - // invalidated. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1154,10 +1075,10 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1171,8 +1092,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{} - // Receive an RA from 2 more than the max number of discovered routers. for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} @@ -1182,43 +1101,28 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) if i <= stack.MaxDiscoveredDefaultRouters { - expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") + default: + t.Fatal("expected router discovery event") } } else { select { case <-ndpDisp.routerC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - case <-time.After(defaultTimeout): + default: } } } - - // Should only have default routes for the first - // stack.MaxDiscoveredDefaultRouters discovered routers. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) - } } // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { - t.Parallel() - prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1239,9 +1143,9 @@ func TestNoPrefixDiscovery(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1262,12 +1166,18 @@ func TestNoPrefixDiscovery(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1276,9 +1186,9 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1293,50 +1203,25 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - - // Rx an RA with prefix with a short lifetime. - const lifetime = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0)) + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + default: + t.Fatal("expected prefix discovery event") } - // Wait for the normal invalidation time plus some buffer to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the prefix in the first place. + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } @@ -1348,10 +1233,10 @@ func TestPrefixDiscovery(t *testing.T) { prefix3, subnet3, _ := prefixSubnetAddr(2, "") ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1361,74 +1246,48 @@ func TestPrefixDiscovery(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(subnet tcpip.Subnet, discovered bool, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { t.Helper() select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - waitForEvent(subnet1, true, defaultTimeout) - - // Should have added a device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectPrefixEvent(subnet1, true) // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - waitForEvent(subnet2, true, defaultTimeout) - - // Should have added a device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectPrefixEvent(subnet2, true) // Receive an RA with prefix3 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - waitForEvent(subnet3, true, defaultTimeout) - - // Should have added a device route for subnet3 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectPrefixEvent(subnet3, true) // Receive an RA with prefix1 in a PI with lifetime = 0. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - waitForEvent(subnet1, false, defaultTimeout) - - // Should have removed the device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectPrefixEvent(subnet1, false) // Receive an RA with prefix2 in a PI with lesser lifetime. lifetime := uint32(2) @@ -1436,31 +1295,23 @@ func TestPrefixDiscovery(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly received prefix event when updating lifetime") - case <-time.After(defaultTimeout): - } - - // Should not have updated route table. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + default: } // Wait for prefix2's most recent invalidation timer plus some buffer to // expire. - waitForEvent(subnet2, false, time.Duration(lifetime)*time.Second+defaultTimeout) - - // Should have removed the device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + t.Fatal("timed out waiting for prefix discovery event") } // Receive RA to invalidate prefix3. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - waitForEvent(subnet3, false, defaultTimeout) - - // Should not have any routes. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } + expectPrefixEvent(subnet3, false) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1482,10 +1333,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1495,33 +1346,27 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(discovered bool, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { t.Helper() select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Errorf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Errorf("got r.prefix = %s, want = %s", r.prefix, subnet) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if r.discovered != discovered { - t.Errorf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - waitForEvent(true, defaultTimeout) + expectPrefixEvent(subnet, true) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1531,11 +1376,18 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - waitForEvent(false, testInfiniteLifetime) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } // Receive an RA with finite lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - waitForEvent(true, defaultTimeout) + expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. @@ -1558,7 +1410,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with 0 lifetime. // The prefix should get invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - waitForEvent(false, defaultTimeout) + expectPrefixEvent(subnet, false) } // TestPrefixDiscoveryMaxRouters tests that only @@ -1570,7 +1422,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1586,7 +1438,6 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - expectedRt := [stack.MaxDiscoveredOnLinkPrefixes]tcpip.Route{} prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link @@ -1606,43 +1457,27 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { copy(buf[14:], prefix.Address) optSer[i] = header.NDPPrefixInformation(buf[:]) - - if i < stack.MaxDiscoveredOnLinkPrefixes { - expectedRt[i] = tcpip.Route{prefixes[i], tcpip.Address([]byte(nil)), 1} - } } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { if i < stack.MaxDiscoveredOnLinkPrefixes { select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != prefixes[i] { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, prefixes[i]) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } else { select { case <-ndpDisp.prefixC: t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - case <-time.After(defaultTimeout): + default: } } } - - // Should only have device routes for the first - // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) - } } // Checks to see if list contains an IPv6 address, item. @@ -1663,8 +1498,6 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. func TestNoAutoGenAddr(t *testing.T) { - t.Parallel() - prefix, _, _ := prefixSubnetAddr(0, "") // Being configured to auto-generate addresses means handle and @@ -1682,9 +1515,9 @@ func TestNoAutoGenAddr(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1705,12 +1538,18 @@ func TestNoAutoGenAddr(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { @@ -1726,9 +1565,9 @@ func TestAutoGenAddr(t *testing.T) { prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1738,42 +1577,36 @@ func TestAutoGenAddr(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != eventType { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for addr auto gen event") + default: + t.Fatal("expected addr auto gen event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - waitForEvent(addr1, newAddr, defaultTimeout) + expectAutoGenAddrEvent(addr1, newAddr) if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -1784,12 +1617,12 @@ func TestAutoGenAddr(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - waitForEvent(addr2, newAddr, defaultTimeout) + expectAutoGenAddrEvent(addr2, newAddr) if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -1802,11 +1635,18 @@ func TestAutoGenAddr(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - case <-time.After(defaultTimeout): + default: } // Wait for addr of prefix1 to be invalidated. - waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should not have %s in the list of addresses", addr1) } @@ -1815,12 +1655,529 @@ func TestAutoGenAddr(t *testing.T) { } } +// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, +// channel.Endpoint and stack.Stack. +// +// stack.Stack will have a default route through the router (llAddr3) installed +// and a static link-address (linkAddr3) added to the link address cache for the +// router. +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { + t.Helper() + ndpDisp := &ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + return ndpDisp, e, s +} + +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// addrForNewConnectionWithAddr returns the local address used when creating a +// new connection with a specific local address. +func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + if err := ep.Bind(addr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", addr, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when +// receiving a PI with 0 preferred lifetime. +func TestAutoGenAddrDeprecateFromPI(t *testing.T) { + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) + + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) + + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) +} + +// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// when its preferred lifetime expires. +func TestAutoGenAddrTimerDeprecation(t *testing.T) { + const nicID = 1 + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) + + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout) + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") + } + + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } +} + +// Tests transitioning a SLAAC address's valid lifetime between finite and +// infinite values. +func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { + const infiniteVLSeconds = 2 + const minVLSeconds = 1 + savedIL := header.NDPInfiniteLifetime + savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + header.NDPInfiniteLifetime = savedIL + }() + stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + infiniteVL uint32 + }{ + { + name: "EqualToInfiniteVL", + infiniteVL: infiniteVLSeconds, + }, + // Our implementation supports changing header.NDPInfiniteLifetime for tests + // such that a packet can be received where the lifetime field has a value + // greater than header.NDPInfiniteLifetime. Because of this, we test to make + // sure that receiving a value greater than header.NDPInfiniteLifetime is + // handled the same as when receiving a value equal to + // header.NDPInfiniteLifetime. + { + name: "MoreThanInfiniteVL", + infiniteVL: infiniteVLSeconds + 1, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(minVLSeconds*time.Second + defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an // auto-generated address only gets updated when required to, as specified in // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 5 + const newMinVL = 4 saved := stack.MinPrefixInformationValidLifetimeForUpdate defer func() { stack.MinPrefixInformationValidLifetimeForUpdate = saved @@ -1896,78 +2253,81 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const delta = 500 * time.Millisecond - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + t.Run(test.name, func(t *testing.T) { + t.Parallel() - // Receive an RA with prefix with initial VL, test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), } - if r.eventType != newAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - // Receive an new RA with prefix with new VL, test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } - // - // Validate that the VL for the address got set to - // test.evl. - // + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // Make sure we do not get any invalidation events - // until atleast 500ms (delta) before test.evl. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - delta): - } + // + // Validate that the VL for the address got set + // to test.evl. + // - // Wait for another second (2x delta), but now we expect - // the invalidation event. - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + // Make sure we do not get any invalidation + // events until atleast 500ms (delta) before + // test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): } - if r.eventType != invalidatedAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + + // Wait for another second (2x delta), but now + // we expect the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") } - case <-time.After(2 * delta): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } + }) + } + }) } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -1979,9 +2339,9 @@ func TestAutoGenAddrRemoval(t *testing.T) { prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1995,51 +2355,37 @@ func TestAutoGenAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - // Receive an RA with prefix with its valid lifetime = lifetime. - const lifetime = 5 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != newAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") } - // Remove the address. + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. if err := s.RemoveAddress(1, addr.Address); err != nil { t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) } - - // Should get the invalidation event immediately. - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != invalidatedAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } + expectAutoGenAddrEvent(addr, invalidatedAddr) // Wait for the original valid lifetime to make sure the original timer // got stopped/cleaned up. select { case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") - case <-time.After(lifetime*time.Second + defaultTimeout): + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } @@ -2051,9 +2397,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -2077,12 +2423,12 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Receive a PI where the generated address will be the same as the one // that we already have assigned statically. - const lifetime = 5 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - case <-time.After(defaultTimeout): + default: } if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -2093,13 +2439,117 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetime*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } +// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use +// opaque interface identifiers when configured to do so. +func TestAutoGenAddrWithOpaqueIID(t *testing.T) { + t.Parallel() + + const nicID = 1 + const nicName = "nic1" + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) + // addr1 and addr2 are the addresses that are expected to be generated when + // stack.Stack is configured to generate opaque interface identifiers as + // defined by RFC 7217. + addrBytes := []byte(subnet1.ID()) + addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), + PrefixLen: 64, + } + addrBytes = []byte(subnet2.ID()) + addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), + PrefixLen: 64, + } + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in a PI. + const validLifetimeSecondPrefix1 = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in a PI with a large valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } +} + // TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. @@ -2247,3 +2697,257 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { }) } } + +// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers +// and prefixes, and auto-generated addresses get invalidated when a NIC +// becomes a router. +func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { + t.Parallel() + + const ( + lifetimeSeconds = 5 + maxEvents = 4 + nicID1 = 1 + nicID2 = 2 + ) + + prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) + e2Addr1 := addrForSubnet(subnet1, linkAddr2) + e2Addr2 := addrForSubnet(subnet2, linkAddr2) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } + + return false, ndpRouterEvent{} + } + + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } + + return false, ndpPrefixEvent{} + } + + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } + + return false, ndpAutoGenAddrEvent{} + } + + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and + // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from + // llAddr2) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } + + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !contains(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !contains(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !contains(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !contains(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } + + s.SetForwarding(true) + + // Collect invalidation events after becoming a router + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } + + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr1, discovered: false}: 1, + {nicID: nicID1, addr: llAddr2, discovered: false}: 1, + {nicID: nicID2, addr: llAddr1, discovered: false}: 1, + {nicID: nicID2, addr: llAddr2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if contains(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if contains(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if contains(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if contains(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when we transitioned into a router). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e8401c673..3810c6602 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,11 +27,11 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint - loopback bool + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + context NICContext mu sync.RWMutex spoofing bool @@ -85,7 +85,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -99,7 +99,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback id: id, name: name, linkEP: ep, - loopback: loopback, + context: ctx, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -174,23 +174,28 @@ func (n *NIC) enable() *tcpip.Error { return err } - if !n.stack.autoGenIPv6LinkLocal { + // Do not auto-generate an IPv6 link-local address for loopback devices. + if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() { return nil } - l2addr := n.linkEP.LinkAddress() + var addr tcpip.Address + if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) + } else { + l2addr := n.linkEP.LinkAddress() - // Only attempt to generate the link-local address if we have a - // valid MAC address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil - } + // Only attempt to generate the link-local address if we have a valid MAC + // address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } - addr := header.LinkLocalAddr(l2addr) + addr = header.LinkLocalAddr(l2addr) + } _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, @@ -203,6 +208,18 @@ func (n *NIC) enable() *tcpip.Error { return err } +// becomeIPv6Router transitions n into an IPv6 router. +// +// When transitioning into an IPv6 router, host-only state (NDP discovered +// routers, discovered on-link prefixes, and auto-generated addresses) will +// be cleaned up/invalidated. +func (n *NIC) becomeIPv6Router() { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.cleanupHostOnlyState() +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -223,6 +240,10 @@ func (n *NIC) isPromiscuousMode() bool { return rv } +func (n *NIC) isLoopback() bool { + return n.linkEP.Capabilities()&CapabilityLoopback != 0 +} + // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() @@ -232,17 +253,47 @@ func (n *NIC) setSpoofing(enable bool) { // primaryEndpoint returns the primary endpoint of n for the given network // protocol. +// +// primaryEndpoint will return the first non-deprecated endpoint if such an +// endpoint exists. If no non-deprecated endpoint exists, the first deprecated +// endpoint will be returned. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { n.mu.RLock() defer n.mu.RUnlock() + var deprecatedEndpoint *referencedNetworkEndpoint for _, r := range n.primary[protocol] { - if r.isValidForOutgoing() && r.tryIncRef() { - return r + if !r.isValidForOutgoing() { + continue + } + + if !r.deprecated { + if r.tryIncRef() { + // r is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + deprecatedEndpoint.decRefLocked() + deprecatedEndpoint = nil + } + + return r + } + } else if deprecatedEndpoint == nil && r.tryIncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of r in + // case n doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, r's reference count + // will be decremented when such an endpoint is found. + deprecatedEndpoint = r } } - return nil + // n doesn't have any valid non-deprecated endpoints, so return + // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated + // endpoints either). + return deprecatedEndpoint } // hasPermanentAddrLocked returns true if n has a permanent (including currently @@ -350,7 +401,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary, static) + }, peb, temporary, static, false) n.mu.Unlock() return ref @@ -399,10 +450,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent, static) + return n.addAddressLocked(protocolAddress, peb, permanent, static, false) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -438,6 +489,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar protocol: protocolAddress.Protocol, kind: kind, configType: configType, + deprecated: deprecated, } // Set up cache if link address resolution exists for this protocol. @@ -536,6 +588,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { return addrs } +// primaryAddress returns the primary address associated with this NIC. +// +// primaryAddress will return the first non-deprecated address if such an +// address exists. If no non-deprecated address exists, the first deprecated +// address will be returned. +func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { + n.mu.RLock() + defer n.mu.RUnlock() + + list, ok := n.primary[proto] + if !ok { + return tcpip.AddressWithPrefix{} + } + + var deprecatedEndpoint *referencedNetworkEndpoint + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints to avoid confusion + // and prevent the caller from using those. + switch ref.getKind() { + case permanentTentative, permanentExpired, temporary: + continue + } + + if !ref.deprecated { + return tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + } + } + + if deprecatedEndpoint == nil { + deprecatedEndpoint = ref + } + } + + if deprecatedEndpoint != nil { + return tcpip.AddressWithPrefix{ + Address: deprecatedEndpoint.ep.ID().LocalAddress, + PrefixLen: deprecatedEndpoint.ep.PrefixLen(), + } + } + + return tcpip.AddressWithPrefix{} +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are @@ -1092,6 +1189,11 @@ type referencedNetworkEndpoint struct { // configType is the method that was used to configure this endpoint. // This must never change after the endpoint is added to a NIC. configType networkEndpointConfigType + + // deprecated indicates whether or not the endpoint should be considered + // deprecated. That is, when deprecated is true, other endpoints that are not + // deprecated should be preferred. + deprecated bool } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 61fd46d66..2b8751d49 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -234,15 +234,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 34307ae07..517f4b941 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { @@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network return 0, tcpip.ErrInvalidEndpointState } - n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) } @@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0e88643a4..41bf9fd9b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -352,6 +352,38 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it will be passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -412,8 +444,8 @@ type Stack struct { ndpConfigs NDPConfigurations // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. - // See the AutoGenIPv6LinkLocal field of Options for more details. + // to auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. autoGenIPv6LinkLocal bool // ndpDisp is the NDP event dispatcher that is used to send the netstack @@ -422,6 +454,10 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions } // UniqueID is an abstract generator of unique identifiers. @@ -460,13 +496,15 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determins whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address - // Detection is enabled, an address will only be assigned if it - // successfully resolves. If it fails, no further attempt will be made - // to auto-generate an IPv6 link-local address. + // will be assigned right away, or at all. If Duplicate Address Detection + // is enabled, an address will only be assigned if it successfully resolves. + // If it fails, no further attempt will be made to auto-generate an IPv6 + // link-local address. // // The generated link-local address will follow RFC 4291 Appendix A // guidelines. @@ -479,6 +517,10 @@ type Options struct { // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory + + // OpaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions } // TransportEndpointInfo holds useful information about a transport endpoint @@ -549,6 +591,7 @@ func New(opts Options) *Stack { autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + opaqueIIDOpts: opts.OpaqueIIDOpts, } // Add specified network protocols. @@ -662,11 +705,31 @@ func (s *Stack) Stats() tcpip.Stats { } // SetForwarding enables or disables the packet forwarding between NICs. +// +// When forwarding becomes enabled, any host-only state on all NICs will be +// cleaned up. func (s *Stack) SetForwarding(enable bool) { // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. s.mu.Lock() + defer s.mu.Unlock() + + // If forwarding status didn't change, do nothing further. + if s.forwarding == enable { + return + } + s.forwarding = enable - s.mu.Unlock() + + // If this stack does not support IPv6, do nothing further. + if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + return + } + + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } } // Forwarding returns if the packet forwarding between NICs is enabled. @@ -733,9 +796,30 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } -// createNIC creates a NIC with the provided id and link-layer endpoint, and -// optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { +// NICContext is an opaque pointer used to store client-supplied NIC metadata. +type NICContext interface{} + +// NICOptions specifies the configuration of a NIC as it is being created. +// The zero value creates an enabled, unnamed NIC. +type NICOptions struct { + // Name specifies the name of the NIC. + Name string + + // Disabled specifies whether to avoid calling Attach on the passed + // LinkEndpoint. + Disabled bool + + // Context specifies user-defined data that will be returned in stack.NICInfo + // for the NIC. Clients of this library can use it to add metadata that + // should be tracked alongside a NIC, to avoid having to keep a + // map[tcpip.NICID]metadata mirroring stack.Stack's nic map. + Context NICContext +} + +// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and +// NICOptions. See the documentation on type NICOptions for details on how +// NICs can be configured. +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -744,44 +828,20 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep, loopback) + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n - if enabled { + if !opts.Disabled { return n.enable() } return nil } -// CreateNIC creates a NIC with the provided id and link-layer endpoint. +// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls +// `LinkEndpoint.Attach` to start delivering packets to it. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, true, false) -} - -// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, -// and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, false) -} - -// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer -// endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, true) -} - -// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, -// but leave it disable. Stack.EnableNIC must be called before the link-layer -// endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, false, false) -} - -// CreateDisabledNamedNIC is a combination of CreateNamedNIC and -// CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, false, false) + return s.CreateNICWithOptions(id, ep, NICOptions{}) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -835,6 +895,18 @@ type NICInfo struct { MTU uint32 Stats NICStats + + // Context is user-supplied data optionally supplied in CreateNICWithOptions. + // See type NICOptions for more details. + Context NICContext +} + +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok } // NICInfo returns a map of NICIDs to their associated information. @@ -848,7 +920,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Up: true, // Netstack interfaces are always up. Running: nic.linkEP.IsAttached(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0, + Loopback: nic.isLoopback(), } nics[id] = NICInfo{ Name: nic.name, @@ -857,6 +929,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, + Context: nic.context, } } return nics @@ -973,9 +1046,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { return nics } -// GetMainNICAddress returns the first primary address and prefix for the given -// NIC and protocol. Returns an error if the NIC doesn't exist and an empty -// value if the NIC doesn't have a primary address for the given protocol. +// GetMainNICAddress returns the first non-deprecated primary address and prefix +// for the given NIC and protocol. If no non-deprecated primary address exists, +// a deprecated primary address and prefix will be returned. Returns an error if +// the NIC doesn't exist and an empty value if the NIC doesn't have a primary +// address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -985,12 +1060,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - for _, a := range nic.PrimaryAddresses() { - if a.Protocol == protocol { - return a.AddressWithPrefix, nil - } - } - return tcpip.AddressWithPrefix{}, nil + return nic.primaryAddress(protocol), nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -1012,7 +1082,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil + return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } } else { @@ -1028,7 +1098,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n remoteAddr = ref.ep.ID().LocalAddress } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback) + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) if needRoute { r.NextHop = route.Gateway } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8fc034ca1..44e5229cc 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,10 +27,12 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -122,7 +124,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -133,7 +135,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -141,7 +143,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -149,11 +151,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -1894,55 +1896,67 @@ func TestNICForwarding(t *testing.T) { } // TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. +// using the modified EUI-64 of the NIC's MAC address (or lack there-of if +// disabled (default)). Note, DAD will be disabled in these tests. func TestNICAutoGenAddr(t *testing.T) { tests := []struct { name string autoGen bool linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions shouldGen bool }{ { "Disabled", false, linkAddr1, + stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID, _ string) string { + return fmt.Sprintf("nic%d", nicID) + }, + }, false, }, { "Enabled", true, linkAddr1, + stack.OpaqueInterfaceIdentifierOptions{}, true, }, { "Nil MAC", true, tcpip.LinkAddress([]byte(nil)), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Empty MAC", true, tcpip.LinkAddress(""), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Invalid MAC", true, tcpip.LinkAddress("\x01\x02\x03"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Multicast MAC", true, tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Unspecified MAC", true, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, } @@ -1951,13 +1965,12 @@ func TestNICAutoGenAddr(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: test.iidOpts, } if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. + // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by default. opts.AutoGenIPv6LinkLocal = true } @@ -1973,9 +1986,171 @@ func TestNICAutoGenAddr(t *testing.T) { } if test.shouldGen { + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) + } +} + +// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local +// addresses with opaque interface identifiers. Link Local addresses should +// always be generated with opaque IIDs if configured to use them, even if the +// NIC has an invalid MAC address. +func TestNICAutoGenAddrWithOpaque(t *testing.T) { + const nicID = 1 + + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte + n, err := rand.Read(secretKey[:]) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) + } + + tests := []struct { + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + secretKey []byte + }{ + { + name: "Disabled", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + secretKey: secretKey[:], + }, + { + name: "Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + secretKey: secretKey[:], + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. + { + name: "Nil MAC and empty nicName", + nicName: "", + autoGen: true, + linkAddr: tcpip.LinkAddress([]byte(nil)), + secretKey: secretKey[:1], + }, + { + name: "Empty MAC and empty nicName", + autoGen: true, + linkAddr: tcpip.LinkAddress(""), + secretKey: secretKey[:2], + }, + { + name: "Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03"), + secretKey: secretKey[:3], + }, + { + name: "Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + secretKey: secretKey[:4], + }, + { + name: "Unspecified MAC and nil SecretKey", + nicName: "test3", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: test.secretKey, + }, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: test.nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + + if test.autoGen { // Should have auto-generated an address and // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) } } else { @@ -1988,6 +2163,56 @@ func TestNICAutoGenAddr(t *testing.T) { } } +// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are +// not auto-generated for loopback NICs. +func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { + const nicID = 1 + const nicName = "nicName" + + tests := []struct { + name string + opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + }{ + { + name: "IID From MAC", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + }, + { + name: "Opaque IID", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + } + + e := loopback.New() + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + } + }) + } +} + // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 3b28b06d0..5e9237de9 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -41,7 +41,7 @@ const ( type testContext struct { t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack ep tcpip.Endpoint @@ -61,35 +61,29 @@ func (c *testContext) createV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -108,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) return &testContext{ t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } @@ -125,7 +119,7 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -156,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -186,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) { func TestDistribution(t *testing.T) { type endpointSockopts struct { reuse int - bindToDevice string + bindToDevice tcpip.NICID } for _, test := range []struct { name string @@ -194,71 +188,71 @@ func TestDistribution(t *testing.T) { endpoints []endpointSockopts // wantedDistribution is the wanted ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantedDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{0, "dev0"}, - endpointSockopts{0, "dev1"}, - endpointSockopts{0, "dev2"}, + {0, 1}, + {0, 2}, + {0, 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": []float64{1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": []float64{0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": []float64{0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, ""}, + {1, 1}, + {1, 1}, + {1, 2}, + {1, 2}, + {1, 2}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": []float64{0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { t.Run(test.name, func(t *testing.T) { for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string + t.Run(string(device), func(t *testing.T) { + var devices []tcpip.NICID for d := range test.wantedDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) defer c.cleanup() c.createV6Endpoint(false) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 748ce4ea5..f50604a8a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -102,13 +102,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptBool sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..72b5ce179 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -423,17 +423,25 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptBool sets a socket option, for simple cases where a value + // has the bool type. + SetSockOptBool(opt SockOptBool, v bool) *Error + // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOpt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) *Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + // GetSockOptBool gets a socket option for simple cases where a return + // value has the bool type. + GetSockOptBool(SockOptBool) (bool, *Error) + // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOpt) (int, *Error) + GetSockOptInt(SockOptInt) (int, *Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -488,13 +496,22 @@ type WriteOptions struct { Atomic bool } -// SockOpt represents socket options which values have the int type. -type SockOpt int +// SockOptBool represents socket options which values have the bool type. +type SockOptBool int + +const ( + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption SockOptBool = iota +) + +// SockOptInt represents socket options which values have the int type. +type SockOptInt int const ( // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOpt = iota + ReceiveQueueSizeOption SockOptInt = iota // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -521,10 +538,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 -// socket is to be restricted to sending and receiving IPv6 packets only. -type V6OnlyOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -539,7 +552,7 @@ type ReusePortOption int // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. -type BindToDeviceOption string +type BindToDeviceOption NICID // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go new file mode 100644 index 000000000..f5f01f32f --- /dev/null +++ b/pkg/tcpip/timer.go @@ -0,0 +1,161 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "sync" + "time" +) + +// cancellableTimerInstance is a specific instance of CancellableTimer. +// +// Different instances are created each time CancellableTimer is Reset so each +// timer has its own earlyReturn signal. This is to address a bug when a +// CancellableTimer is stopped and reset in quick succession resulting in a +// timer instance's earlyReturn signal being affected or seen by another timer +// instance. +// +// Consider the following sceneario where timer instances share a common +// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a +// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second +// (B), third (C), and fourth (D) instance of the timer firing, respectively): +// T1: Obtain L +// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T2: instance A fires, blocked trying to obtain L. +// T1: Attempt to stop instance A (set earlyReturn = true) +// T1: Reset timer (create instance B) +// T3: instance B fires, blocked trying to obtain L. +// T1: Attempt to stop instance B (set earlyReturn = true) +// T1: Reset timer (create instance C) +// T4: instance C fires, blocked trying to obtain L. +// T1: Attempt to stop instance C (set earlyReturn = true) +// T1: Reset timer (create instance D) +// T5: instance D fires, blocked trying to obtain L. +// T1: Release L +// +// Now that T1 has released L, any of the 4 timer instances can take L and check +// earlyReturn. If the timers simply check earlyReturn and then do nothing +// further, then instance D will never early return even though it was not +// requested to stop. If the timers reset earlyReturn before early returning, +// then all but one of the timers will do work when only one was expected to. +// If CancellableTimer resets earlyReturn when resetting, then all the timers +// will fire (again, when only one was expected to). +// +// To address the above concerns the simplest solution was to give each timer +// its own earlyReturn signal. +type cancellableTimerInstance struct { + timer *time.Timer + + // Used to inform the timer to early return when it gets stopped while the + // lock the timer tries to obtain when fired is held (T1 is a goroutine that + // tries to cancel the timer and T2 is the goroutine that handles the timer + // firing): + // T1: Obtain the lock, then call StopLocked() + // T2: timer fires, and gets blocked on obtaining the lock + // T1: Releases lock + // T2: Obtains lock does unintended work + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using earlyReturn to return early so that once T2 obtains + // the lock, it will see that it is set to true and do nothing further. + earlyReturn *bool +} + +// stop stops the timer instance t from firing if it hasn't fired already. If it +// has fired and is blocked at obtaining the lock, earlyReturn will be set to +// true so that it will early return when it obtains the lock. +func (t *cancellableTimerInstance) stop() { + if t.timer != nil { + t.timer.Stop() + *t.earlyReturn = true + } +} + +// CancellableTimer is a timer that does some work and can be safely cancelled +// when it fires at the same time some "related work" is being done. +// +// The term "related work" is defined as some work that needs to be done while +// holding some lock that the timer must also hold while doing some work. +type CancellableTimer struct { + // The active instance of a cancellable timer. + instance cancellableTimerInstance + + // locker is the lock taken by the timer immediately after it fires and must + // be held when attempting to stop the timer. + // + // Must never change after being assigned. + locker sync.Locker + + // fn is the function that will be called when a timer fires and has not been + // signaled to early return. + // + // fn MUST NOT attempt to lock locker. + // + // Must never change after being assigned. + fn func() +} + +// StopLocked prevents the Timer from firing if it has not fired already. +// +// If the timer is blocked on obtaining the t.locker lock when StopLocked is +// called, it will early return instead of calling t.fn. +// +// Note, t will be modified. +// +// t.locker MUST be locked. +func (t *CancellableTimer) StopLocked() { + t.instance.stop() + + // Nothing to do with the stopped instance anymore. + t.instance = cancellableTimerInstance{} +} + +// Reset changes the timer to expire after duration d. +// +// Note, t will be modified. +// +// Reset should only be called on stopped or expired timers. To be safe, callers +// should always call StopLocked before calling Reset. +func (t *CancellableTimer) Reset(d time.Duration) { + // Create a new instance. + earlyReturn := false + t.instance = cancellableTimerInstance{ + timer: time.AfterFunc(d, func() { + t.locker.Lock() + defer t.locker.Unlock() + + if earlyReturn { + // If we reach this point, it means that the timer fired while another + // goroutine called StopLocked while it had the lock. Simply return + // here and do nothing further. + earlyReturn = false + return + } + + t.fn() + }), + earlyReturn: &earlyReturn, + } +} + +// MakeCancellableTimer returns an unscheduled CancellableTimer with the given +// locker and fn. +// +// fn MUST NOT attempt to lock locker. +// +// Callers must call Reset to schedule the timer to fire. +func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer { + return CancellableTimer{locker: locker, fn: fn} +} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go new file mode 100644 index 000000000..1f735d735 --- /dev/null +++ b/pkg/tcpip/timer_test.go @@ -0,0 +1,236 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timer_test + +import ( + "sync" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + shortDuration = 1 * time.Nanosecond + middleDuration = 100 * time.Millisecond + longDuration = 1 * time.Second +) + +func TestCancellableTimerFire(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + timer := tcpip.MakeCancellableTimer(&lock, func() { + ch <- struct{}{} + }) + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromLongDuration(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(middleDuration) + + lock.Lock() + timer.StopLocked() + lock.Unlock() + + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromShortDuration(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } + + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerImmediatelyStop(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + for i := 0; i < 1000; i++ { + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + } + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + + for i := 0; i < 10; i++ { + timer.Reset(middleDuration) + + lock.Lock() + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration * 2) + timer.StopLocked() + lock.Unlock() + } + + // Wait for double the duration so timers that weren't correctly stopped can + // fire. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration * 2): + } +} + +func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + for i := 0; i < 10; i++ { + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration) + timer.StopLocked() + timer.Reset(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestManyCancellableTimerResetUnderLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + for i := 0; i < 10; i++ { + timer.StopLocked() + timer.Reset(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9c40931b5..c7ce74cdd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -350,13 +350,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptBool sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0010b5e5f..07ffa8aba 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -247,17 +247,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrNotSupported + return tcpip.ErrUnknownProtocolOption } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -265,6 +265,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNotSupported } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrNotSupported +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + // HandlePacket implements stack.PacketEndpoint.HandlePacket. func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5aafe2615..85f7eb76b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -509,13 +509,38 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -544,21 +569,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index dfaa4a559..4f361b226 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) { // Make sure we get the same error when calling the original ep and the // new one. This validates that v4-mapped endpoints are still able to // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { + _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) } @@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) { // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { t.Fatalf("GetSockOpt failed failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index dd8b47cbe..920b24975 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1145,8 +1145,31 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v + } + + return nil +} + // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max @@ -1256,19 +1279,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.QuickAckOption: if v == 0 { @@ -1289,23 +1307,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyMSSChanged) return nil - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v != 0 - return nil - case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -1446,8 +1447,27 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1525,12 +1545,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.QuickAckOption: @@ -1540,22 +1556,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) @@ -1974,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } + if e.state == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() @@ -2033,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.bindLocked(addr) +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e8fe4dab5..1aa0733d0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) { func TestConnectBindToDevice(t *testing.T) { for _, test := range []struct { name string - device string + device tcpip.NICID want tcp.EndpointState }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, } { t.Run(test.name, func(t *testing.T) { c := context.New(t, defaultMTU) @@ -3794,46 +3794,41 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { + if err := s.CreateNIC(321, loopback.New()); err != nil { t.Errorf("CreateNIC failed: %v", err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } @@ -4027,12 +4022,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) } case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) } default: t.Fatalf("unknown network: '%s'", network) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b0a376eba..822907998 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,15 +158,17 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) if testing.Verbose() { wep2 = sniffer.New(channel.New(1000, mtu, "")) } - if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts2 := stack.NICOptions{Name: "nic2"} + if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { @@ -473,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { + if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1ac4705af..864dc8733 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -456,14 +456,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { - return nil -} - -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -478,8 +473,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.v6only = v != 0 + e.v6only = v + } + + return nil +} +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -624,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.BroadcastOption: e.mu.Lock() @@ -660,8 +662,27 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -695,22 +716,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) @@ -757,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = tcpip.BindToDeviceOption("") + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 7051a7a9c..0a82bc4fa 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } else if flow.isBroadcast() { @@ -508,46 +508,42 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) + opts := stack.NICOptions{Name: "my_device"} + if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } |