diff options
Diffstat (limited to 'pkg')
61 files changed, 2170 insertions, 309 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a4bb62013..05ca5342f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -29,6 +29,7 @@ go_library( "file_amd64.go", "file_arm64.go", "fs.go", + "fuse.go", "futex.go", "inotify.go", "ioctl.go", diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go new file mode 100644 index 000000000..d3ebbccc4 --- /dev/null +++ b/pkg/abi/linux/fuse.go @@ -0,0 +1,143 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// +marshal +type FUSEOpcode uint32 + +// +marshal +type FUSEOpID uint64 + +// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +const ( + FUSE_LOOKUP FUSEOpcode = 1 + FUSE_FORGET = 2 /* no reply */ + FUSE_GETATTR = 3 + FUSE_SETATTR = 4 + FUSE_READLINK = 5 + FUSE_SYMLINK = 6 + _ + FUSE_MKNOD = 8 + FUSE_MKDIR = 9 + FUSE_UNLINK = 10 + FUSE_RMDIR = 11 + FUSE_RENAME = 12 + FUSE_LINK = 13 + FUSE_OPEN = 14 + FUSE_READ = 15 + FUSE_WRITE = 16 + FUSE_STATFS = 17 + FUSE_RELEASE = 18 + _ + FUSE_FSYNC = 20 + FUSE_SETXATTR = 21 + FUSE_GETXATTR = 22 + FUSE_LISTXATTR = 23 + FUSE_REMOVEXATTR = 24 + FUSE_FLUSH = 25 + FUSE_INIT = 26 + FUSE_OPENDIR = 27 + FUSE_READDIR = 28 + FUSE_RELEASEDIR = 29 + FUSE_FSYNCDIR = 30 + FUSE_GETLK = 31 + FUSE_SETLK = 32 + FUSE_SETLKW = 33 + FUSE_ACCESS = 34 + FUSE_CREATE = 35 + FUSE_INTERRUPT = 36 + FUSE_BMAP = 37 + FUSE_DESTROY = 38 + FUSE_IOCTL = 39 + FUSE_POLL = 40 + FUSE_NOTIFY_REPLY = 41 + FUSE_BATCH_FORGET = 42 +) + +const ( + // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem. + // This is the minimum size Linux supports. See linux.fuse.h. + FUSE_MIN_READ_BUFFER uint32 = 8192 +) + +// FUSEHeaderIn is the header read by the daemon with each request. +// +// +marshal +type FUSEHeaderIn struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Opcode specifies the kind of operation of the request. + Opcode FUSEOpcode + + // Unique specifies the unique identifier for this request. + Unique FUSEOpID + + // NodeID is the ID of the filesystem object being operated on. + NodeID uint64 + + // UID is the UID of the requesting process. + UID uint32 + + // GID is the GID of the requesting process. + GID uint32 + + // PID is the PID of the requesting process. + PID uint32 + + _ uint32 +} + +// FUSEHeaderOut is the header written by the daemon when it processes +// a request and wants to send a reply (almost all operations require a +// reply; if they do not, this will be explicitly documented). +// +// +marshal +type FUSEHeaderOut struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Error specifies the error that occurred (0 if none). + Error int32 + + // Unique specifies the unique identifier of the corresponding request. + Unique FUSEOpID +} + +// FUSEWriteIn is the header written by a daemon when it makes a +// write request to the FUSE filesystem. +// +// +marshal +type FUSEWriteIn struct { + // Fh specifies the file handle that is being written to. + Fh uint64 + + // Offset is the offset of the write. + Offset uint64 + + // Size is the size of data being written. + Size uint32 + + // WriteFlags is the flags used during the write. + WriteFlags uint32 + + // LockOwner is the ID of the lock owner. + LockOwner uint64 + + // Flags is the flags for the request. + Flags uint32 + + _ uint32 +} diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 789369220..5fb419bcd 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -8,7 +8,6 @@ go_template_instance( out = "dirty_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "Dirty", @@ -25,14 +24,14 @@ go_template_instance( name = "frame_ref_set_impl", out = "frame_ref_set_impl.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "fsutil", prefix = "FrameRef", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "uint64", "Functions": "FrameRefSetFunctions", }, @@ -43,7 +42,6 @@ go_template_instance( out = "file_range_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "FileRange", @@ -86,7 +84,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/state", diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index c6cd45087..2c9446c1d 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) { // repeatedly until all bytes have been written. max is the true size of the // cached object; offsets beyond max will not be passed to writeAt, even if // they are marked dirty. -func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { var changedDirty bool defer func() { if changedDirty { @@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet // successful partial write, SyncDirtyAll will call it repeatedly until all // bytes have been written. max is the true size of the cached object; offsets // beyond max will not be passed to writeAt, even if they are marked dirty. -func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { dseg := dirty.FirstSegment() for dseg.Ok() { if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil { @@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max } // Preconditions: mr must be page-aligned. -func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() { wbr := cseg.Range().Intersect(mr) if max < wbr.Start { diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 5643cdac9..bbafebf03 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -23,13 +23,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a -// platform.File. It is used to implement Mappables that store data in +// memmap.File. It is used to implement Mappables that store data in // sparsely-allocated memory. // // type FileRangeSet <generated by go_generics> @@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli } // FileRange returns the FileRange mapped by seg. -func (seg FileRangeIterator) FileRange() platform.FileRange { +func (seg FileRangeIterator) FileRange() memmap.FileRange { return seg.FileRangeOf(seg.Range()) } // FileRangeOf returns the FileRange mapped by mr. // // Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0. -func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange { +func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange { frstart := seg.Value() + (mr.Start - seg.Start()) - return platform.FileRange{frstart, frstart + mr.Length()} + return memmap.FileRange{frstart, frstart + mr.Length()} } // Fill attempts to ensure that all memmap.Mappable offsets in required are -// mapped to a platform.File offset, by allocating from mf with the given +// mapped to a memmap.File offset, by allocating from mf with the given // memory usage kind and invoking readAt to store data into memory. (If readAt // returns a successful partial read, Fill will call it repeatedly until all // bytes have been read.) EOF is handled consistently with the requirements of @@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map } // Drop removes segments for memmap.Mappable offsets in mr, freeing the -// corresponding platform.FileRanges. +// corresponding memmap.FileRanges. // // Preconditions: mr must be page-aligned. func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { @@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { } // DropAll removes all segments in mr, freeing the corresponding -// platform.FileRanges. +// memmap.FileRanges. func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) { for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { mf.DecRef(seg.FileRange()) diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index dd6f5aba6..a808894df 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -17,7 +17,7 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" ) @@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) { } // Merge implements segment.Functions.Merge. -func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) { +func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) { if val1 != val2 { return 0, false } @@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. } // Split implements segment.Functions.Split. -func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { +func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } // IncRefAndAccount adds a reference on the range fr. All newly inserted segments // are accounted as host page cache memory mappings. -func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) { seg, gap := refs.Find(fr.Start) for { switch { @@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { // DecRefAndAccount removes a reference on the range fr and untracks segments // that are removed from memory accounting. -func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) { seg := refs.FindSegment(fr.Start) for seg.Ok() && seg.Start() < fr.End { diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e82afd112..ef0113b52 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { // offsets in fr or until the next call to UnmapAll. // // Preconditions: The caller must hold a reference on all offsets in fr. -func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) { +func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) { chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift) f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) } // Preconditions: f.mapsMu must be locked. -func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error { +func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error { prot := syscall.PROT_READ if write { prot |= syscall.PROT_WRITE diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 78fec553e..c15d8a946 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -21,18 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// HostMappable implements memmap.Mappable and platform.File over a +// HostMappable implements memmap.Mappable and memmap.File over a // CachedFileObject. // // Lock order (compare the lock order model in mm/mm.go): // truncateMu ("fs locks") // mu ("memmap.Mappable locks not taken by Translate") -// ("platform.File locks") +// ("memmap.File locks") // backingFile ("CachedFileObject locks") // // +stateify savable @@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error { return nil } -// MapInternal implements platform.File.MapInternal. -func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (h *HostMappable) FD() int { return h.backingFile.FD() } -// IncRef implements platform.File.IncRef. -func (h *HostMappable) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (h *HostMappable) IncRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.IncRefOn(mr) } -// DecRef implements platform.File.DecRef. -func (h *HostMappable) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (h *HostMappable) DecRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.DecRefOn(mr) } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 800c8b4e1..fe8b0b6ac 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -26,7 +26,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. c.mapsMu.Lock() @@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable } } -// IncRef implements platform.File.IncRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// IncRef implements memmap.File.IncRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { +func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg, gap := c.refs.Find(fr.Start) @@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { } } -// DecRef implements platform.File.DecRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// DecRef implements memmap.File.DecRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { +func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg := c.refs.FindSegment(fr.Start) @@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { c.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. This is used when we +// MapInternal implements memmap.File.MapInternal. This is used when we // directly map an underlying host fd and CachingInodeOperations is used as the -// platform.File during translation. -func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// memmap.File during translation. +func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// FD implements memmap.File.FD. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. func (c *CachingInodeOperations) FD() int { return c.backingFile.FD() diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 737007748..67649e811 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -1,12 +1,28 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "request_list", + out = "request_list.go", + package = "fuse", + prefix = "request", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Request", + "Linker": "*Request", + }, +) + go_library( name = "fuse", srcs = [ + "connection.go", "dev.go", "fusefs.go", + "register.go", + "request_list.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -18,7 +34,30 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "//pkg/waiter", + "//tools/go_marshal/marshal", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "dev_test", + size = "small", + srcs = ["dev_test.go"], + library = ":fuse", + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/fsimpl/testutil", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", + "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go new file mode 100644 index 000000000..f330da0bd --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -0,0 +1,255 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "errors" + "fmt" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// MaxActiveRequestsDefault is the default setting controlling the upper bound +// on the number of active requests at any given time. +const MaxActiveRequestsDefault = 10000 + +var ( + // Ordinary requests have even IDs, while interrupts IDs are odd. + InitReqBit uint64 = 1 + ReqIDStep uint64 = 2 +) + +// Request represents a FUSE operation request that hasn't been sent to the +// server yet. +// +// +stateify savable +type Request struct { + requestEntry + + id linux.FUSEOpID + hdr *linux.FUSEHeaderIn + data []byte +} + +// Response represents an actual response from the server, including the +// response payload. +// +// +stateify savable +type Response struct { + opcode linux.FUSEOpcode + hdr linux.FUSEHeaderOut + data []byte +} + +// Connection is the struct by which the sentry communicates with the FUSE server daemon. +type Connection struct { + fd *DeviceFD + + // MaxWrite is the daemon's maximum size of a write buffer. + // This is negotiated during FUSE_INIT. + MaxWrite uint32 +} + +// NewFUSEConnection creates a FUSE connection to fd +func NewFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*Connection, error) { + // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to + // mount a FUSE filesystem. + fuseFD := fd.Impl().(*DeviceFD) + fuseFD.mounted = true + + // Create the writeBuf for the header to be stored in. + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + fuseFD.writeBuf = make([]byte, hdrLen) + fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse) + fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests) + fuseFD.writeCursor = 0 + + return &Connection{ + fd: fuseFD, + }, nil +} + +// NewRequest creates a new request that can be sent to the FUSE server. +func (conn *Connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + conn.fd.nextOpID += linux.FUSEOpID(ReqIDStep) + + hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() + hdr := linux.FUSEHeaderIn{ + Len: uint32(hdrLen + payload.SizeBytes()), + Opcode: opcode, + Unique: conn.fd.nextOpID, + NodeID: ino, + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + PID: pid, + } + + buf := make([]byte, hdr.Len) + hdr.MarshalUnsafe(buf[:hdrLen]) + payload.MarshalUnsafe(buf[hdrLen:]) + + return &Request{ + id: hdr.Unique, + hdr: &hdr, + data: buf, + }, nil +} + +// Call makes a request to the server and blocks the invoking task until a +// server responds with a response. +// NOTE: If no task is provided then the Call will simply enqueue the request +// and return a nil response. No blocking will happen in this case. Instead, +// this is used to signify that the processing of this request will happen by +// the kernel.Task that writes the response. See FUSE_INIT for such an +// invocation. +func (conn *Connection) Call(t *kernel.Task, r *Request) (*Response, error) { + fut, err := conn.callFuture(t, r) + if err != nil { + return nil, err + } + + return fut.resolve(t) +} + +// Error returns the error of the FUSE call. +func (r *Response) Error() error { + errno := r.hdr.Error + if errno >= 0 { + return nil + } + + sysErrNo := syscall.Errno(-errno) + return error(sysErrNo) +} + +// UnmarshalPayload unmarshals the response data into m. +func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { + hdrLen := r.hdr.SizeBytes() + haveDataLen := r.hdr.Len - uint32(hdrLen) + wantDataLen := uint32(m.SizeBytes()) + + if haveDataLen < wantDataLen { + return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) + } + + m.UnmarshalUnsafe(r.data[hdrLen:]) + return nil +} + +// callFuture makes a request to the server and returns a future response. +// Call resolve() when the response needs to be fulfilled. +func (conn *Connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + + // Is the queue full? + // + // We must busy wait here until the request can be queued. We don't + // block on the fd.fullQueueCh with a lock - so after being signalled, + // before we acquire the lock, it is possible that a barging task enters + // and queues a request. As a result, upon acquiring the lock we must + // again check if the room is available. + // + // This can potentially starve a request forever but this can only happen + // if there are always too many ongoing requests all the time. The + // supported maxActiveRequests setting should be really high to avoid this. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + if t == nil { + // Since there is no task that is waiting. We must error out. + return nil, errors.New("FUSE request queue full") + } + + log.Infof("Blocking request %v from being queued. Too many active requests: %v", + r.id, conn.fd.numActiveRequests) + conn.fd.mu.Unlock() + err := t.Block(conn.fd.fullQueueCh) + conn.fd.mu.Lock() + if err != nil { + return nil, err + } + } + + return conn.callFutureLocked(t, r) +} + +// callFutureLocked makes a request to the server and returns a future response. +func (conn *Connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.queue.PushBack(r) + conn.fd.numActiveRequests += 1 + fut := newFutureResponse(r.hdr.Opcode) + conn.fd.completions[r.id] = fut + + // Signal the readers that there is something to read. + conn.fd.waitQueue.Notify(waiter.EventIn) + + return fut, nil +} + +// futureResponse represents an in-flight request, that may or may not have +// completed yet. Convert it to a resolved Response by calling Resolve, but note +// that this may block. +// +// +stateify savable +type futureResponse struct { + opcode linux.FUSEOpcode + ch chan struct{} + hdr *linux.FUSEHeaderOut + data []byte +} + +// newFutureResponse creates a future response to a FUSE request. +func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse { + return &futureResponse{ + opcode: opcode, + ch: make(chan struct{}), + } +} + +// resolve blocks the task until the server responds to its corresponding request, +// then returns a resolved response. +func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { + // If there is no Task associated with this request - then we don't try to resolve + // the response. Instead, the task writing the response (proxy to the server) will + // process the response on our behalf. + if t == nil { + log.Infof("fuse.Response.resolve: Not waiting on a response from server.") + return nil, nil + } + + if err := t.Block(f.ch); err != nil { + return nil, err + } + + return f.getResponse(), nil +} + +// getResponse creates a Response from the data the futureResponse has. +func (f *futureResponse) getResponse() *Response { + return &Response{ + opcode: f.opcode, + hdr: *f.hdr, + data: f.data, + } +} diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index c9e12a94f..f3443ac71 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -15,13 +15,17 @@ package fuse import ( + "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" ) const fuseDevMinor = 229 @@ -54,9 +58,43 @@ type DeviceFD struct { // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD. mounted bool - // TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue - // and deque requests, control synchronization and establish communication - // between the FUSE kernel module and the /dev/fuse character device. + // nextOpID is used to create new requests. + nextOpID linux.FUSEOpID + + // queue is the list of requests that need to be processed by the FUSE server. + queue requestList + + // numActiveRequests is the number of requests made by the Sentry that has + // yet to be responded to. + numActiveRequests uint64 + + // completions is used to map a request to its response. A Writer will use this + // to notify the caller of a completed response. + completions map[linux.FUSEOpID]*futureResponse + + writeCursor uint32 + + // writeBuf is the memory buffer used to copy in the FUSE out header from + // userspace. + writeBuf []byte + + // writeCursorFR current FR being copied from server. + writeCursorFR *futureResponse + + // mu protects all the queues, maps, buffers and cursors and nextOpID. + mu sync.Mutex + + // waitQueue is used to notify interested parties when the device becomes + // readable or writable. + waitQueue waiter.Queue + + // fullQueueCh is a channel used to synchronize the readers with the writers. + // Writers (inbound requests to the filesystem) block if there are too many + // unprocessed in-flight requests. + fullQueueCh chan struct{} + + // fs is the FUSE filesystem that this FD is being used for. + fs *filesystem } // Release implements vfs.FileDescriptionImpl.Release. @@ -79,7 +117,75 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R return 0, syserror.EPERM } - return 0, syserror.ENOSYS + // We require that any Read done on this filesystem have a sane minimum + // read buffer. It must have the capacity for the fixed parts of any request + // header (Linux uses the request header and the FUSEWriteIn header for this + // calculation) + the negotiated MaxWrite room for the data. + minBuffSize := linux.FUSE_MIN_READ_BUFFER + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes()) + negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.MaxWrite + if minBuffSize < negotiatedMinBuffSize { + minBuffSize = negotiatedMinBuffSize + } + + // If the read buffer is too small, error out. + if dst.NumBytes() < int64(minBuffSize) { + return 0, syserror.EINVAL + } + + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readLocked(ctx, dst, opts) +} + +// readLocked implements the reading of the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + if fd.queue.Empty() { + return 0, syserror.ErrWouldBlock + } + + var readCursor uint32 + var bytesRead int64 + for { + req := fd.queue.Front() + if dst.NumBytes() < int64(req.hdr.Len) { + // The request is too large. Cannot process it. All requests must be smaller than the + // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT + // handshake. + errno := -int32(syscall.EIO) + if req.hdr.Opcode == linux.FUSE_SETXATTR { + errno = -int32(syscall.E2BIG) + } + + // Return the error to the calling task. + if err := fd.sendError(ctx, errno, req); err != nil { + return 0, err + } + + // We're done with this request. + fd.queue.Remove(req) + + // Restart the read as this request was invalid. + log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.") + return fd.readLocked(ctx, dst, opts) + } + + n, err := dst.CopyOut(ctx, req.data[readCursor:]) + if err != nil { + return 0, err + } + readCursor += uint32(n) + bytesRead += int64(n) + + if readCursor >= req.hdr.Len { + // Fully done with this req, remove it from the queue. + fd.queue.Remove(req) + break + } + } + + return bytesRead, nil } // PWrite implements vfs.FileDescriptionImpl.PWrite. @@ -94,12 +200,128 @@ func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset i // Write implements vfs.FileDescriptionImpl.Write. func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.writeLocked(ctx, src, opts) +} + +// writeLocked implements writing to the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. if !fd.mounted { return 0, syserror.EPERM } - return 0, syserror.ENOSYS + var cn, n int64 + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + + for src.NumBytes() > 0 { + if fd.writeCursorFR != nil { + // Already have common header, and we're now copying the payload. + wantBytes := fd.writeCursorFR.hdr.Len + + // Note that the FR data doesn't have the header. Copy it over if its necessary. + if fd.writeCursorFR.data == nil { + fd.writeCursorFR.data = make([]byte, wantBytes) + } + + bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == wantBytes { + // Done reading this full response. Clean up and unblock the + // initiator. + break + } + + // Check if we have more data in src. + continue + } + + // Assert that the header isn't read into the writeBuf yet. + if fd.writeCursor >= hdrLen { + return 0, syserror.EINVAL + } + + // We don't have the full common response header yet. + wantBytes := hdrLen - fd.writeCursor + bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == hdrLen { + // Have full header in the writeBuf. Use it to fetch the actual futureResponse + // from the device's completions map. + var hdr linux.FUSEHeaderOut + hdr.UnmarshalBytes(fd.writeBuf) + + // We have the header now and so the writeBuf has served its purpose. + // We could reset it manually here but instead of doing that, at the + // end of the write, the writeCursor will be set to 0 thereby allowing + // the next request to overwrite whats in the buffer, + + fut, ok := fd.completions[hdr.Unique] + if !ok { + // Server sent us a response for a request we never sent? + return 0, syserror.EINVAL + } + + delete(fd.completions, hdr.Unique) + + // Copy over the header into the future response. The rest of the payload + // will be copied over to the FR's data in the next iteration. + fut.hdr = &hdr + fd.writeCursorFR = fut + + // Next iteration will now try read the complete request, if src has + // any data remaining. Otherwise we're done. + } + } + + if fd.writeCursorFR != nil { + if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil { + return 0, err + } + + // Ready the device for the next request. + fd.writeCursorFR = nil + fd.writeCursor = 0 + } + + return n, nil +} + +// Readiness implements vfs.FileDescriptionImpl.Readiness. +func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { + var ready waiter.EventMask + ready |= waiter.EventOut // FD is always writable + if !fd.queue.Empty() { + // Have reqs available, FD is readable. + ready |= waiter.EventIn + } + + return ready & mask +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.waitQueue.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *DeviceFD) EventUnregister(e *waiter.Entry) { + fd.waitQueue.EventUnregister(e) } // Seek implements vfs.FileDescriptionImpl.Seek. @@ -112,22 +334,61 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64 return 0, syserror.ENOSYS } -// Register registers the FUSE device with vfsObj. -func Register(vfsObj *vfs.VirtualFilesystem) error { - if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{ - GroupName: "misc", - }); err != nil { +// sendResponse sends a response to the waiting task (if any). +func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error { + // See if the running task need to perform some action before returning. + // Since we just finished writing the future, we can be sure that + // getResponse generates a populated response. + if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil { return err } + // Signal that the queue is no longer full. + select { + case fd.fullQueueCh <- struct{}{}: + default: + } + fd.numActiveRequests -= 1 + + // Signal the task waiting on a response. + close(fut.ch) return nil } -// CreateDevtmpfsFile creates a device special file in devtmpfs. -func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error { - if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil { +// sendError sends an error response to the waiting task (if any). +func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error { + // Return the error to the calling task. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + respHdr := linux.FUSEHeaderOut{ + Len: outHdrLen, + Error: errno, + Unique: req.hdr.Unique, + } + + fut, ok := fd.completions[respHdr.Unique] + if !ok { + // Server sent us a response for a request we never sent? + return syserror.EINVAL + } + delete(fd.completions, respHdr.Unique) + + fut.hdr = &respHdr + if err := fd.sendResponse(ctx, fut); err != nil { return err } return nil } + +// noReceiverAction has the calling kernel.Task do some action if its known that no +// receiver is going to be waiting on the future channel. This is to be used by: +// FUSE_INIT. +func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error { + if r.opcode == linux.FUSE_INIT { + // TODO: process init response here. + // Maybe get the creds from the context? + // creds := auth.CredentialsFromContext(ctx) + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go new file mode 100644 index 000000000..fcd77832a --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -0,0 +1,429 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "io" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// echoTestOpcode is the Opcode used during testing. The server used in tests +// will simply echo the payload back with the appropriate headers. +const echoTestOpcode linux.FUSEOpcode = 1000 + +type testPayload struct { + data uint32 +} + +// TestFUSECommunication tests that the communication layer between the Sentry and the +// FUSE server daemon works as expected. +func TestFUSECommunication(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + + // Create test cases with different number of concurrent clients and servers. + testCases := []struct { + Name string + NumClients int + NumServers int + MaxActiveRequests uint64 + }{ + { + Name: "SingleClientSingleServer", + NumClients: 1, + NumServers: 1, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "SingleClientMultipleServers", + NumClients: 1, + NumServers: 10, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "MultipleClientsSingleServer", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "MultipleClientsMultipleServers", + NumClients: 10, + NumServers: 10, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "RequestCapacityFull", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: 1, + }, + { + Name: "RequestCapacityContinuouslyFull", + NumClients: 100, + NumServers: 2, + MaxActiveRequests: 2, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + clientsDone := make([]chan struct{}, testCase.NumClients) + serversDone := make([]chan struct{}, testCase.NumServers) + serversKill := make([]chan struct{}, testCase.NumServers) + + // FUSE clients. + for i := 0; i < testCase.NumClients; i++ { + clientsDone[i] = make(chan struct{}) + go func(i int) { + fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i]) + }(i) + } + + // FUSE servers. + for j := 0; j < testCase.NumServers; j++ { + serversDone[j] = make(chan struct{}) + serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block. + go func(j int) { + fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j]) + }(j) + } + + // Tear down. + // + // Make sure all the clients are done. + for i := 0; i < testCase.NumClients; i++ { + <-clientsDone[i] + } + + // Kill any server that is potentially waiting. + for j := 0; j < testCase.NumServers; j++ { + serversKill[j] <- struct{}{} + } + + // Make sure all the servers are done. + for j := 0; j < testCase.NumServers; j++ { + <-serversDone[j] + } + }) + } +} + +// CallTest makes a request to the server and blocks the invoking +// goroutine until a server responds with a response. Doesn't block +// a kernel.Task. Analogous to Connection.Call but used for testing. +func CallTest(conn *Connection, t *kernel.Task, r *Request, i uint32) (*Response, error) { + conn.fd.mu.Lock() + + // Wait until we're certain that a new request can be processed. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + conn.fd.mu.Unlock() + select { + case <-conn.fd.fullQueueCh: + } + conn.fd.mu.Lock() + } + + fut, err := conn.callFutureLocked(t, r) // No task given. + conn.fd.mu.Unlock() + + if err != nil { + return nil, err + } + + // Resolve the response. + // + // Block without a task. + select { + case <-fut.ch: + } + + // A response is ready. Resolve and return it. + return fut.getResponse(), nil +} + +// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE +// device. However, it does so by - not blocking the task that is calling - and +// instead just waits on a channel. The behaviour is essentially the same as +// DeviceFD.Read except it guarantees that the task is not blocked. +func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) { + var err error + var n, total int64 + + dev := fd.Impl().(*DeviceFD) + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + dev.EventRegister(&w, waiter.EventIn) + for { + // Issue the request and break out if it completes with anything other than + // "would block". + n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{}) + total += n + if err != syserror.ErrWouldBlock { + break + } + + // Wait for a notification that we should retry. + // Emulate the blocking for when no requests are available + select { + case <-ch: + case <-killServer: + // Server killed by the main program. + return 0, true, nil + } + } + + dev.EventUnregister(&w) + return total, false, err +} + +// fuseClientRun emulates all the actions of a normal FUSE request. It creates +// a header, a payload, calls the server, waits for the response, and processes +// the response. +func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *Connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) { + defer func() { clientDone <- struct{}{} }() + + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + testObj := &testPayload{ + data: rand.Uint32(), + } + + req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + + // Queue up a request. + // Analogous to Call except it doesn't block on the task. + resp, err := CallTest(conn, clientTask, req, pid) + if err != nil { + t.Fatalf("CallTaskNonBlock failed: %v", err) + } + + if err = resp.Error(); err != nil { + t.Fatalf("Server responded with an error: %v", err) + } + + var respTestPayload testPayload + if err := resp.UnmarshalPayload(&respTestPayload); err != nil { + t.Fatalf("Unmarshalling payload error: %v", err) + } + + if resp.hdr.Unique != req.hdr.Unique { + t.Fatalf("got response for another request. Expected response for req %v but got response for req %v", + req.hdr.Unique, resp.hdr.Unique) + } + + if respTestPayload.data != testObj.data { + t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + } + +} + +// fuseServerRun creates a task and emulates all the actions of a simple FUSE server +// that simply reads a request and echos the same struct back as a response using the +// appropriate headers. +func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) { + defer func() { serverDone <- struct{}{} }() + + // Create the tasks that the server will be using. + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + var readPayload testPayload + + serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + + // Read the request. + for { + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + payloadLen := uint32(readPayload.SizeBytes()) + + // The raed buffer must meet some certain size criteria. + buffSize := inHdrLen + payloadLen + if buffSize < linux.FUSE_MIN_READ_BUFFER { + buffSize = linux.FUSE_MIN_READ_BUFFER + } + inBuf := make([]byte, buffSize) + inIOseq := usermem.BytesIOSequence(inBuf) + + n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer) + if err != nil { + t.Fatalf("Read failed :%v", err) + } + + // Server should shut down. No new requests are going to be made. + if serverKilled { + break + } + + if n <= 0 { + t.Fatalf("Read read no bytes") + } + + var readFUSEHeaderIn linux.FUSEHeaderIn + readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) + readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + + if readFUSEHeaderIn.Opcode != echoTestOpcode { + t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) + } + + // Write the response. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + outBuf := make([]byte, outHdrLen+payloadLen) + outHeader := linux.FUSEHeaderOut{ + Len: outHdrLen + payloadLen, + Error: 0, + Unique: readFUSEHeaderIn.Unique, + } + + // Echo the payload back. + outHeader.MarshalUnsafe(outBuf[:outHdrLen]) + readPayload.MarshalUnsafe(outBuf[outHdrLen:]) + outIOseq := usermem.BytesIOSequence(outBuf) + + n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed :%v", err) + } + } +} + +func setup(t *testing.T) *testutil.System { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("Error creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + AllowUserMount: true, + }) + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + + return testutil.NewSystem(ctx, t, k.VFS(), mntns) +} + +// newTestConnection creates a fuse connection that the sentry can communicate with +// and the FD for the server to communicate with. +func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*Connection, *vfs.FileDescription, error) { + vfsObj := &vfs.VirtualFilesystem{} + fuseDev := &DeviceFD{} + + if err := vfsObj.Init(); err != nil { + return nil, nil, err + } + + vd := vfsObj.NewAnonVirtualDentry("genCountFD") + defer vd.DecRef() + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, nil, err + } + + fsopts := filesystemOptions{ + maxActiveRequests: maxActiveRequests, + } + fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + if err != nil { + return nil, nil, err + } + + return fs.conn, &fuseDev.vfsfd, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *testPayload) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *testPayload) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], t.data) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *testPayload) UnmarshalBytes(src []byte) { + *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} +} + +// Packed implements marshal.Marshallable.Packed. +func (t *testPayload) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *testPayload) MarshalUnsafe(dst []byte) { + t.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *testPayload) UnmarshalUnsafe(src []byte) { + t.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + panic("not implemented") +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *testPayload) WriteTo(w io.Writer) (int64, error) { + panic("not implemented") +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index f7775fb9b..911b6f7cb 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -51,6 +51,11 @@ type filesystemOptions struct { // rootMode specifies the the file mode of the filesystem's root. rootMode linux.FileMode + + // maxActiveRequests specifies the maximum number of active requests that can + // exist at any time. Any further requests will block when trying to + // Call the server. + maxActiveRequests uint64 } // filesystem implements vfs.FilesystemImpl. @@ -58,12 +63,12 @@ type filesystem struct { kernfs.Filesystem devMinor uint32 - // fuseFD is the FD returned when opening /dev/fuse. It is used for communication - // between the FUSE server daemon and the sentry fusefs. - fuseFD *DeviceFD + // conn is used for communication between the FUSE server + // daemon and the sentry fusefs. + conn *Connection // opts is the options the fusefs is initialized with. - opts filesystemOptions + opts *filesystemOptions } // Name implements vfs.FilesystemType.Name. @@ -100,7 +105,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) // Parse and set all the other supported FUSE mount options. - // TODO: Expand the supported mount options. + // TODO(gVisor.dev/issue/3229): Expand the supported mount options. if userIDStr, ok := mopts["user_id"]; ok { delete(mopts, "user_id") userID, err := strconv.ParseUint(userIDStr, 10, 32) @@ -134,21 +139,20 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } fsopts.rootMode = rootMode + // Set the maxInFlightRequests option. + fsopts.maxActiveRequests = MaxActiveRequestsDefault + // Check for unparsed options. if len(mopts) != 0 { log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts) return nil, nil, syserror.EINVAL } - // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to - // mount a FUSE filesystem. - fuseFD := fuseFd.Impl().(*DeviceFD) - fuseFD.mounted = true - - fs := &filesystem{ - devMinor: devMinor, - fuseFD: fuseFD, - opts: fsopts, + // Create a new FUSE filesystem. + fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + if err != nil { + log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) + return nil, nil, err } fs.VFSFilesystem().Init(vfsObj, &fsType, fs) @@ -162,6 +166,26 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return fs.VFSFilesystem(), root.VFSDentry(), nil } +// NewFUSEFilesystem creates a new FUSE filesystem. +func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { + fs := &filesystem{ + devMinor: devMinor, + opts: opts, + } + + conn, err := NewFUSEConnection(ctx, device, opts.maxActiveRequests) + if err != nil { + log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) + return nil, syserror.EINVAL + } + + fs.conn = conn + fuseFD := device.Impl().(*DeviceFD) + fuseFD.fs = fs + + return fs, nil +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release() { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go new file mode 100644 index 000000000..b5b581152 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/register.go @@ -0,0 +1,42 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// Register registers the FUSE device with vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "misc", + }); err != nil { + return err + } + + return nil +} + +// CreateDevtmpfsFile creates a device special file in devtmpfs. +func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error { + if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil { + return err + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 02317a133..09f142cfc 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -221,12 +220,12 @@ func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequenc return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} - var freed []platform.FileRange + var freed []memmap.FileRange d.dataMu.Lock() cseg := d.cache.LowerBoundSegment(mr.Start) for cseg.Ok() && cseg.Start() < mr.End { cseg = d.cache.Isolate(cseg, mr) - freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) + freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) cseg = d.cache.Remove(cseg).NextSegment() } d.dataMu.Unlock() @@ -821,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (d *dentry) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. d.mapsMu.Lock() @@ -869,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { } } -// dentryPlatformFile implements platform.File. It exists solely because dentry -// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef. +// dentryPlatformFile implements memmap.File. It exists solely because dentry +// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file // is available (i.e. dentry.handle.fd >= 0), and that FD is used for @@ -878,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { type dentryPlatformFile struct { *dentry - // fdRefs counts references on platform.File offsets. fdRefs is protected + // fdRefs counts references on memmap.File offsets. fdRefs is protected // by dentry.dataMu. fdRefs fsutil.FrameRefSet @@ -890,29 +889,29 @@ type dentryPlatformFile struct { hostFileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. -func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.IncRefAndAccount(fr) d.dataMu.Unlock() } -// DecRef implements platform.File.DecRef. -func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) d.handleMu.RUnlock() return bs, err } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() fd := d.handle.fd diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index e86fbe2d5..bd701bbc7 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -34,7 +34,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index 8545a82f0..65d3af38c 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -19,13 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// inodePlatformFile implements platform.File. It exists solely because inode -// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// inodePlatformFile implements memmap.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // // inodePlatformFile should only be used if inode.canMap is true. type inodePlatformFile struct { @@ -34,7 +33,7 @@ type inodePlatformFile struct { // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex - // fdRefs counts references on platform.File offsets. It is used solely for + // fdRefs counts references on memmap.File offsets. It is used solely for // memory accounting. fdRefs fsutil.FrameRefSet @@ -45,32 +44,32 @@ type inodePlatformFile struct { fileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. +// IncRef implements memmap.File.IncRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) IncRef(fr platform.FileRange) { +func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) i.fdRefsMu.Unlock() } -// DecRef implements platform.File.DecRef. +// DecRef implements memmap.File.DecRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) DecRef(fr platform.FileRange) { +func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) i.fdRefsMu.Unlock() } -// MapInternal implements platform.File.MapInternal. +// MapInternal implements memmap.File.MapInternal. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (i *inodePlatformFile) FD() int { return i.hostFD } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index bfd779837..c211fc8d0 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -20,7 +20,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index f66cfcc7f..55b4c2cdb 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -45,7 +45,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -370,7 +369,7 @@ type Shm struct { // fr is the offset into mfp.MemoryFile() that backs this contents of this // segment. Immutable. - fr platform.FileRange + fr memmap.FileRange // mu protects all fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 5f3908d8b..7c4fefb16 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/log" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -90,7 +90,7 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) { +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { return &Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), }, nil diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index f1b3c212c..290c32466 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -58,7 +58,7 @@ type vdsoParams struct { type VDSOParamPage struct { // The parameter page is fr, allocated from mfp.MemoryFile(). mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange // seq is the current sequence count written to the page. // @@ -81,7 +81,7 @@ type VDSOParamPage struct { // * VDSOParamPage must be the only writer to fr. // // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. -func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage { +func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { return &VDSOParamPage{mfp: mfp, fr: fr} } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index a98b66de1..2c95669cd 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -28,9 +28,21 @@ go_template_instance( }, ) +go_template_instance( + name = "file_range", + out = "file_range.go", + package = "memmap", + prefix = "File", + template = "//pkg/segment:generic_range", + types = { + "T": "uint64", + }, +) + go_library( name = "memmap", srcs = [ + "file_range.go", "mappable_range.go", "mapping_set.go", "mapping_set_impl.go", @@ -40,7 +52,7 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/sentry/platform", + "//pkg/safemem", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index c6db9fc8f..c188f6c29 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -19,12 +19,12 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 -// offsets to (platform.File, uint64 File offset) pairs. +// offsets to (File, uint64 File offset) pairs. // // See mm/mm.go for Mappable's place in the lock order. // @@ -74,7 +74,7 @@ type Mappable interface { // Translations are valid until invalidated by a callback to // MappingSpace.Invalidate or until the caller removes its mapping of the // translated range. Mappable implementations must ensure that at least one - // reference is held on all pages in a platform.File that may be the result + // reference is held on all pages in a File that may be the result // of a valid Translation. // // Preconditions: required.Length() > 0. optional.IsSupersetOf(required). @@ -100,7 +100,7 @@ type Translation struct { Source MappableRange // File is the mapped file. - File platform.File + File File // Offset is the offset into File at which this Translation begins. Offset uint64 @@ -110,9 +110,9 @@ type Translation struct { Perms usermem.AccessType } -// FileRange returns the platform.FileRange represented by t. -func (t Translation) FileRange() platform.FileRange { - return platform.FileRange{t.Offset, t.Offset + t.Source.Length()} +// FileRange returns the FileRange represented by t. +func (t Translation) FileRange() FileRange { + return FileRange{t.Offset, t.Offset + t.Source.Length()} } // CheckTranslateResult returns an error if (ts, terr) does not satisfy all @@ -361,3 +361,49 @@ type MMapOpts struct { // TODO(jamieliu): Replace entirely with MappingIdentity? Hint string } + +// File represents a host file that may be mapped into an platform.AddressSpace. +type File interface { + // All pages in a File are reference-counted. + + // IncRef increments the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. (The File + // interface does not provide a way to acquire an initial reference; + // implementors may define mechanisms for doing so.) + IncRef(fr FileRange) + + // DecRef decrements the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. + DecRef(fr FileRange) + + // MapInternal returns a mapping of the given file offsets in the invoking + // process' address space for reading and writing. + // + // Note that fr.Start and fr.End need not be page-aligned. + // + // Preconditions: fr.Length() > 0. At least one reference must be held on + // all pages in fr. + // + // Postconditions: The returned mapping is valid as long as at least one + // reference is held on the mapped pages. + MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + + // FD returns the file descriptor represented by the File. + // + // The only permitted operation on the returned file descriptor is to map + // pages from it consistent with the requirements of AddressSpace.MapFile. + FD() int +} + +// FileRange represents a range of uint64 offsets into a File. +// +// type FileRange <generated using go_generics> + +// String implements fmt.Stringer.String. +func (fr FileRange) String() string { + return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a036ce53c..f9d0837a1 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "file_refcount_set", out = "file_refcount_set.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "mm", prefix = "fileRefcount", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "int32", "Functions": "fileRefcountSetFunctions", }, diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 379148903..1999ec706 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -243,7 +242,7 @@ type aioMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange } var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp()) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 6db7c3d40..3e85964e4 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -25,7 +25,7 @@ // Locks taken by memmap.Mappable.Translate // mm.privateRefs.mu // platform.AddressSpace locks -// platform.File locks +// memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu // @@ -396,7 +396,7 @@ type pma struct { // file is the file mapped by this pma. Only pmas for which file == // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to // the corresponding file range while they exist. - file platform.File `state:"nosave"` + file memmap.File `state:"nosave"` // off is the offset into file at which this pma begins. // @@ -436,7 +436,7 @@ type pma struct { private bool // If internalMappings is not empty, it is the cached return value of - // file.MapInternal for the platform.FileRange mapped by this pma. + // file.MapInternal for the memmap.FileRange mapped by this pma. internalMappings safemem.BlockSeq `state:"nosave"` } @@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 { func (fileRefcountSetFunctions) ClearValue(_ *int32) { } -func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) { +func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) { return rc1, rc1 == rc2 } -func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) { +func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) { return rc, rc } diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 62e4c20af..930ec895f 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat } } -// Pin returns the platform.File ranges currently mapped by addresses in ar in +// Pin returns the memmap.File ranges currently mapped by addresses in ar in // mm, acquiring a reference on the returned ranges which the caller must // release by calling Unpin. If not all addresses are mapped, Pin returns a // non-nil error. Note that Pin may return both a non-empty slice of @@ -674,15 +673,15 @@ type PinnedRange struct { Source usermem.AddrRange // File is the mapped file. - File platform.File + File memmap.File // Offset is the offset into File at which this PinnedRange begins. Offset uint64 } -// FileRange returns the platform.File offsets mapped by pr. -func (pr PinnedRange) FileRange() platform.FileRange { - return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} +// FileRange returns the memmap.File offsets mapped by pr. +func (pr PinnedRange) FileRange() memmap.FileRange { + return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} } // Unpin releases the reference held by prs. @@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf } // incPrivateRef acquires a reference on private pages in fr. -func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { +func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) { mm.privateRefs.mu.Lock() defer mm.privateRefs.mu.Unlock() refSet := &mm.privateRefs.refs @@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { } // decPrivateRef releases a reference on private pages in fr. -func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) { - var freed []platform.FileRange +func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) { + var freed []memmap.FileRange mm.privateRefs.mu.Lock() refSet := &mm.privateRefs.refs @@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa // Discard internal mappings instead of trying to merge them, since merging // them requires an allocation and getting them again from the - // platform.File might not. + // memmap.File might not. pma1.internalMappings = safemem.BlockSeq{} return pma1, true } @@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error { return nil } -func (pseg pmaIterator) fileRange() platform.FileRange { +func (pseg pmaIterator) fileRange() memmap.FileRange { return pseg.fileRangeOf(pseg.Range()) } // Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0. -func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { +func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { panic("terminal pma iterator") @@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { pma := pseg.ValuePtr() pstart := pseg.Start() - return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} + return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 9ad52082d..0e142fb11 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -35,7 +34,7 @@ type SpecialMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange name string } @@ -44,7 +43,7 @@ type SpecialMappable struct { // SpecialMappable will use the given name in /proc/[pid]/maps. // // Preconditions: fr.Length() != 0. -func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable { +func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} m.EnableLeakCheck("mm.SpecialMappable") return &m @@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider { // FileRange returns the offsets into MemoryFileProvider().MemoryFile() that // store the SpecialMappable's contents. -func (m *SpecialMappable) FileRange() platform.FileRange { +func (m *SpecialMappable) FileRange() memmap.FileRange { return m.fr } diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index e1fcb175f..7a3311a70 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -36,14 +36,14 @@ go_template_instance( "trackGaps": "1", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "usage", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "usageInfo", "Functions": "usageSetFunctions", }, @@ -56,14 +56,14 @@ go_template_instance( "minDegree": "10", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "reclaim", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "reclaimSetValue", "Functions": "reclaimSetFunctions", }, @@ -89,7 +89,7 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/hostmm", - "//pkg/sentry/platform", + "//pkg/sentry/memmap", "//pkg/sentry/usage", "//pkg/state", "//pkg/state/wire", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index afab97c0a..3243d7214 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -33,14 +33,14 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// MemoryFile is a platform.File whose pages may be allocated to arbitrary +// MemoryFile is a memmap.File whose pages may be allocated to arbitrary // users. type MemoryFile struct { // opts holds options passed to NewMemoryFile. opts is immutable. @@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() { // to Allocate. // // Preconditions: length must be page-aligned and non-zero. -func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) { +func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) { if length == 0 || length%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid allocation length: %#x", length)) } @@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Find a range in the underlying file. fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) if !ok { - return platform.FileRange{}, syserror.ENOMEM + return memmap.FileRange{}, syserror.ENOMEM } // Expand the file if needed. @@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Round the new file size up to be chunk-aligned. newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask if err := f.file.Truncate(newFileSize); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } f.fileSize = newFileSize f.mappingsMu.Lock() @@ -416,7 +416,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi bs[i] = 0 } }); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } } if !f.usage.Add(fr, usageInfo{ @@ -439,7 +439,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // space for mappings to be allocated downwards. // // Precondition: alignment must be a power of 2. -func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) { alignmentMask := alignment - 1 // Search for space in existing gaps, starting at the current end of the @@ -461,7 +461,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 break } if start := unalignedStart &^ alignmentMask; start >= gap.Start() { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } gap = gap.PrevLargeEnoughGap(length) @@ -475,7 +475,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 min = (min + alignmentMask) &^ alignmentMask if min+length < min { // Overflow: allocation would exceed the range of uint64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } // Determine the minimum file size required to fit this allocation at its end. @@ -484,7 +484,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 if newFileSize <= fileSize { if fileSize != 0 { // Overflow: allocation would exceed the range of int64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } newFileSize = chunkSize } @@ -496,7 +496,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 continue } if start := unalignedStart &^ alignmentMask; start >= min { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } } } @@ -508,22 +508,22 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 // by r.ReadToBlocks(), it returns that error. // // Preconditions: length > 0. length must be page-aligned. -func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) { +func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) { fr, err := f.Allocate(length, kind) if err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } dsts, err := f.MapInternal(fr, usermem.Write) if err != nil { f.DecRef(fr) - return platform.FileRange{}, err + return memmap.FileRange{}, err } n, err := safemem.ReadFullToBlocks(r, dsts) un := uint64(usermem.Addr(n).RoundDown()) if un < length { // Free unused memory and update fr to contain only the memory that is // still allocated. - f.DecRef(platform.FileRange{fr.Start + un, fr.End}) + f.DecRef(memmap.FileRange{fr.Start + un, fr.End}) fr.End = fr.Start + un } return fr, err @@ -540,7 +540,7 @@ const ( // will read zeroes. // // Preconditions: fr.Length() > 0. -func (f *MemoryFile) Decommit(fr platform.FileRange) error { +func (f *MemoryFile) Decommit(fr memmap.FileRange) error { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -560,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error { return nil } -func (f *MemoryFile) markDecommitted(fr platform.FileRange) { +func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() // Since we're changing the knownCommitted attribute, we need to merge @@ -581,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) { f.usage.MergeRange(fr) } -// IncRef implements platform.File.IncRef. -func (f *MemoryFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (f *MemoryFile) IncRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -600,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) } -// DecRef implements platform.File.DecRef. -func (f *MemoryFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (f *MemoryFile) DecRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -637,8 +637,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } } -// MapInternal implements platform.File.MapInternal. -func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { if !fr.WellFormed() || fr.Length() == 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -664,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) ( // forEachMappingSlice invokes fn on a sequence of byte slices that // collectively map all bytes in fr. -func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error { +func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error { mappings := f.mappings.Load().([]uintptr) for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { chunk := int(chunkStart >> chunkShift) @@ -944,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( continue case !populated && populatedRun: // Finish the run by changing this segment. - runRange := platform.FileRange{ + runRange := memmap.FileRange{ Start: r.Start + uint64(populatedRunStart*usermem.PageSize), End: r.Start + uint64(i*usermem.PageSize), } @@ -1009,7 +1009,7 @@ func (f *MemoryFile) File() *os.File { return f.file } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (f *MemoryFile) FD() int { return int(f.file.Fd()) } @@ -1090,13 +1090,13 @@ func (f *MemoryFile) runReclaim() { // // Note that there returned range will be removed from tracking. It // must be reclaimed (removed from f.usage) at this point. -func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { +func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() for { for { if f.destroyed { - return platform.FileRange{}, false + return memmap.FileRange{}, false } if f.reclaimable { break @@ -1120,7 +1120,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } } -func (f *MemoryFile) markReclaimed(fr platform.FileRange) { +func (f *MemoryFile) markReclaimed(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) @@ -1222,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 { func (usageSetFunctions) ClearValue(val *usageInfo) { } -func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) { +func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) { return val1, val1 == val2 } -func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { +func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { return val, val } @@ -1270,10 +1270,10 @@ func (reclaimSetFunctions) MaxKey() uint64 { func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { } -func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { +func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { return reclaimSetValue{}, true } -func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { +func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { return reclaimSetValue{}, reclaimSetValue{} } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 453241eca..209b28053 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,39 +1,21 @@ load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - go_library( name = "platform", srcs = [ "context.go", - "file_range.go", "mmap_min_addr.go", "platform.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/atomicbitops", "//pkg/context", - "//pkg/log", - "//pkg/safecopy", - "//pkg/safemem", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/usage", - "//pkg/syserror", + "//pkg/sentry/memmap", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 10a10bfe2..b5d27a72a 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/ring0", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index faf1d5e1c..98a3e539d 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem. } // MapFile implements platform.AddressSpace.MapFile. -func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { as.mu.Lock() defer as.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go index 6531bae1d..48ccf8474 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go @@ -22,7 +22,8 @@ import ( ) var ( - runDataSize int + runDataSize int + hasGuestPCID bool ) func updateSystemValues(fd int) error { @@ -33,6 +34,7 @@ func updateSystemValues(fd int) error { } // Save the data. runDataSize = int(sz) + hasGuestPCID = true // Success. return nil diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 3de309c1a..ff8c068c0 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -156,6 +157,14 @@ func (c *vCPU) initArchState() error { return err } + // Initialize the PCID database. + if hasGuestPCID { + // Note that NewPCIDs may return a nil table here, in which + // case we simply don't use PCID support (see below). In + // practice, this should not happen, however. + c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) + } + c.floatingPointState = arch.NewFloatingPointData() return nil } @@ -234,6 +243,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) } + // Assign PCIDs. + if c.PCIDs != nil { + var requireFlushPCID bool // Force a flush? + switchOpts.UserASID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables) + switchOpts.Flush = switchOpts.Flush || requireFlushPCID + } + var vector ring0.Vector ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0) c.SetTtbr0App(uintptr(ttbr0App)) diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 171513f3f..4b13eec30 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -22,9 +22,9 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -207,7 +207,7 @@ type AddressSpace interface { // Preconditions: addr and fr must be page-aligned. fr.Length() > 0. // at.Any() == true. At least one reference must be held on all pages in // fr, and must continue to be held as long as pages are mapped. - MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error + MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error // Unmap unmaps the given range. // @@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string { return fmt.Sprintf("segmentation fault at %#x", f.Addr) } -// File represents a host file that may be mapped into an AddressSpace. -type File interface { - // All pages in a File are reference-counted. - - // IncRef increments the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. (The File - // interface does not provide a way to acquire an initial reference; - // implementors may define mechanisms for doing so.) - IncRef(fr FileRange) - - // DecRef decrements the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. - DecRef(fr FileRange) - - // MapInternal returns a mapping of the given file offsets in the invoking - // process' address space for reading and writing. - // - // Note that fr.Start and fr.End need not be page-aligned. - // - // Preconditions: fr.Length() > 0. At least one reference must be held on - // all pages in fr. - // - // Postconditions: The returned mapping is valid as long as at least one - // reference is held on the mapped pages. - MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) - - // FD returns the file descriptor represented by the File. - // - // The only permitted operation on the returned file descriptor is to map - // pages from it consistent with the requirements of AddressSpace.MapFile. - FD() int -} - -// FileRange represents a range of uint64 offsets into a File. -// -// type FileRange <generated using go_generics> - -// String implements fmt.Stringer.String. -func (fr FileRange) String() string { - return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) -} - // Requirements is used to specify platform specific requirements. type Requirements struct { // RequiresCurrentPIDNS indicates that the sandbox has to be started in the diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 30402c2df..29fd23cc3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/hostcpu", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sync", diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 2389423b0..c990f3454 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp } // MapFile implements platform.AddressSpace.MapFile. -func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { var flags int if precommit { flags |= syscall.MAP_POPULATE diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 1b2cfad7d..c576d9475 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -62,7 +62,7 @@ func Override() { s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt) s.Table[59] = syscalls.Supported("execve", Execve) s.Table[72] = syscalls.Supported("fcntl", Fcntl) - s.Table[73] = syscalls.Supported("fcntl", Flock) + s.Table[73] = syscalls.Supported("flock", Flock) s.Table[74] = syscalls.Supported("fsync", Fsync) s.Table[75] = syscalls.Supported("fdatasync", Fdatasync) s.Table[76] = syscalls.Supported("truncate", Truncate) @@ -163,6 +163,106 @@ func Override() { // Override ARM64. s = linux.ARM64 + s.Table[5] = syscalls.Supported("setxattr", Setxattr) + s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) + s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) + s.Table[8] = syscalls.Supported("getxattr", Getxattr) + s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr) + s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr) + s.Table[11] = syscalls.Supported("listxattr", Listxattr) + s.Table[12] = syscalls.Supported("llistxattr", Llistxattr) + s.Table[13] = syscalls.Supported("flistxattr", Flistxattr) + s.Table[14] = syscalls.Supported("removexattr", Removexattr) + s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr) + s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr) + s.Table[17] = syscalls.Supported("getcwd", Getcwd) + s.Table[19] = syscalls.Supported("eventfd2", Eventfd2) + s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1) + s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl) + s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait) + s.Table[23] = syscalls.Supported("dup", Dup) + s.Table[24] = syscalls.Supported("dup3", Dup3) + s.Table[25] = syscalls.Supported("fcntl", Fcntl) + s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil) + s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[29] = syscalls.Supported("ioctl", Ioctl) + s.Table[32] = syscalls.Supported("flock", Flock) + s.Table[33] = syscalls.Supported("mknodat", Mknodat) + s.Table[34] = syscalls.Supported("mkdirat", Mkdirat) + s.Table[35] = syscalls.Supported("unlinkat", Unlinkat) + s.Table[36] = syscalls.Supported("symlinkat", Symlinkat) + s.Table[37] = syscalls.Supported("linkat", Linkat) + s.Table[38] = syscalls.Supported("renameat", Renameat) + s.Table[39] = syscalls.Supported("umount2", Umount2) + s.Table[40] = syscalls.Supported("mount", Mount) + s.Table[43] = syscalls.Supported("statfs", Statfs) + s.Table[44] = syscalls.Supported("fstatfs", Fstatfs) + s.Table[45] = syscalls.Supported("truncate", Truncate) + s.Table[46] = syscalls.Supported("ftruncate", Ftruncate) + s.Table[48] = syscalls.Supported("faccessat", Faccessat) + s.Table[49] = syscalls.Supported("chdir", Chdir) + s.Table[50] = syscalls.Supported("fchdir", Fchdir) + s.Table[51] = syscalls.Supported("chroot", Chroot) + s.Table[52] = syscalls.Supported("fchmod", Fchmod) + s.Table[53] = syscalls.Supported("fchmodat", Fchmodat) + s.Table[54] = syscalls.Supported("fchownat", Fchownat) + s.Table[55] = syscalls.Supported("fchown", Fchown) + s.Table[56] = syscalls.Supported("openat", Openat) + s.Table[57] = syscalls.Supported("close", Close) + s.Table[59] = syscalls.Supported("pipe2", Pipe2) + s.Table[61] = syscalls.Supported("getdents64", Getdents64) + s.Table[62] = syscalls.Supported("lseek", Lseek) s.Table[63] = syscalls.Supported("read", Read) + s.Table[64] = syscalls.Supported("write", Write) + s.Table[65] = syscalls.Supported("readv", Readv) + s.Table[66] = syscalls.Supported("writev", Writev) + s.Table[67] = syscalls.Supported("pread64", Pread64) + s.Table[68] = syscalls.Supported("pwrite64", Pwrite64) + s.Table[69] = syscalls.Supported("preadv", Preadv) + s.Table[70] = syscalls.Supported("pwritev", Pwritev) + s.Table[72] = syscalls.Supported("pselect", Pselect) + s.Table[73] = syscalls.Supported("ppoll", Ppoll) + s.Table[74] = syscalls.Supported("signalfd4", Signalfd4) + s.Table[76] = syscalls.Supported("splice", Splice) + s.Table[77] = syscalls.Supported("tee", Tee) + s.Table[78] = syscalls.Supported("readlinkat", Readlinkat) + s.Table[80] = syscalls.Supported("fstat", Fstat) + s.Table[81] = syscalls.Supported("sync", Sync) + s.Table[82] = syscalls.Supported("fsync", Fsync) + s.Table[83] = syscalls.Supported("fdatasync", Fdatasync) + s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange) + s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate) + s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime) + s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime) + s.Table[88] = syscalls.Supported("utimensat", Utimensat) + s.Table[198] = syscalls.Supported("socket", Socket) + s.Table[199] = syscalls.Supported("socketpair", SocketPair) + s.Table[200] = syscalls.Supported("bind", Bind) + s.Table[201] = syscalls.Supported("listen", Listen) + s.Table[202] = syscalls.Supported("accept", Accept) + s.Table[203] = syscalls.Supported("connect", Connect) + s.Table[204] = syscalls.Supported("getsockname", GetSockName) + s.Table[205] = syscalls.Supported("getpeername", GetPeerName) + s.Table[206] = syscalls.Supported("sendto", SendTo) + s.Table[207] = syscalls.Supported("recvfrom", RecvFrom) + s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt) + s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt) + s.Table[210] = syscalls.Supported("shutdown", Shutdown) + s.Table[211] = syscalls.Supported("sendmsg", SendMsg) + s.Table[212] = syscalls.Supported("recvmsg", RecvMsg) + s.Table[221] = syscalls.Supported("execve", Execve) + s.Table[222] = syscalls.Supported("mmap", Mmap) + s.Table[242] = syscalls.Supported("accept4", Accept4) + s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg) + s.Table[267] = syscalls.Supported("syncfs", Syncfs) + s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg) + s.Table[276] = syscalls.Supported("renameat2", Renameat2) + s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate) + s.Table[281] = syscalls.Supported("execveat", Execveat) + s.Table[286] = syscalls.Supported("preadv2", Preadv2) + s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) + s.Table[291] = syscalls.Supported("statx", Statx) + s.Init() } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index b0f57040c..31a242482 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -160,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: header.EthernetBroadcastAddress, + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 66e67429c..a35a64a0f 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -32,10 +32,14 @@ import ( ) const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + + defaultChannelSize = 1 + defaultMTU = 65536 ) type testContext struct { @@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext { TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, }) - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) wep := stack.LinkEndpoint(ep) if testing.Verbose() { @@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) { if !rep.IsValid() { t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { @@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } } + +func TestLinkAddressRequest(t *testing.T) { + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: stackLinkAddr2, + expectLinkAddr: stackLinkAddr2, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: header.EthernetBroadcastAddress, + }, + } + + for _, test := range tests { + p := arp.NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) + if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ff1cb53dd..24600d877 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -504,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) // TODO(b/148672031): Use stack.FindRoute instead of manually creating the @@ -513,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), + RemoteLinkAddress: remoteLinkAddr, } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) + } + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) pkt.SetType(header.ICMPv6NeighborSolicit) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 52a01b44e..f86aaed1d 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -34,6 +34,9 @@ const ( linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + + defaultChannelSize = 1 + defaultMTU = 65536 ) var ( @@ -257,8 +260,7 @@ func newTestContext(t *testing.T) *testContext { }), } - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { @@ -271,7 +273,7 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress lladdr0: %v", err) } - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -951,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { }) } } + +func TestLinkAddressRequest(t *testing.T) { + snaddr := header.SolicitedNodeAddr(lladdr0) + mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) + + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectLinkAddr: linkAddr1, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: mcaddr, + }, + } + + for _, test := range tests { + p := NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index bca1d940b..c962693f5 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -121,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {} type fwdTestNetworkProtocol struct { addrCache *linkAddrCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) } +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -174,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {} func (f *fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { if f.addrCache != nil && f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr) + f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) }) } return nil @@ -405,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -463,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { @@ -515,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -559,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -616,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 403557fd7..6f73a0ce4 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 1baa498d0..b15b8d1cb 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) if f := r.onLinkAddressRequest; f != nil { f() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index c477e31d8..a70792b50 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -243,7 +243,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 9e1b2d25f..8604c4259 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -478,12 +478,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. - // The request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts + // the request on the local network if remoteLinkAddr is the zero value. The + // request is sent on linkEP with localAddr as the source. // // A valid response will cause the discovery protocol's network // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 18ff89ffc..e860ee484 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -49,6 +49,7 @@ go_library( "segment_heap.go", "segment_queue.go", "segment_state.go", + "segment_unsafe.go", "snd.go", "snd_state.go", "tcp_endpoint_list.go", diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b34e47bbd..5d6174a59 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -49,7 +49,7 @@ const ( // DefaultReceiveBufferSize is the default size of the receive buffer // for an endpoint. - DefaultReceiveBufferSize = 1 << 20 // 1MB + DefaultReceiveBufferSize = 32 << 10 // 32KB // MaxBufferSize is the largest size a receive/send buffer can grow to. MaxBufferSize = 4 << 20 // 4MB diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index dd89a292a..5e0bfe585 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // We only store the segment if it's within our buffer // size limit. if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += s.logicalLen() + r.pendingBufUsed += seqnum.Size(s.segMemSize()) s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= s.logicalLen() + r.pendingBufUsed -= seqnum.Size(s.segMemSize()) s.decRef() } return false, nil diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 0280892a8..bb60dc29d 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -138,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// segMemSize is the amount of memory used to hold the segment data and +// the associated metadata. +func (s *segment) segMemSize() int { + return segSize + s.data.Size() +} + // parse populates the sequence & ack numbers, flags, and window fields of the // segment from the TCP header stored in the data. It then updates the view to // skip the header. diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go new file mode 100644 index 000000000..0ab7b8f56 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "unsafe" +) + +const ( + segSize = int(unsafe.Sizeof(segment{})) +) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 06fde2a79..37e7767d6 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -143,12 +143,14 @@ func New(t *testing.T, mtu uint32) *Context { TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, }) + const sendBufferSize = 1 << 20 // 1 MiB + const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context { t: t, s: s, linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), + WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), } } diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 83b80c8bc..a5e84658a 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,6 +10,7 @@ go_library( "dockerutil.go", "exec.go", "network.go", + "profile.go", ], visibility = ["//:sandbox"], deps = [ @@ -23,3 +24,19 @@ go_library( "@com_github_docker_go_connections//nat:go_default_library", ], ) + +go_test( + name = "profile_test", + size = "large", + srcs = [ + "profile_test.go", + ], + library = ":dockerutil", + tags = [ + # Requires docker and runsc to be configured before test runs. + # Also requires the test to be run as root. + "manual", + "local", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md new file mode 100644 index 000000000..870292096 --- /dev/null +++ b/pkg/test/dockerutil/README.md @@ -0,0 +1,86 @@ +# dockerutil + +This package is for creating and controlling docker containers for testing +runsc, gVisor's docker/kubernetes binary. A simple test may look like: + +``` + func TestSuperCool(t *testing.T) { + ctx := context.Background() + c := dockerutil.MakeContainer(ctx, t) + got, err := c.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine" + }, "echo", "super cool") + if err != nil { + t.Fatalf("err was not nil: %v", err) + } + want := "super cool" + if !strings.Contains(got, want){ + t.Fatalf("want: %s, got: %s", want, got) + } + } +``` + +For further examples, see many of our end to end tests elsewhere in the repo, +such as those in //test/e2e or benchmarks at //test/benchmarks. + +dockerutil uses the "official" docker golang api, which is +[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil +is a thin wrapper around this API, allowing desired new use cases to be easily +implemented. + +## Profiling + +dockerutil is capable of generating profiles. Currently, the only option is to +use pprof profiles generated by `runsc debug`. The profiler will generate Block, +CPU, Heap, Goroutine, and Mutex profiles. To generate profiles: + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or + `--vfs2`. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles run: + +``` +make sudo TARGETS=//path/to:target \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` + +Container name in most tests and benchmarks in gVisor is usually the test name +and some random characters like so: +`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2` + +Profiling requires root as runsc debug inspects running containers in /var/run +among other things. + +### Writing for Profiling + +The below shows an example of using profiles with dockerutil. + +``` +func TestSuperCool(t *testing.T){ + ctx := context.Background() + // profiled and using runtime from dockerutil.runtime flag + profiled := MakeContainer() + + // not profiled and using runtime runc + native := MakeNativeContainer() + + err := profiled.Spawn(ctx, RunOpts{ + Image: "some/image", + }, "sleep", "100000") + // profiling has begun here + ... + expensive setup that I don't want to profile. + ... + profiled.RestartProfiles() + // profiled activity +} +``` + +In the above example, `profiled` would be profiled and `native` would not. The +call to `RestartProfiles()` restarts the clock on profiling. This is useful if +the main activity being tested is done with `docker exec` or `container.Spawn()` +followed by one or more `container.Exec()` calls. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 17acdaf6f..b59503188 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -43,15 +43,21 @@ import ( // See: https://pkg.go.dev/github.com/docker/docker. type Container struct { Name string - Runtime string + runtime string logger testutil.Logger client *client.Client id string mounts []mount.Mount links []string - cleanups []func() copyErr error + cleanups []func() + + // Profiles are profiles added to this container. They contain methods + // that are run after Creation, Start, and Cleanup of this Container, along + // a handle to restart the profile. Generally, tests/benchmarks using + // profiles need to run as root. + profiles []Profile // Stores streams attached to the container. Used by WaitForOutputSubmatch. streams types.HijackedResponse @@ -106,7 +112,19 @@ type RunOpts struct { // MakeContainer sets up the struct for a Docker container. // // Names of containers will be unique. +// Containers will check flags for profiling requests. func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := MakeNativeContainer(ctx, logger) + c.runtime = *runtime + if p := MakePprofFromFlags(c); p != nil { + c.AddProfile(p) + } + return c +} + +// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native +// containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { // Slashes are not allowed in container names. name := testutil.RandomID(logger.Name()) name = strings.ReplaceAll(name, "/", "-") @@ -114,20 +132,33 @@ func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { if err != nil { return nil } - client.NegotiateAPIVersion(ctx) - return &Container{ logger: logger, Name: name, - Runtime: *runtime, + runtime: "", client: client, } } +// AddProfile adds a profile to this container. +func (c *Container) AddProfile(p Profile) { + c.profiles = append(c.profiles, p) +} + +// RestartProfiles calls Restart on all profiles for this container. +func (c *Container) RestartProfiles() error { + for _, profile := range c.profiles { + if err := profile.Restart(c); err != nil { + return err + } + } + return nil +} + // Spawn is analogous to 'docker run -d'. func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { - if err := c.create(ctx, r, args); err != nil { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { return err } return c.Start(ctx) @@ -153,7 +184,7 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) // Run is analogous to 'docker run'. func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { - if err := c.create(ctx, r, args); err != nil { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { return "", err } @@ -181,27 +212,25 @@ func (c *Container) MakeLink(target string) string { // CreateFrom creates a container from the given configs. func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { - cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name) - if err != nil { - return err - } - c.id = cont.ID - return nil + return c.create(ctx, conf, hostconf, netconf) } // Create is analogous to 'docker create'. func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { - return c.create(ctx, r, args) + return c.create(ctx, c.config(r, args), c.hostConfig(r), nil) } -func (c *Container) create(ctx context.Context, r RunOpts, args []string) error { - conf := c.config(r, args) - hostconf := c.hostConfig(r) +func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) if err != nil { return err } c.id = cont.ID + for _, profile := range c.profiles { + if err := profile.OnCreate(c); err != nil { + return fmt.Errorf("OnCreate method failed with: %v", err) + } + } return nil } @@ -227,7 +256,7 @@ func (c *Container) hostConfig(r RunOpts) *container.HostConfig { c.mounts = append(c.mounts, r.Mounts...) return &container.HostConfig{ - Runtime: c.Runtime, + Runtime: c.runtime, Mounts: c.mounts, PublishAllPorts: true, Links: r.Links, @@ -261,8 +290,15 @@ func (c *Container) Start(ctx context.Context) error { c.cleanups = append(c.cleanups, func() { c.streams.Close() }) - - return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}) + if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { + return fmt.Errorf("ContainerStart failed: %v", err) + } + for _, profile := range c.profiles { + if err := profile.OnStart(c); err != nil { + return fmt.Errorf("OnStart method failed: %v", err) + } + } + return nil } // Stop is analogous to 'docker stop'. @@ -482,6 +518,12 @@ func (c *Container) Remove(ctx context.Context) error { // CleanUp kills and deletes the container (best effort). func (c *Container) CleanUp(ctx context.Context) { + // Execute profile cleanups before the container goes down. + for _, profile := range c.profiles { + profile.OnCleanUp(c) + } + // Forget profiles. + c.profiles = nil // Kill the container. if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { // Just log; can't do anything here. diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index df09babf3..5a9dd8bd8 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -25,6 +25,7 @@ import ( "os/exec" "regexp" "strconv" + "time" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -42,6 +43,26 @@ var ( // config is the default Docker daemon configuration path. config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") + + // The following flags are for the "pprof" profiler tool. + + // pprofBaseDir allows the user to change the directory to which profiles are + // written. By default, profiles will appear under: + // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + + // duration is the max duration `runsc debug` will run and capture profiles. + // If the container's clean up method is called prior to duration, the + // profiling process will be killed. + duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + + // The below flags enable each type of profile. Multiple profiles can be + // enabled for each run. + pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") + pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") + pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug") + pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") + pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") ) // EnsureSupportedDockerVersion checks if correct docker is installed. diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go new file mode 100644 index 000000000..1fab33083 --- /dev/null +++ b/pkg/test/dockerutil/profile.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" +) + +// Profile represents profile-like operations on a container, +// such as running perf or pprof. It is meant to be added to containers +// such that the container type calls the Profile during its lifecycle. +type Profile interface { + // OnCreate is called just after the container is created when the container + // has a valid ID (e.g. c.ID()). + OnCreate(c *Container) error + + // OnStart is called just after the container is started when the container + // has a valid Pid (e.g. c.SandboxPid()). + OnStart(c *Container) error + + // Restart restarts the Profile on request. + Restart(c *Container) error + + // OnCleanUp is called during the container's cleanup method. + // Cleanups should just log errors if they have them. + OnCleanUp(c *Container) error +} + +// Pprof is for running profiles with 'runsc debug'. Pprof workloads +// should be run as root and ONLY against runsc sandboxes. The runtime +// should have --profile set as an option in /etc/docker/daemon.json in +// order for profiling to work with Pprof. +type Pprof struct { + BasePath string // path to put profiles + BlockProfile bool + CPUProfile bool + GoRoutineProfile bool + HeapProfile bool + MutexProfile bool + Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. + shouldRun bool + cmd *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser +} + +// MakePprofFromFlags makes a Pprof profile from flags. +func MakePprofFromFlags(c *Container) *Pprof { + if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) { + return nil + } + return &Pprof{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + BlockProfile: *pprofBlock, + CPUProfile: *pprofCPU, + GoRoutineProfile: *pprofGo, + HeapProfile: *pprofHeap, + MutexProfile: *pprofMutex, + Duration: *duration, + } +} + +// OnCreate implements Profile.OnCreate. +func (p *Pprof) OnCreate(c *Container) error { + return os.MkdirAll(p.BasePath, 0755) +} + +// OnStart implements Profile.OnStart. +func (p *Pprof) OnStart(c *Container) error { + path, err := RuntimePath() + if err != nil { + return fmt.Errorf("failed to get runtime path: %v", err) + } + + // The root directory of this container's runtime. + root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + args := []string{root, "debug"} + args = append(args, p.makeProfileArgs(c)...) + args = append(args, c.ID()) + + // Best effort wait until container is running. + for now := time.Now(); time.Since(now) < 5*time.Second; { + if status, err := c.Status(context.Background()); err != nil { + return fmt.Errorf("failed to get status with: %v", err) + + } else if status.Running { + break + } + time.Sleep(500 * time.Millisecond) + } + p.cmd = exec.Command(path, args...) + if err := p.cmd.Start(); err != nil { + return fmt.Errorf("process failed: %v", err) + } + return nil +} + +// Restart implements Profile.Restart. +func (p *Pprof) Restart(c *Container) error { + p.OnCleanUp(c) + return p.OnStart(c) +} + +// OnCleanUp implements Profile.OnCleanup +func (p *Pprof) OnCleanUp(c *Container) error { + defer func() { p.cmd = nil }() + if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { + return p.cmd.Process.Kill() + } + return nil +} + +// makeProfileArgs turns Pprof fields into runsc debug flags. +func (p *Pprof) makeProfileArgs(c *Container) []string { + var ret []string + if p.BlockProfile { + ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) + } + if p.CPUProfile { + ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) + } + if p.GoRoutineProfile { + ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof"))) + } + if p.HeapProfile { + ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) + } + if p.MutexProfile { + ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) + } + ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) + return ret +} diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go new file mode 100644 index 000000000..b7b4d7618 --- /dev/null +++ b/pkg/test/dockerutil/profile_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +type testCase struct { + name string + pprof Pprof + expectedFiles []string +} + +func TestPprof(t *testing.T) { + // Basepath and expected file names for each type of profile. + basePath := "/tmp/test/profile" + block := "block.pprof" + cpu := "cpu.pprof" + goprofle := "go.pprof" + heap := "heap.pprof" + mutex := "mutex.pprof" + + testCases := []testCase{ + { + name: "Cpu", + pprof: Pprof{ + BasePath: basePath, + CPUProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{cpu}, + }, + { + name: "All", + pprof: Pprof{ + BasePath: basePath, + BlockProfile: true, + CPUProfile: true, + GoRoutineProfile: true, + HeapProfile: true, + MutexProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. + tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) + c.AddProfile(&tc.pprof) + + func() { + defer c.CleanUp(ctx) + // Start a container. + if err := c.Spawn(ctx, RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { + t.Fatalf("run failed with: %v", err) + } + + if status, err := c.Status(context.Background()); !status.Running { + t.Fatalf("container is not yet running: %+v err: %v", status, err) + } + + // End early if the expected files exist and have data. + for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { + if err := checkFiles(tc); err == nil { + break + } + } + }() + + // Check all expected files exist and have data. + if err := checkFiles(tc); err != nil { + t.Fatalf(err.Error()) + } + }) + } +} + +func checkFiles(tc testCase) error { + for _, file := range tc.expectedFiles { + stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) + if err != nil { + return fmt.Errorf("stat failed with: %v", err) + } else if stat.Size() < 1 { + return fmt.Errorf("file not written to: %+v", stat) + } + } + return nil +} + +func TestMain(m *testing.M) { + EnsureSupportedDockerVersion() + os.Exit(m.Run()) +} |