diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/sentry/platform | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/sentry/platform')
85 files changed, 12199 insertions, 0 deletions
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD new file mode 100644 index 000000000..d5be81f8d --- /dev/null +++ b/pkg/sentry/platform/BUILD @@ -0,0 +1,51 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "platform_state", + srcs = [ + "file_range.go", + ], + out = "platform_state.go", + package = "platform", +) + +go_template_instance( + name = "file_range", + out = "file_range.go", + package = "platform", + prefix = "File", + template = "//pkg/segment:generic_range", + types = { + "T": "uint64", + }, +) + +go_library( + name = "platform", + srcs = [ + "context.go", + "file_range.go", + "mmap_min_addr.go", + "platform.go", + "platform_state.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/atomicbitops", + "//pkg/log", + "//pkg/sentry/arch", + "//pkg/sentry/context", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/safemem", + "//pkg/sentry/usage", + "//pkg/sentry/usermem", + "//pkg/state", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/platform/context.go b/pkg/sentry/platform/context.go new file mode 100644 index 000000000..0d200a5e2 --- /dev/null +++ b/pkg/sentry/platform/context.go @@ -0,0 +1,36 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package platform + +import ( + "gvisor.googlesource.com/gvisor/pkg/sentry/context" +) + +// contextID is the auth package's type for context.Context.Value keys. +type contextID int + +const ( + // CtxPlatform is a Context.Value key for a Platform. + CtxPlatform contextID = iota +) + +// FromContext returns the Platform that is used to execute ctx's application +// code, or nil if no such Platform exists. +func FromContext(ctx context.Context) Platform { + if v := ctx.Value(CtxPlatform); v != nil { + return v.(Platform) + } + return nil +} diff --git a/pkg/sentry/platform/filemem/BUILD b/pkg/sentry/platform/filemem/BUILD new file mode 100644 index 000000000..3c4d5b0b6 --- /dev/null +++ b/pkg/sentry/platform/filemem/BUILD @@ -0,0 +1,69 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "filemem_autogen_state", + srcs = [ + "filemem.go", + "filemem_state.go", + "usage_set.go", + ], + out = "filemem_autogen_state.go", + package = "filemem", +) + +go_template_instance( + name = "usage_set", + out = "usage_set.go", + consts = { + "minDegree": "10", + }, + imports = { + "platform": "gvisor.googlesource.com/gvisor/pkg/sentry/platform", + }, + package = "filemem", + prefix = "usage", + template = "//pkg/segment:generic_set", + types = { + "Key": "uint64", + "Range": "platform.FileRange", + "Value": "usageInfo", + "Functions": "usageSetFunctions", + }, +) + +go_library( + name = "filemem", + srcs = [ + "filemem.go", + "filemem_autogen_state.go", + "filemem_state.go", + "filemem_unsafe.go", + "usage_set.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/filemem", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/log", + "//pkg/sentry/arch", + "//pkg/sentry/context", + "//pkg/sentry/memutil", + "//pkg/sentry/platform", + "//pkg/sentry/safemem", + "//pkg/sentry/usage", + "//pkg/sentry/usermem", + "//pkg/state", + "//pkg/syserror", + ], +) + +go_test( + name = "filemem_test", + size = "small", + srcs = ["filemem_test.go"], + embed = [":filemem"], + deps = ["//pkg/sentry/usermem"], +) diff --git a/pkg/sentry/platform/filemem/filemem.go b/pkg/sentry/platform/filemem/filemem.go new file mode 100644 index 000000000..d79c3c7f1 --- /dev/null +++ b/pkg/sentry/platform/filemem/filemem.go @@ -0,0 +1,838 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package filemem provides a reusable implementation of platform.Memory. +// +// It enables memory to be sourced from a memfd file. +// +// Lock order: +// +// filemem.FileMem.mu +// filemem.FileMem.mappingsMu +package filemem + +import ( + "fmt" + "math" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/sentry/context" + "gvisor.googlesource.com/gvisor/pkg/sentry/memutil" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/usage" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" + "gvisor.googlesource.com/gvisor/pkg/syserror" +) + +// FileMem is a platform.Memory that allocates from a host file that it owns. +type FileMem struct { + // Filemem models the backing file as follows: + // + // Each page in the file can be committed or uncommitted. A page is + // committed if the host kernel is spending resources to store its contents + // and uncommitted otherwise. This definition includes pages that the host + // kernel has swapped; this is intentional, to ensure that accounting does + // not change even if host kernel swapping behavior changes, and that + // memory used by pseudo-swap mechanisms like zswap is still accounted. + // + // The initial contents of uncommitted pages are implicitly zero bytes. A + // read or write to the contents of an uncommitted page causes it to be + // committed. This is the only event that can cause a uncommitted page to + // be committed. + // + // fallocate(FALLOC_FL_PUNCH_HOLE) (FileMem.Decommit) causes committed + // pages to be uncommitted. This is the only event that can cause a + // committed page to be uncommitted. + // + // Filemem's accounting is based on identifying the set of committed pages. + // Since filemem does not have direct access to the MMU, tracking reads and + // writes to uncommitted pages to detect commitment would introduce + // additional page faults, which would be prohibitively expensive. Instead, + // filemem queries the host kernel to determine which pages are committed. + + // file is the backing memory file. The file pointer is immutable. + file *os.File + + mu sync.Mutex + + // usage maps each page in the file to metadata for that page. Pages for + // which no segment exists in usage are both unallocated (not in use) and + // uncommitted. + // + // Since usage stores usageInfo objects by value, clients should usually + // use usageIterator.ValuePtr() instead of usageIterator.Value() to get a + // pointer to the usageInfo rather than a copy. + // + // usage must be kept maximally merged (that is, there should never be two + // adjacent segments with the same values). At least markReclaimed depends + // on this property. + // + // usage is protected by mu. + usage usageSet + + // The UpdateUsage function scans all segments with knownCommitted set + // to false, sees which pages are committed and creates corresponding + // segments with knownCommitted set to true. + // + // In order to avoid unnecessary scans, usageExpected tracks the total + // file blocks expected. This is used to elide the scan when this + // matches the underlying file blocks. + // + // To track swapped pages, usageSwapped tracks the discrepency between + // what is observed in core and what is reported by the file. When + // usageSwapped is non-zero, a sweep will be performed at least every + // second. The start of the last sweep is recorded in usageLast. + // + // All usage attributes are all protected by mu. + usageExpected uint64 + usageSwapped uint64 + usageLast time.Time + + // fileSize is the size of the backing memory file in bytes. fileSize is + // always a power-of-two multiple of chunkSize. + // + // fileSize is protected by mu. + fileSize int64 + + // destroyed is set by Destroy to instruct the reclaimer goroutine to + // release resources and exit. destroyed is protected by mu. + destroyed bool + + // reclaimable is true if usage may contain reclaimable pages. reclaimable + // is protected by mu. + reclaimable bool + + // reclaimCond is signaled (with mu locked) when reclaimable or destroyed + // transitions from false to true. + reclaimCond sync.Cond + + // Filemem pages are mapped into the local address space on the granularity + // of large pieces called chunks. mappings is a []uintptr that stores, for + // each chunk, the start address of a mapping of that chunk in the current + // process' address space, or 0 if no such mapping exists. Once a chunk is + // mapped, it is never remapped or unmapped until the filemem is destroyed. + // + // Mutating the mappings slice or its contents requires both holding + // mappingsMu and using atomic memory operations. (The slice is mutated + // whenever the file is expanded. Per the above, the only permitted + // mutation of the slice's contents is the assignment of a mapping to a + // chunk that was previously unmapped.) Reading the slice or its contents + // only requires *either* holding mappingsMu or using atomic memory + // operations. This allows FileMem.AccessPhysical to avoid locking in the + // common case where chunk mappings already exist. + + mappingsMu sync.Mutex + mappings atomic.Value +} + +// usage tracks usage information. +type usageInfo struct { + // kind is the usage kind. + kind usage.MemoryKind + + // knownCommitted indicates whether this region is known to be + // committed. If this is false, then the region may or may not have + // been touched. If it is true however, then mincore (below) has + // indicated that the page is present at least once. + knownCommitted bool + + refs uint64 +} + +func (u *usageInfo) incRef() { + u.refs++ +} + +func (u *usageInfo) decRef() { + if u.refs == 0 { + panic("DecRef at 0 refs!") + } + u.refs-- +} + +const ( + chunkShift = 24 + chunkSize = 1 << chunkShift // 16 MB + chunkMask = chunkSize - 1 + + initialSize = chunkSize +) + +// newFromFile creates a FileMem backed by the given file. +func newFromFile(file *os.File) (*FileMem, error) { + if err := file.Truncate(initialSize); err != nil { + return nil, err + } + f := &FileMem{ + fileSize: initialSize, + file: file, + } + f.reclaimCond.L = &f.mu + f.mappings.Store(make([]uintptr, initialSize/chunkSize)) + go f.runReclaim() // S/R-SAFE: f.mu + + // The Linux kernel contains an optional feature called "Integrity + // Measurement Architecture" (IMA). If IMA is enabled, it will checksum + // binaries the first time they are mapped PROT_EXEC. This is bad news for + // executable pages mapped from FileMem, which can grow to terabytes in + // (sparse) size. If IMA attempts to checksum a file that large, it will + // allocate all of the sparse pages and quickly exhaust all memory. + // + // Work around IMA by immediately creating a temporary PROT_EXEC mapping, + // while FileMem is still small. IMA will ignore any future mappings. + m, _, errno := syscall.Syscall6( + syscall.SYS_MMAP, + 0, + usermem.PageSize, + syscall.PROT_EXEC, + syscall.MAP_SHARED, + f.file.Fd(), + 0) + if errno != 0 { + // This isn't fatal to filemem (IMA may not even be in use). Log the + // error, but don't return it. + log.Warningf("Failed to pre-map FileMem PROT_EXEC: %v", errno) + } else { + syscall.Syscall( + syscall.SYS_MUNMAP, + m, + usermem.PageSize, + 0) + } + + return f, nil +} + +// New creates a FileMem backed by a memfd file. +func New(name string) (*FileMem, error) { + fd, err := memutil.CreateMemFD(name, 0) + if err != nil { + return nil, err + } + return newFromFile(os.NewFile(uintptr(fd), name)) +} + +// Destroy implements platform.Memory.Destroy. +func (f *FileMem) Destroy() { + f.mu.Lock() + defer f.mu.Unlock() + f.destroyed = true + f.reclaimCond.Signal() +} + +// Allocate implements platform.Memory.Allocate. +func (f *FileMem) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) { + if length == 0 || length%usermem.PageSize != 0 { + panic(fmt.Sprintf("invalid allocation length: %#x", length)) + } + + f.mu.Lock() + defer f.mu.Unlock() + + // Align hugepage-and-larger allocations on hugepage boundaries to try + // to take advantage of hugetmpfs. + alignment := uint64(usermem.PageSize) + if length >= usermem.HugePageSize { + alignment = usermem.HugePageSize + } + + start := findUnallocatedRange(&f.usage, length, alignment) + end := start + length + // File offsets are int64s. Since length must be strictly positive, end + // cannot legitimately be 0. + if end < start || int64(end) <= 0 { + return platform.FileRange{}, syserror.ENOMEM + } + + // Expand the file if needed. Double the file size on each expansion; + // uncommitted pages have effectively no cost. + fileSize := f.fileSize + for int64(end) > fileSize { + if fileSize >= 2*fileSize { + // fileSize overflow. + return platform.FileRange{}, syserror.ENOMEM + } + fileSize *= 2 + } + if fileSize > f.fileSize { + if err := f.file.Truncate(fileSize); err != nil { + return platform.FileRange{}, err + } + f.fileSize = fileSize + f.mappingsMu.Lock() + oldMappings := f.mappings.Load().([]uintptr) + newMappings := make([]uintptr, fileSize>>chunkShift) + copy(newMappings, oldMappings) + f.mappings.Store(newMappings) + f.mappingsMu.Unlock() + } + + // Mark selected pages as in use. + fr := platform.FileRange{start, end} + if !f.usage.Add(fr, usageInfo{ + kind: kind, + refs: 1, + }) { + panic(fmt.Sprintf("allocating %v: failed to insert into f.usage:\n%v", fr, &f.usage)) + } + return fr, nil +} + +func findUnallocatedRange(usage *usageSet, length, alignment uint64) uint64 { + alignMask := alignment - 1 + var start uint64 + for seg := usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + r := seg.Range() + if start >= r.End { + // start was rounded up to an alignment boundary from the end + // of a previous segment. + continue + } + // This segment represents allocated or reclaimable pages; only the + // range from start to the segment's beginning is allocatable, and the + // next allocatable range begins after the segment. + if r.Start > start && r.Start-start >= length { + break + } + start = (r.End + alignMask) &^ alignMask + } + return start +} + +// fallocate(2) modes, defined in Linux's include/uapi/linux/falloc.h. +const ( + _FALLOC_FL_KEEP_SIZE = 1 + _FALLOC_FL_PUNCH_HOLE = 2 +) + +// Decommit implements platform.Memory.Decommit. +func (f *FileMem) Decommit(fr platform.FileRange) error { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + panic(fmt.Sprintf("invalid range: %v", fr)) + } + + // "After a successful call, subsequent reads from this range will + // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with + // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2) + err := syscall.Fallocate( + int(f.file.Fd()), + _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE, + int64(fr.Start), + int64(fr.Length())) + if err != nil { + return err + } + f.markDecommitted(fr) + return nil +} + +func (f *FileMem) markDecommitted(fr platform.FileRange) { + f.mu.Lock() + defer f.mu.Unlock() + // Since we're changing the knownCommitted attribute, we need to merge + // across the entire range to ensure that the usage tree is minimal. + gap := f.usage.ApplyContiguous(fr, func(seg usageIterator) { + val := seg.ValuePtr() + if val.knownCommitted { + // Drop the usageExpected appropriately. + amount := seg.Range().Length() + usage.MemoryAccounting.Dec(amount, val.kind) + f.usageExpected -= amount + val.knownCommitted = false + } + }) + if gap.Ok() { + panic(fmt.Sprintf("Decommit(%v): attempted to decommit unallocated pages %v:\n%v", fr, gap.Range(), &f.usage)) + } + f.usage.MergeRange(fr) +} + +// runReclaim implements the reclaimer goroutine, which continuously decommits +// reclaimable frames in order to reduce memory usage. +func (f *FileMem) runReclaim() { + for { + fr, ok := f.findReclaimable() + if !ok { + break + } + + if err := f.Decommit(fr); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", fr, err) + // Zero the frames manually. This won't reduce memory usage, but at + // least ensures that the frames will be zero when reallocated. + f.forEachMappingSlice(fr, func(bs []byte) { + for i := range bs { + bs[i] = 0 + } + }) + // Pretend the frames were decommitted even though they weren't, + // since the memory accounting implementation has no idea how to + // deal with this. + f.markDecommitted(fr) + } + f.markReclaimed(fr) + } + // We only get here if findReclaimable finds f.destroyed set and returns + // false. + f.mu.Lock() + defer f.mu.Unlock() + if !f.destroyed { + panic("findReclaimable broke out of reclaim loop, but f.destroyed is no longer set") + } + f.file.Close() + // Ensure that any attempts to use f.file.Fd() fail instead of getting a fd + // that has possibly been reassigned. + f.file = nil + mappings := f.mappings.Load().([]uintptr) + for i, m := range mappings { + if m != 0 { + _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m, chunkSize, 0) + if errno != 0 { + log.Warningf("Failed to unmap mapping %#x for filemem chunk %d: %v", m, i, errno) + } + } + } + // Similarly, invalidate f.mappings. (atomic.Value.Store(nil) panics.) + f.mappings.Store([]uintptr{}) +} + +func (f *FileMem) findReclaimable() (platform.FileRange, bool) { + f.mu.Lock() + defer f.mu.Unlock() + for { + for { + if f.destroyed { + return platform.FileRange{}, false + } + if f.reclaimable { + break + } + f.reclaimCond.Wait() + } + // Allocate returns the first usable range in offset order and is + // currently a linear scan, so reclaiming from the beginning of the + // file minimizes the expected latency of Allocate. + for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + if seg.ValuePtr().refs == 0 { + return seg.Range(), true + } + } + f.reclaimable = false + } +} + +func (f *FileMem) markReclaimed(fr platform.FileRange) { + f.mu.Lock() + defer f.mu.Unlock() + seg := f.usage.FindSegment(fr.Start) + // All of fr should be mapped to a single uncommitted reclaimable segment + // accounted to System. + if !seg.Ok() { + panic(fmt.Sprintf("Reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage)) + } + if !seg.Range().IsSupersetOf(fr) { + panic(fmt.Sprintf("Reclaimed pages %v are not entirely contained in segment %v with state %v:\n%v", fr, seg.Range(), seg.Value(), &f.usage)) + } + if got, want := seg.Value(), (usageInfo{ + kind: usage.System, + knownCommitted: false, + refs: 0, + }); got != want { + panic(fmt.Sprintf("Reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage)) + } + // Deallocate reclaimed pages. Even though all of seg is reclaimable, the + // caller of markReclaimed may not have decommitted it, so we can only mark + // fr as reclaimed. + f.usage.Remove(f.usage.Isolate(seg, fr)) +} + +// MapInto implements platform.File.MapInto. +func (f *FileMem) MapInto(as platform.AddressSpace, addr usermem.Addr, fr platform.FileRange, at usermem.AccessType, precommit bool) error { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + panic(fmt.Sprintf("invalid range: %v", fr)) + } + return as.MapFile(addr, int(f.file.Fd()), fr, at, precommit) +} + +// MapInternal implements platform.File.MapInternal. +func (f *FileMem) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { + if !fr.WellFormed() || fr.Length() == 0 { + panic(fmt.Sprintf("invalid range: %v", fr)) + } + if at.Execute { + return safemem.BlockSeq{}, syserror.EACCES + } + + chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift) + if chunks == 1 { + // Avoid an unnecessary slice allocation. + var seq safemem.BlockSeq + err := f.forEachMappingSlice(fr, func(bs []byte) { + seq = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(bs)) + }) + return seq, err + } + blocks := make([]safemem.Block, 0, chunks) + err := f.forEachMappingSlice(fr, func(bs []byte) { + blocks = append(blocks, safemem.BlockFromSafeSlice(bs)) + }) + return safemem.BlockSeqFromSlice(blocks), err +} + +// IncRef implements platform.File.IncRef. +func (f *FileMem) IncRef(fr platform.FileRange) { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + panic(fmt.Sprintf("invalid range: %v", fr)) + } + + f.mu.Lock() + defer f.mu.Unlock() + + gap := f.usage.ApplyContiguous(fr, func(seg usageIterator) { + seg.ValuePtr().incRef() + }) + if gap.Ok() { + panic(fmt.Sprintf("IncRef(%v): attempted to IncRef on unallocated pages %v:\n%v", fr, gap.Range(), &f.usage)) + } +} + +// DecRef implements platform.File.DecRef. +func (f *FileMem) DecRef(fr platform.FileRange) { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + panic(fmt.Sprintf("invalid range: %v", fr)) + } + + var freed bool + + f.mu.Lock() + defer f.mu.Unlock() + + for seg := f.usage.FindSegment(fr.Start); seg.Ok() && seg.Start() < fr.End; seg = seg.NextSegment() { + seg = f.usage.Isolate(seg, fr) + val := seg.ValuePtr() + val.decRef() + if val.refs == 0 { + freed = true + // Reclassify memory as System, until it's freed by the reclaim + // goroutine. + if val.knownCommitted { + usage.MemoryAccounting.Move(seg.Range().Length(), usage.System, val.kind) + } + val.kind = usage.System + } + } + f.usage.MergeAdjacent(fr) + + if freed { + f.reclaimable = true + f.reclaimCond.Signal() + } +} + +// Flush implements platform.Mappable.Flush. +func (f *FileMem) Flush(ctx context.Context) error { + return nil +} + +// forEachMappingSlice invokes fn on a sequence of byte slices that +// collectively map all bytes in fr. +func (f *FileMem) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error { + mappings := f.mappings.Load().([]uintptr) + for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { + chunk := int(chunkStart >> chunkShift) + m := atomic.LoadUintptr(&mappings[chunk]) + if m == 0 { + var err error + mappings, m, err = f.getChunkMapping(chunk) + if err != nil { + return err + } + } + startOff := uint64(0) + if chunkStart < fr.Start { + startOff = fr.Start - chunkStart + } + endOff := uint64(chunkSize) + if chunkStart+chunkSize > fr.End { + endOff = fr.End - chunkStart + } + fn(unsafeSlice(m, chunkSize)[startOff:endOff]) + } + return nil +} + +func (f *FileMem) getChunkMapping(chunk int) ([]uintptr, uintptr, error) { + f.mappingsMu.Lock() + defer f.mappingsMu.Unlock() + // Another thread may have replaced f.mappings altogether due to file + // expansion. + mappings := f.mappings.Load().([]uintptr) + // Another thread may have already mapped the chunk. + if m := mappings[chunk]; m != 0 { + return mappings, m, nil + } + m, _, errno := syscall.Syscall6( + syscall.SYS_MMAP, + 0, + chunkSize, + syscall.PROT_READ|syscall.PROT_WRITE, + syscall.MAP_SHARED, + f.file.Fd(), + uintptr(chunk<<chunkShift)) + if errno != 0 { + return nil, 0, errno + } + atomic.StoreUintptr(&mappings[chunk], m) + return mappings, m, nil +} + +// UpdateUsage implements platform.Memory.UpdateUsage. +func (f *FileMem) UpdateUsage() error { + f.mu.Lock() + defer f.mu.Unlock() + + // If the underlying usage matches where the usage tree already + // represents, then we can just avoid the entire scan (we know it's + // accurate). + currentUsage, err := f.TotalUsage() + if err != nil { + return err + } + if currentUsage == f.usageExpected && f.usageSwapped == 0 { + log.Debugf("UpdateUsage: skipped with usageSwapped=0.") + return nil + } + // If the current usage matches the expected but there's swap + // accounting, then ensure a scan takes place at least every second + // (when requested). + if currentUsage == f.usageExpected+f.usageSwapped && time.Now().Before(f.usageLast.Add(time.Second)) { + log.Debugf("UpdateUsage: skipped with usageSwapped!=0.") + return nil + } + + f.usageLast = time.Now() + err = f.updateUsageLocked(currentUsage, mincore) + log.Debugf("UpdateUsage: currentUsage=%d, usageExpected=%d, usageSwapped=%d.", + currentUsage, f.usageExpected, f.usageSwapped) + log.Debugf("UpdateUsage: took %v.", time.Since(f.usageLast)) + return err +} + +// updateUsageLocked attempts to detect commitment of previous-uncommitted +// pages by invoking checkCommitted, which is a function that, for each page i +// in bs, sets committed[i] to 1 if the page is committed and 0 otherwise. +// +// Precondition: f.mu must be held. +func (f *FileMem) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error { + // Track if anything changed to elide the merge. In the common case, we + // expect all segments to be committed and no merge to occur. + changedAny := false + defer func() { + if changedAny { + f.usage.MergeAll() + } + + // Adjust the swap usage to reflect reality. + if f.usageExpected < currentUsage { + // Since no pages may be decommitted while we hold usageMu, we + // know that usage may have only increased since we got the + // last current usage. Therefore, if usageExpected is still + // short of currentUsage, we must assume that the difference is + // in pages that have been swapped. + newUsageSwapped := currentUsage - f.usageExpected + if f.usageSwapped < newUsageSwapped { + usage.MemoryAccounting.Inc(newUsageSwapped-f.usageSwapped, usage.System) + } else { + usage.MemoryAccounting.Dec(f.usageSwapped-newUsageSwapped, usage.System) + } + f.usageSwapped = newUsageSwapped + } else if f.usageSwapped != 0 { + // We have more usage accounted for than the file itself. + // That's fine, we probably caught a race where pages were + // being committed while the above loop was running. Just + // report the higher number that we found and ignore swap. + usage.MemoryAccounting.Dec(f.usageSwapped, usage.System) + f.usageSwapped = 0 + } + }() + + // Reused mincore buffer, will generally be <= 4096 bytes. + var buf []byte + + // Iterate over all usage data. There will only be usage segments + // present when there is an associated reference. + for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + val := seg.Value() + + // Already known to be committed; ignore. + if val.knownCommitted { + continue + } + + // Assume that reclaimable pages (that aren't already known to be + // committed) are not committed. This isn't necessarily true, even + // after the reclaimer does Decommit(), because the kernel may + // subsequently back the hugepage-sized region containing the + // decommitted page with a hugepage. However, it's consistent with our + // treatment of unallocated pages, which have the same property. + if val.refs == 0 { + continue + } + + // Get the range for this segment. As we touch slices, the + // Start value will be walked along. + r := seg.Range() + + var checkErr error + err := f.forEachMappingSlice(r, func(s []byte) { + if checkErr != nil { + return + } + + // Ensure that we have sufficient buffer for the call + // (one byte per page). The length of each slice must + // be page-aligned. + bufLen := len(s) / usermem.PageSize + if len(buf) < bufLen { + buf = make([]byte, bufLen) + } + + // Query for new pages in core. + if err := checkCommitted(s, buf); err != nil { + checkErr = err + return + } + + // Scan each page and switch out segments. + populatedRun := false + populatedRunStart := 0 + for i := 0; i <= bufLen; i++ { + // We run past the end of the slice here to + // simplify the logic and only set populated if + // we're still looking at elements. + populated := false + if i < bufLen { + populated = buf[i]&0x1 != 0 + } + + switch { + case populated == populatedRun: + // Keep the run going. + continue + case populated && !populatedRun: + // Begin the run. + populatedRun = true + populatedRunStart = i + // Keep going. + continue + case !populated && populatedRun: + // Finish the run by changing this segment. + runRange := platform.FileRange{ + Start: r.Start + uint64(populatedRunStart*usermem.PageSize), + End: r.Start + uint64(i*usermem.PageSize), + } + seg = f.usage.Isolate(seg, runRange) + seg.ValuePtr().knownCommitted = true + // Advance the segment only if we still + // have work to do in the context of + // the original segment from the for + // loop. Otherwise, the for loop itself + // will advance the segment + // appropriately. + if runRange.End != r.End { + seg = seg.NextSegment() + } + amount := runRange.Length() + usage.MemoryAccounting.Inc(amount, val.kind) + f.usageExpected += amount + changedAny = true + populatedRun = false + } + } + + // Advance r.Start. + r.Start += uint64(len(s)) + }) + if checkErr != nil { + return checkErr + } + if err != nil { + return err + } + } + + return nil +} + +// TotalUsage implements platform.Memory.TotalUsage. +func (f *FileMem) TotalUsage() (uint64, error) { + // Stat the underlying file to discover the underlying usage. stat(2) + // always reports the allocated block count in units of 512 bytes. This + // includes pages in the page cache and swapped pages. + var stat syscall.Stat_t + if err := syscall.Fstat(int(f.file.Fd()), &stat); err != nil { + return 0, err + } + return uint64(stat.Blocks * 512), nil +} + +// TotalSize implements platform.Memory.TotalSize. +func (f *FileMem) TotalSize() uint64 { + f.mu.Lock() + defer f.mu.Unlock() + return uint64(f.fileSize) +} + +// File returns the memory file used by f. +func (f *FileMem) File() *os.File { + return f.file +} + +// String implements fmt.Stringer.String. +// +// Note that because f.String locks f.mu, calling f.String internally +// (including indirectly through the fmt package) risks recursive locking. +// Within the filemem package, use f.usage directly instead. +func (f *FileMem) String() string { + f.mu.Lock() + defer f.mu.Unlock() + return f.usage.String() +} + +type usageSetFunctions struct{} + +func (usageSetFunctions) MinKey() uint64 { + return 0 +} + +func (usageSetFunctions) MaxKey() uint64 { + return math.MaxUint64 +} + +func (usageSetFunctions) ClearValue(val *usageInfo) { +} + +func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) { + return val1, val1 == val2 +} + +func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { + return val, val +} diff --git a/pkg/sentry/platform/filemem/filemem_state.go b/pkg/sentry/platform/filemem/filemem_state.go new file mode 100644 index 000000000..5dace8fec --- /dev/null +++ b/pkg/sentry/platform/filemem/filemem_state.go @@ -0,0 +1,170 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filemem + +import ( + "bytes" + "fmt" + "io" + "runtime" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/sentry/usage" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" + "gvisor.googlesource.com/gvisor/pkg/state" +) + +// SaveTo implements platform.Memory.SaveTo. +func (f *FileMem) SaveTo(w io.Writer) error { + // Wait for reclaim. + f.mu.Lock() + defer f.mu.Unlock() + for f.reclaimable { + f.reclaimCond.Signal() + f.mu.Unlock() + runtime.Gosched() + f.mu.Lock() + } + + // Ensure that all pages that contain data have knownCommitted set, since + // we only store knownCommitted pages below. + zeroPage := make([]byte, usermem.PageSize) + err := f.updateUsageLocked(0, func(bs []byte, committed []byte) error { + for pgoff := 0; pgoff < len(bs); pgoff += usermem.PageSize { + i := pgoff / usermem.PageSize + pg := bs[pgoff : pgoff+usermem.PageSize] + if !bytes.Equal(pg, zeroPage) { + committed[i] = 1 + continue + } + committed[i] = 0 + // Reading the page caused it to be committed; decommit it to + // reduce memory usage. + // + // "MADV_REMOVE [...] Free up a given range of pages and its + // associated backing store. This is equivalent to punching a hole + // in the corresponding byte range of the backing store (see + // fallocate(2))." - madvise(2) + if err := syscall.Madvise(pg, syscall.MADV_REMOVE); err != nil { + // This doesn't impact the correctness of saved memory, it + // just means that we're incrementally more likely to OOM. + // Complain, but don't abort saving. + log.Warningf("Decommitting page %p while saving failed: %v", pg, err) + } + } + return nil + }) + if err != nil { + return err + } + + // Save metadata. + if err := state.Save(w, &f.fileSize, nil); err != nil { + return err + } + if err := state.Save(w, &f.usage, nil); err != nil { + return err + } + + // Dump out committed pages. + for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + if !seg.Value().knownCommitted { + continue + } + // Write a header to distinguish from objects. + if err := state.WriteHeader(w, uint64(seg.Range().Length()), false); err != nil { + return err + } + // Write out data. + var ioErr error + err := f.forEachMappingSlice(seg.Range(), func(s []byte) { + if ioErr != nil { + return + } + _, ioErr = w.Write(s) + }) + if ioErr != nil { + return ioErr + } + if err != nil { + return err + } + + // Update accounting for restored pages. We need to do this here since + // these segments are marked as "known committed", and will be skipped + // over on accounting scans. + usage.MemoryAccounting.Inc(seg.Range().Length(), seg.Value().kind) + } + + return nil +} + +// LoadFrom implements platform.Memory.LoadFrom. +func (f *FileMem) LoadFrom(r io.Reader) error { + // Load metadata. + if err := state.Load(r, &f.fileSize, nil); err != nil { + return err + } + if err := f.file.Truncate(f.fileSize); err != nil { + return err + } + newMappings := make([]uintptr, f.fileSize>>chunkShift) + f.mappings.Store(newMappings) + if err := state.Load(r, &f.usage, nil); err != nil { + return err + } + + // Load committed pages. + for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + if !seg.Value().knownCommitted { + continue + } + // Verify header. + length, object, err := state.ReadHeader(r) + if err != nil { + return err + } + if object { + // Not expected. + return fmt.Errorf("unexpected object") + } + if expected := uint64(seg.Range().Length()); length != expected { + // Size mismatch. + return fmt.Errorf("mismatched segment: expected %d, got %d", expected, length) + } + // Read data. + var ioErr error + err = f.forEachMappingSlice(seg.Range(), func(s []byte) { + if ioErr != nil { + return + } + _, ioErr = io.ReadFull(r, s) + }) + if ioErr != nil { + return ioErr + } + if err != nil { + return err + } + + // Update accounting for restored pages. We need to do this here since + // these segments are marked as "known committed", and will be skipped + // over on accounting scans. + usage.MemoryAccounting.Inc(seg.End()-seg.Start(), seg.Value().kind) + } + + return nil +} diff --git a/pkg/sentry/platform/filemem/filemem_test.go b/pkg/sentry/platform/filemem/filemem_test.go new file mode 100644 index 000000000..46ffcf116 --- /dev/null +++ b/pkg/sentry/platform/filemem/filemem_test.go @@ -0,0 +1,122 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filemem + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +const ( + page = usermem.PageSize + hugepage = usermem.HugePageSize +) + +func TestFindUnallocatedRange(t *testing.T) { + for _, test := range []struct { + desc string + usage *usageSegmentDataSlices + length uint64 + alignment uint64 + start uint64 + }{ + { + desc: "Initial allocation succeeds", + usage: &usageSegmentDataSlices{}, + length: page, + alignment: page, + start: 0, + }, + { + desc: "Allocation begins at start of file", + usage: &usageSegmentDataSlices{ + Start: []uint64{page}, + End: []uint64{2 * page}, + Values: []usageInfo{{refs: 1}}, + }, + length: page, + alignment: page, + start: 0, + }, + { + desc: "In-use frames are not allocatable", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, page}, + End: []uint64{page, 2 * page}, + Values: []usageInfo{{refs: 1}, {refs: 2}}, + }, + length: page, + alignment: page, + start: 2 * page, + }, + { + desc: "Reclaimable frames are not allocatable", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, page, 2 * page}, + End: []uint64{page, 2 * page, 3 * page}, + Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}}, + }, + length: page, + alignment: page, + start: 3 * page, + }, + { + desc: "Gaps between in-use frames are allocatable", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, 2 * page}, + End: []uint64{page, 3 * page}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + length: page, + alignment: page, + start: page, + }, + { + desc: "Inadequately-sized gaps are rejected", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, 2 * page}, + End: []uint64{page, 3 * page}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + length: 2 * page, + alignment: page, + start: 3 * page, + }, + { + desc: "Hugepage alignment is honored", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, hugepage + page}, + // Hugepage-sized gap here that shouldn't be allocated from + // since it's incorrectly aligned. + End: []uint64{page, hugepage + 2*page}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + length: hugepage, + alignment: hugepage, + start: 2 * hugepage, + }, + } { + t.Run(test.desc, func(t *testing.T) { + var usage usageSet + if err := usage.ImportSortedSlices(test.usage); err != nil { + t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err) + } + if got, want := findUnallocatedRange(&usage, test.length, test.alignment), test.start; got != want { + t.Errorf("findUnallocatedRange(%v, %d, %d): got %d, wanted %d", test.usage, test.length, test.alignment, got, want) + } + }) + } +} diff --git a/pkg/sentry/platform/filemem/filemem_unsafe.go b/pkg/sentry/platform/filemem/filemem_unsafe.go new file mode 100644 index 000000000..a23b9825a --- /dev/null +++ b/pkg/sentry/platform/filemem/filemem_unsafe.go @@ -0,0 +1,40 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filemem + +import ( + "reflect" + "syscall" + "unsafe" +) + +func unsafeSlice(addr uintptr, length int) (slice []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + sh.Data = addr + sh.Len = length + sh.Cap = length + return +} + +func mincore(s []byte, buf []byte) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_MINCORE, + uintptr(unsafe.Pointer(&s[0])), + uintptr(len(s)), + uintptr(unsafe.Pointer(&buf[0]))); errno != 0 { + return errno + } + return nil +} diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD new file mode 100644 index 000000000..33dde2a31 --- /dev/null +++ b/pkg/sentry/platform/interrupt/BUILD @@ -0,0 +1,19 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "interrupt", + srcs = [ + "interrupt.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/interrupt", + visibility = ["//pkg/sentry:internal"], +) + +go_test( + name = "interrupt_test", + size = "small", + srcs = ["interrupt_test.go"], + embed = [":interrupt"], +) diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go new file mode 100644 index 000000000..ca4f42087 --- /dev/null +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -0,0 +1,96 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package interrupt provides an interrupt helper. +package interrupt + +import ( + "fmt" + "sync" +) + +// Receiver receives interrupt notifications from a Forwarder. +type Receiver interface { + // NotifyInterrupt is called when the Receiver receives an interrupt. + NotifyInterrupt() +} + +// Forwarder is a helper for delivering delayed signal interruptions. +// +// This helps platform implementations with Interrupt semantics. +type Forwarder struct { + // mu protects the below. + mu sync.Mutex + + // dst is the function to be called when NotifyInterrupt() is called. If + // dst is nil, pending will be set instead, causing the next call to + // Enable() to return false. + dst Receiver + pending bool +} + +// Enable attempts to enable interrupt forwarding to r. If f has already +// received an interrupt, Enable does nothing and returns false. Otherwise, +// future calls to f.NotifyInterrupt() cause r.NotifyInterrupt() to be called, +// and Enable returns true. +// +// Usage: +// +// if !f.Enable(r) { +// // There was an interrupt. +// return +// } +// defer f.Disable() +// +// Preconditions: r must not be nil. f must not already be forwarding +// interrupts to a Receiver. +func (f *Forwarder) Enable(r Receiver) bool { + if r == nil { + panic("nil Receiver") + } + f.mu.Lock() + if f.dst != nil { + f.mu.Unlock() + panic(fmt.Sprintf("already forwarding interrupts to %+v", f.dst)) + } + if f.pending { + f.pending = false + f.mu.Unlock() + return false + } + f.dst = r + f.mu.Unlock() + return true +} + +// Disable stops interrupt forwarding. If interrupt forwarding is already +// disabled, Disable is a no-op. +func (f *Forwarder) Disable() { + f.mu.Lock() + f.dst = nil + f.mu.Unlock() +} + +// NotifyInterrupt implements Receiver.NotifyInterrupt. If interrupt forwarding +// is enabled, the configured Receiver will be notified. Otherwise the +// interrupt will be delivered to the next call to Enable. +func (f *Forwarder) NotifyInterrupt() { + f.mu.Lock() + if f.dst != nil { + f.dst.NotifyInterrupt() + } else { + f.pending = true + } + f.mu.Unlock() +} diff --git a/pkg/sentry/platform/interrupt/interrupt_test.go b/pkg/sentry/platform/interrupt/interrupt_test.go new file mode 100644 index 000000000..7c49eeea6 --- /dev/null +++ b/pkg/sentry/platform/interrupt/interrupt_test.go @@ -0,0 +1,99 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interrupt + +import ( + "testing" +) + +type countingReceiver struct { + interrupts int +} + +// NotifyInterrupt implements Receiver.NotifyInterrupt. +func (r *countingReceiver) NotifyInterrupt() { + r.interrupts++ +} + +func TestSingleInterruptBeforeEnable(t *testing.T) { + var ( + f Forwarder + r countingReceiver + ) + f.NotifyInterrupt() + // The interrupt should cause the first Enable to fail. + if f.Enable(&r) { + f.Disable() + t.Fatalf("Enable: got true, wanted false") + } + // The failing Enable "acknowledges" the interrupt, allowing future Enables + // to succeed. + if !f.Enable(&r) { + t.Fatalf("Enable: got false, wanted true") + } + f.Disable() +} + +func TestMultipleInterruptsBeforeEnable(t *testing.T) { + var ( + f Forwarder + r countingReceiver + ) + f.NotifyInterrupt() + f.NotifyInterrupt() + // The interrupts should cause the first Enable to fail. + if f.Enable(&r) { + f.Disable() + t.Fatalf("Enable: got true, wanted false") + } + // Interrupts are deduplicated while the Forwarder is disabled, so the + // failing Enable "acknowledges" all interrupts, allowing future Enables to + // succeed. + if !f.Enable(&r) { + t.Fatalf("Enable: got false, wanted true") + } + f.Disable() +} + +func TestSingleInterruptAfterEnable(t *testing.T) { + var ( + f Forwarder + r countingReceiver + ) + if !f.Enable(&r) { + t.Fatalf("Enable: got false, wanted true") + } + defer f.Disable() + f.NotifyInterrupt() + if r.interrupts != 1 { + t.Errorf("interrupts: got %d, wanted 1", r.interrupts) + } +} + +func TestMultipleInterruptsAfterEnable(t *testing.T) { + var ( + f Forwarder + r countingReceiver + ) + if !f.Enable(&r) { + t.Fatalf("Enable: got false, wanted true") + } + defer f.Disable() + f.NotifyInterrupt() + f.NotifyInterrupt() + if r.interrupts != 2 { + t.Errorf("interrupts: got %d, wanted 2", r.interrupts) + } +} diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD new file mode 100644 index 000000000..d902e344a --- /dev/null +++ b/pkg/sentry/platform/kvm/BUILD @@ -0,0 +1,90 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +go_template_instance( + name = "host_map_set", + out = "host_map_set.go", + consts = { + "minDegree": "15", + }, + imports = { + "usermem": "gvisor.googlesource.com/gvisor/pkg/sentry/usermem", + }, + package = "kvm", + prefix = "hostMap", + template = "//pkg/segment:generic_set", + types = { + "Key": "usermem.Addr", + "Range": "usermem.AddrRange", + "Value": "uintptr", + "Functions": "hostMapSetFunctions", + }, +) + +go_library( + name = "kvm", + srcs = [ + "address_space.go", + "bluepill.go", + "bluepill_amd64.go", + "bluepill_amd64.s", + "bluepill_amd64_unsafe.go", + "bluepill_fault.go", + "bluepill_unsafe.go", + "context.go", + "host_map.go", + "host_map_set.go", + "kvm.go", + "kvm_amd64.go", + "kvm_amd64_unsafe.go", + "kvm_const.go", + "machine.go", + "machine_amd64.go", + "machine_amd64_unsafe.go", + "machine_unsafe.go", + "physical_map.go", + "virtual_map.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/kvm", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/cpuid", + "//pkg/log", + "//pkg/sentry/arch", + "//pkg/sentry/platform", + "//pkg/sentry/platform/filemem", + "//pkg/sentry/platform/interrupt", + "//pkg/sentry/platform/procid", + "//pkg/sentry/platform/ring0", + "//pkg/sentry/platform/ring0/pagetables", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/time", + "//pkg/sentry/usermem", + "//pkg/tmutex", + ], +) + +go_test( + name = "kvm_test", + size = "small", + srcs = [ + "kvm_test.go", + "virtual_map_test.go", + ], + embed = [":kvm"], + tags = [ + "nogotsan", + "requires-kvm", + ], + deps = [ + "//pkg/sentry/arch", + "//pkg/sentry/platform", + "//pkg/sentry/platform/kvm/testutil", + "//pkg/sentry/platform/ring0", + "//pkg/sentry/platform/ring0/pagetables", + "//pkg/sentry/usermem", + ], +) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go new file mode 100644 index 000000000..791f038b0 --- /dev/null +++ b/pkg/sentry/platform/kvm/address_space.go @@ -0,0 +1,207 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "reflect" + "sync" + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/filemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// addressSpace is a wrapper for PageTables. +type addressSpace struct { + platform.NoAddressSpaceIO + + // filemem is the memory instance. + filemem *filemem.FileMem + + // machine is the underlying machine. + machine *machine + + // pageTables are for this particular address space. + pageTables *pagetables.PageTables + + // dirtySet is the set of dirty vCPUs. + // + // The key is the vCPU, the value is a shared uint32 pointer that + // indicates whether or not the context is clean. A zero here indicates + // that the context should be cleaned prior to re-entry. + dirtySet sync.Map + + // files contains files mapped in the host address space. + files hostMap +} + +// Invalidate interrupts all dirty contexts. +func (as *addressSpace) Invalidate() { + as.dirtySet.Range(func(key, value interface{}) bool { + c := key.(*vCPU) + v := value.(*uint32) + atomic.StoreUint32(v, 0) // Invalidation required. + c.Bounce() // Force a kernel transition. + return true // Keep iterating. + }) +} + +// Touch adds the given vCPU to the dirty list. +func (as *addressSpace) Touch(c *vCPU) *uint32 { + value, ok := as.dirtySet.Load(c) + if !ok { + value, _ = as.dirtySet.LoadOrStore(c, new(uint32)) + } + return value.(*uint32) +} + +func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) { + for m.length > 0 { + physical, length, ok := TranslateToPhysical(m.addr) + if !ok { + panic("unable to translate segment") + } + if length > m.length { + length = m.length + } + + // Ensure that this map has physical mappings. If the page does + // not have physical mappings, the KVM module may inject + // spurious exceptions when emulation fails (i.e. it tries to + // emulate because the RIP is pointed at those pages). + as.machine.mapPhysical(physical, length) + + // Install the page table mappings. Note that the ordering is + // important; if the pagetable mappings were installed before + // ensuring the physical pages were available, then some other + // thread could theoretically access them. + prev := as.pageTables.Map(addr, length, true /* user */, at, physical) + inv = inv || prev + m.addr += length + m.length -= length + addr += usermem.Addr(length) + } + + return inv +} + +func (as *addressSpace) mapHostFile(addr usermem.Addr, fd int, fr platform.FileRange, at usermem.AccessType) error { + // Create custom host mappings. + ms, err := as.files.CreateMappings(usermem.AddrRange{ + Start: addr, + End: addr + usermem.Addr(fr.End-fr.Start), + }, at, fd, fr.Start) + if err != nil { + return err + } + + inv := false + for _, m := range ms { + // The host mapped slices are guaranteed to be aligned. + inv = inv || as.mapHost(addr, m, at) + addr += usermem.Addr(m.length) + } + if inv { + as.Invalidate() + } + + return nil +} + +func (as *addressSpace) mapFilemem(addr usermem.Addr, fr platform.FileRange, at usermem.AccessType, precommit bool) error { + // TODO: Lock order at the platform level is not sufficiently + // well-defined to guarantee that the caller (FileMem.MapInto) is not + // holding any locks that FileMem.MapInternal may take. + + // Retrieve mappings for the underlying filemem. Note that the + // permissions here are largely irrelevant, since it corresponds to + // physical memory for the guest. We enforce the given access type + // below, in the guest page tables. + bs, err := as.filemem.MapInternal(fr, usermem.AccessType{ + Read: true, + Write: true, + }) + if err != nil { + return err + } + + // Save the original range for invalidation. + orig := usermem.AddrRange{ + Start: addr, + End: addr + usermem.Addr(fr.End-fr.Start), + } + + inv := false + for !bs.IsEmpty() { + b := bs.Head() + bs = bs.Tail() + // Since fr was page-aligned, b should also be page-aligned. We do the + // lookup in our host page tables for this translation. + s := b.ToSlice() + if precommit { + for i := 0; i < len(s); i += usermem.PageSize { + _ = s[i] // Touch to commit. + } + } + inv = inv || as.mapHost(addr, hostMapEntry{ + addr: reflect.ValueOf(&s[0]).Pointer(), + length: uintptr(len(s)), + }, at) + addr += usermem.Addr(len(s)) + } + if inv { + as.Invalidate() + as.files.DeleteMapping(orig) + } + + return nil +} + +// MapFile implements platform.AddressSpace.MapFile. +func (as *addressSpace) MapFile(addr usermem.Addr, fd int, fr platform.FileRange, at usermem.AccessType, precommit bool) error { + // Create an appropriate mapping. If this is filemem, we don't create + // custom mappings for each in-application mapping. For files however, + // we create distinct mappings for each address space. Unfortunately, + // there's not a better way to manage this here. The file underlying + // this fd can change at any time, so we can't actually index the file + // and share between address space. Oh well. It's all refering to the + // same physical pages, hopefully we don't run out of address space. + if fd != int(as.filemem.File().Fd()) { + // N.B. precommit is ignored for host files. + return as.mapHostFile(addr, fd, fr, at) + } + + return as.mapFilemem(addr, fr, at, precommit) +} + +// Unmap unmaps the given range by calling pagetables.PageTables.Unmap. +func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) { + if prev := as.pageTables.Unmap(addr, uintptr(length)); prev { + as.Invalidate() + as.files.DeleteMapping(usermem.AddrRange{ + Start: addr, + End: addr + usermem.Addr(length), + }) + } +} + +// Release releases the page tables. +func (as *addressSpace) Release() error { + as.Unmap(0, ^uint64(0)) + as.pageTables.Release() + return nil +} diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go new file mode 100644 index 000000000..ecc33d7dd --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -0,0 +1,41 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "fmt" + "reflect" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy" +) + +// bluepill enters guest mode. +func bluepill(*vCPU) + +// sighandler is the signal entry point. +func sighandler() + +// savedHandler is a pointer to the previous handler. +// +// This is called by bluepillHandler. +var savedHandler uintptr + +func init() { + // Install the handler. + if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err)) + } +} diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go new file mode 100644 index 000000000..a2baefb7d --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -0,0 +1,143 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" +) + +var ( + // bounceSignal is the signal used for bouncing KVM. + // + // We use SIGCHLD because it is not masked by the runtime, and + // it will be ignored properly by other parts of the kernel. + bounceSignal = syscall.SIGCHLD + + // bounceSignalMask has only bounceSignal set. + bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) + + // bounce is the interrupt vector used to return to the kernel. + bounce = uint32(ring0.VirtualizationException) +) + +// redpill on amd64 invokes a syscall with -1. +// +//go:nosplit +func redpill() { + syscall.RawSyscall(^uintptr(0), 0, 0, 0) +} + +// bluepillArchEnter is called during bluepillEnter. +// +//go:nosplit +func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { + c = vCPUPtr(uintptr(context.Rax)) + regs := c.CPU.Registers() + regs.R8 = context.R8 + regs.R9 = context.R9 + regs.R10 = context.R10 + regs.R11 = context.R11 + regs.R12 = context.R12 + regs.R13 = context.R13 + regs.R14 = context.R14 + regs.R15 = context.R15 + regs.Rdi = context.Rdi + regs.Rsi = context.Rsi + regs.Rbp = context.Rbp + regs.Rbx = context.Rbx + regs.Rdx = context.Rdx + regs.Rax = context.Rax + regs.Rcx = context.Rcx + regs.Rsp = context.Rsp + regs.Rip = context.Rip + regs.Eflags = context.Eflags + regs.Eflags &^= uint64(ring0.KernelFlagsClear) + regs.Eflags |= ring0.KernelFlagsSet + regs.Cs = uint64(ring0.Kcode) + regs.Ds = uint64(ring0.Udata) + regs.Es = uint64(ring0.Udata) + regs.Fs = uint64(ring0.Udata) + regs.Ss = uint64(ring0.Kdata) + + // ring0 uses GS exclusively, so we use GS_base to store the location + // of the floating point address. + // + // The address will be restored directly after running the VCPU, and + // will be saved again prior to halting. We rely on the fact that the + // SaveFloatingPointer/LoadFloatingPoint functions use the most + // efficient mechanism available (including compression) so the state + // size is guaranteed to be less than what's pointed to here. + regs.Gs_base = uint64(context.Fpstate) + return +} + +// bluepillSyscall handles kernel syscalls. +// +//go:nosplit +func bluepillSyscall() { + regs := ring0.Current().Registers() + if regs.Rax != ^uint64(0) { + regs.Rip -= 2 // Rewind. + } + ring0.SaveFloatingPoint(bytePtr(uintptr(regs.Gs_base))) + ring0.Halt() + ring0.LoadFloatingPoint(bytePtr(uintptr(regs.Gs_base))) +} + +// bluepillException handles kernel exceptions. +// +//go:nosplit +func bluepillException(vector ring0.Vector) { + regs := ring0.Current().Registers() + if vector == ring0.Vector(bounce) { + // These should not interrupt kernel execution; point the Rip + // to zero to ensure that we get a reasonable panic when we + // attempt to return. + regs.Rip = 0 + } + ring0.SaveFloatingPoint(bytePtr(uintptr(regs.Gs_base))) + ring0.Halt() + ring0.LoadFloatingPoint(bytePtr(uintptr(regs.Gs_base))) +} + +// bluepillArchExit is called during bluepillEnter. +// +//go:nosplit +func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { + regs := c.CPU.Registers() + context.R8 = regs.R8 + context.R9 = regs.R9 + context.R10 = regs.R10 + context.R11 = regs.R11 + context.R12 = regs.R12 + context.R13 = regs.R13 + context.R14 = regs.R14 + context.R15 = regs.R15 + context.Rdi = regs.Rdi + context.Rsi = regs.Rsi + context.Rbp = regs.Rbp + context.Rbx = regs.Rbx + context.Rdx = regs.Rdx + context.Rax = regs.Rax + context.Rcx = regs.Rcx + context.Rsp = regs.Rsp + context.Rip = regs.Rip + context.Eflags = regs.Eflags +} diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s new file mode 100644 index 000000000..0881bd5f5 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -0,0 +1,87 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// VCPU_CPU is the location of the CPU in the vCPU struct. +// +// This is guaranteed to be zero. +#define VCPU_CPU 0x0 + +// CPU_SELF is the self reference in ring0's percpu. +// +// This is guaranteed to be zero. +#define CPU_SELF 0x0 + +// Context offsets. +// +// Only limited use of the context is done in the assembly stub below, most is +// done in the Go handlers. However, the RIP must be examined. +#define CONTEXT_RAX 0x90 +#define CONTEXT_RIP 0xa8 +#define CONTEXT_FP 0xe0 + +// CLI is the literal byte for the disable interrupts instruction. +// +// This is checked as the source of the fault. +#define CLI $0xfa + +// See bluepill.go. +TEXT ·bluepill(SB),NOSPLIT,$0 +begin: + MOVQ vcpu+0(FP), AX + LEAQ VCPU_CPU(AX), BX + BYTE CLI; +check_vcpu: + MOVQ CPU_SELF(GS), CX + CMPQ BX, CX + JE right_vCPU +wrong_vcpu: + CALL ·redpill(SB) + JMP begin +right_vCPU: + RET + +// sighandler: see bluepill.go for documentation. +// +// The arguments are the following: +// +// DI - The signal number. +// SI - Pointer to siginfo_t structure. +// DX - Pointer to ucontext structure. +// +TEXT ·sighandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $0x80, CX + CMPL CX, 0x8(SI) + JNE fallback + + // Check if RIP is disable interrupts. + MOVQ CONTEXT_RIP(DX), CX + CMPQ CX, $0x0 + JE fallback + CMPB 0(CX), CLI + JNE fallback + + // Call the bluepillHandler. + PUSHQ DX // First argument (context). + CALL ·bluepillHandler(SB) // Call the handler. + POPQ DX // Discard the argument. + RET + +fallback: + // Jump to the previous signal handler. + XORQ CX, CX + MOVQ ·savedHandler(SB), AX + JMP AX diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go new file mode 100644 index 000000000..61ca61dcb --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" +) + +// bluepillArchContext returns the arch-specific context. +func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { + return &((*arch.UContext64)(context).MContext) +} diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go new file mode 100644 index 000000000..7c8c7bc37 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -0,0 +1,127 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "sync/atomic" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +const ( + // faultBlockSize is the size used for servicing memory faults. + // + // This should be large enough to avoid frequent faults and avoid using + // all available KVM slots (~512), but small enough that KVM does not + // complain about slot sizes (~4GB). See handleBluepillFault for how + // this block is used. + faultBlockSize = 2 << 30 + + // faultBlockMask is the mask for the fault blocks. + // + // This must be typed to avoid overflow complaints (ugh). + faultBlockMask = ^uintptr(faultBlockSize - 1) +) + +// yield yields the CPU. +// +//go:nosplit +func yield() { + syscall.RawSyscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) +} + +// calculateBluepillFault calculates the fault address range. +// +//go:nosplit +func calculateBluepillFault(m *machine, physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { + alignedPhysical := physical &^ uintptr(usermem.PageSize-1) + for _, pr := range physicalRegions { + end := pr.physical + pr.length + if physical < pr.physical || physical >= end { + continue + } + + // Adjust the block to match our size. + physicalStart = alignedPhysical & faultBlockMask + if physicalStart < pr.physical { + // Bound the starting point to the start of the region. + physicalStart = pr.physical + } + virtualStart = pr.virtual + (physicalStart - pr.physical) + physicalEnd := physicalStart + faultBlockSize + if physicalEnd > end { + physicalEnd = end + } + length = physicalEnd - physicalStart + return virtualStart, physicalStart, length, true + } + + return 0, 0, 0, false +} + +// handleBluepillFault handles a physical fault. +// +// The corresponding virtual address is returned. This may throw on error. +// +//go:nosplit +func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { + // Paging fault: we need to map the underlying physical pages for this + // fault. This all has to be done in this function because we're in a + // signal handler context. (We can't call any functions that might + // split the stack.) + virtualStart, physicalStart, length, ok := calculateBluepillFault(m, physical) + if !ok { + return 0, false + } + + // Set the KVM slot. + // + // First, we need to acquire the exclusive right to set a slot. See + // machine.nextSlot for information about the protocol. + slot := atomic.SwapUint32(&m.nextSlot, ^uint32(0)) + for slot == ^uint32(0) { + yield() // Race with another call. + slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) + } + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + if errno == 0 { + // Successfully added region; we can increment nextSlot and + // allow another set to proceed here. + atomic.StoreUint32(&m.nextSlot, slot+1) + return virtualStart + (physical - physicalStart), true + } + + // Release our slot (still available). + atomic.StoreUint32(&m.nextSlot, slot) + + switch errno { + case syscall.EEXIST: + // The region already exists. It's possible that we raced with + // another vCPU here. We just revert nextSlot and return true, + // because this must have been satisfied by some other vCPU. + return virtualStart + (physical - physicalStart), true + case syscall.EINVAL: + throw("set memory region failed; out of slots") + case syscall.ENOMEM: + throw("set memory region failed: out of memory") + case syscall.EFAULT: + throw("set memory region failed: invalid physical range") + default: + throw("set memory region failed: unknown reason") + } + + panic("unreachable") +} diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go new file mode 100644 index 000000000..85703ff18 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -0,0 +1,175 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "sync/atomic" + "syscall" + "unsafe" +) + +//go:linkname throw runtime.throw +func throw(string) + +// vCPUPtr returns a CPU for the given address. +// +//go:nosplit +func vCPUPtr(addr uintptr) *vCPU { + return (*vCPU)(unsafe.Pointer(addr)) +} + +// bytePtr returns a bytePtr for the given address. +// +//go:nosplit +func bytePtr(addr uintptr) *byte { + return (*byte)(unsafe.Pointer(addr)) +} + +// bluepillHandler is called from the signal stub. +// +// The world may be stopped while this is executing, and it executes on the +// signal stack. It should only execute raw system calls and functions that are +// explicitly marked go:nosplit. +// +//go:nosplit +func bluepillHandler(context unsafe.Pointer) { + // Sanitize the registers; interrupts must always be disabled. + c := bluepillArchEnter(bluepillArchContext(context)) + + // Increment the number of switches. + atomic.AddUint32(&c.switches, 1) + + // Store vCPUGuest. + // + // This is fine even if we're not in guest mode yet. In this signal + // handler, we'll already have all the relevant signals blocked, so an + // interrupt is only deliverable when we actually execute the KVM_RUN. + // + // The state will be returned to vCPUReady by Phase2. + if state := atomic.SwapUintptr(&c.state, vCPUGuest); state != vCPUReady { + throw("vCPU not in ready state") + } + + for { + _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) + if errno == syscall.EINTR { + // First, we process whatever pending signal + // interrupted KVM. Since we're in a signal handler + // currently, all signals are masked and the signal + // must have been delivered directly to this thread. + sig, _, errno := syscall.RawSyscall6( + syscall.SYS_RT_SIGTIMEDWAIT, + uintptr(unsafe.Pointer(&bounceSignalMask)), + 0, // siginfo. + 0, // timeout. + 8, // sigset size. + 0, 0) + if errno != 0 { + throw("error waiting for pending signal") + } + if sig != uintptr(bounceSignal) { + throw("unexpected signal") + } + + // Check whether the current state of the vCPU is ready + // for interrupt injection. Because we don't have a + // PIC, we can't inject an interrupt while they are + // masked. We need to request a window if it's not + // ready. + if c.runData.readyForInterruptInjection == 0 { + c.runData.requestInterruptWindow = 1 + continue // Rerun vCPU. + } else { + // Force injection below; the vCPU is ready. + c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN + } + } else if errno != 0 { + throw("run failed") + } + + switch c.runData.exitReason { + case _KVM_EXIT_EXCEPTION: + throw("exception") + case _KVM_EXIT_IO: + throw("I/O") + case _KVM_EXIT_INTERNAL_ERROR: + throw("internal error") + case _KVM_EXIT_HYPERCALL: + throw("hypercall") + case _KVM_EXIT_DEBUG: + throw("debug") + case _KVM_EXIT_HLT: + // Copy out registers. + bluepillArchExit(c, bluepillArchContext(context)) + + // Notify any waiters. + switch state := atomic.SwapUintptr(&c.state, vCPUReady); state { + case vCPUGuest: + case vCPUWaiter: + c.notify() // Safe from handler. + default: + throw("invalid state") + } + return + case _KVM_EXIT_MMIO: + // Increment the fault count. + atomic.AddUint32(&c.faults, 1) + + // For MMIO, the physical address is the first data item. + virtual, ok := handleBluepillFault(c.machine, uintptr(c.runData.data[0])) + if !ok { + throw("physical address not valid") + } + + // We now need to fill in the data appropriately. KVM + // expects us to provide the result of the given MMIO + // operation in the runData struct. This is safe + // because, if a fault occurs here, the same fault + // would have occurred in guest mode. The kernel should + // not create invalid page table mappings. + data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1])) + length := (uintptr)((uint32)(c.runData.data[2])) + write := (uint8)((c.runData.data[2] >> 32 & 0xff)) != 0 + for i := uintptr(0); i < length; i++ { + b := bytePtr(uintptr(virtual) + i) + if write { + // Write to the given address. + *b = data[i] + } else { + // Read from the given address. + data[i] = *b + } + } + case _KVM_EXIT_IRQ_WINDOW_OPEN: + // Interrupt: we must have requested an interrupt + // window; set the interrupt line. + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_INTERRUPT, + uintptr(unsafe.Pointer(&bounce))); errno != 0 { + throw("interrupt injection failed") + } + // Clear previous injection request. + c.runData.requestInterruptWindow = 0 + case _KVM_EXIT_SHUTDOWN: + throw("shutdown") + case _KVM_EXIT_FAIL_ENTRY: + throw("entry failed") + default: + throw("unknown failure") + } + } +} diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go new file mode 100644 index 000000000..fd04a2c47 --- /dev/null +++ b/pkg/sentry/platform/kvm/context.go @@ -0,0 +1,81 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/interrupt" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// context is an implementation of the platform context. +// +// This is a thin wrapper around the machine. +type context struct { + // machine is the parent machine, and is immutable. + machine *machine + + // interrupt is the interrupt context. + interrupt interrupt.Forwarder +} + +// Switch runs the provided context in the given address space. +func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) { + // Extract data. + localAS := as.(*addressSpace) + regs := &ac.StateData().Regs + fp := (*byte)(ac.FloatingPointData()) + + // Grab a vCPU. + cpu, err := c.machine.Get() + if err != nil { + return nil, usermem.NoAccess, err + } + + // Enable interrupts (i.e. calls to vCPU.Notify). + if !c.interrupt.Enable(cpu) { + c.machine.Put(cpu) // Already preempted. + return nil, usermem.NoAccess, platform.ErrContextInterrupt + } + + // Mark the address space as dirty. + flags := ring0.Flags(0) + dirty := localAS.Touch(cpu) + if v := atomic.SwapUint32(dirty, 1); v == 0 { + flags |= ring0.FlagFlush + } + if ac.FullRestore() { + flags |= ring0.FlagFull + } + + // Take the blue pill. + si, at, err := cpu.SwitchToUser(regs, fp, localAS.pageTables, flags) + + // Release resources. + c.machine.Put(cpu) + + // All done. + c.interrupt.Disable() + return si, at, err +} + +// Interrupt interrupts the running context. +func (c *context) Interrupt() { + c.interrupt.NotifyInterrupt() +} diff --git a/pkg/sentry/platform/kvm/host_map.go b/pkg/sentry/platform/kvm/host_map.go new file mode 100644 index 000000000..357f8c92e --- /dev/null +++ b/pkg/sentry/platform/kvm/host_map.go @@ -0,0 +1,168 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "fmt" + "sync" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +type hostMap struct { + // mu protects below. + mu sync.RWMutex + + // set contains host mappings. + set hostMapSet +} + +type hostMapEntry struct { + addr uintptr + length uintptr +} + +func (hm *hostMap) forEachEntry(r usermem.AddrRange, fn func(offset uint64, m hostMapEntry)) { + for seg := hm.set.FindSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + length := uintptr(seg.Range().Length()) + segOffset := uint64(0) // Adjusted below. + if seg.End() > r.End { + length -= uintptr(seg.End() - r.End) + } + if seg.Start() < r.Start { + length -= uintptr(r.Start - seg.Start()) + } else { + segOffset = uint64(seg.Start() - r.Start) + } + fn(segOffset, hostMapEntry{ + addr: seg.Value(), + length: length, + }) + } +} + +func (hm *hostMap) createMappings(r usermem.AddrRange, at usermem.AccessType, fd int, offset uint64) (ms []hostMapEntry, err error) { + // Replace any existing mappings. + hm.forEachEntry(r, func(segOffset uint64, m hostMapEntry) { + _, _, errno := syscall.RawSyscall6( + syscall.SYS_MMAP, + m.addr, + m.length, + uintptr(at.Prot()), + syscall.MAP_FIXED|syscall.MAP_SHARED, + uintptr(fd), + uintptr(offset+segOffset)) + if errno != 0 && err == nil { + err = errno + } + }) + if err != nil { + return nil, err + } + + // Add in necessary new mappings. + for gap := hm.set.FindGap(r.Start); gap.Ok() && gap.Start() < r.End; { + length := uintptr(gap.Range().Length()) + gapOffset := uint64(0) // Adjusted below. + if gap.End() > r.End { + length -= uintptr(gap.End() - r.End) + } + if gap.Start() < r.Start { + length -= uintptr(r.Start - gap.Start()) + } else { + gapOffset = uint64(gap.Start() - r.Start) + } + + // Map the host file memory. + hostAddr, _, errno := syscall.RawSyscall6( + syscall.SYS_MMAP, + 0, + length, + uintptr(at.Prot()), + syscall.MAP_SHARED, + uintptr(fd), + uintptr(offset+gapOffset)) + if errno != 0 { + return nil, errno + } + + // Insert into the host set and move to the next gap. + gap = hm.set.Insert(gap, gap.Range().Intersect(r), hostAddr).NextGap() + } + + // Collect all slices. + hm.forEachEntry(r, func(_ uint64, m hostMapEntry) { + ms = append(ms, m) + }) + + return ms, nil +} + +// CreateMappings creates a new set of host mapping entries. +func (hm *hostMap) CreateMappings(r usermem.AddrRange, at usermem.AccessType, fd int, offset uint64) (ms []hostMapEntry, err error) { + hm.mu.Lock() + ms, err = hm.createMappings(r, at, fd, offset) + hm.mu.Unlock() + return +} + +func (hm *hostMap) deleteMapping(r usermem.AddrRange) { + // Remove all the existing mappings. + hm.forEachEntry(r, func(_ uint64, m hostMapEntry) { + _, _, errno := syscall.RawSyscall( + syscall.SYS_MUNMAP, + m.addr, + m.length, + 0) + if errno != 0 { + // Should never happen. + panic(fmt.Sprintf("unmap error: %v", errno)) + } + }) + + // Knock the range out. + hm.set.RemoveRange(r) +} + +// DeleteMapping deletes the given range. +func (hm *hostMap) DeleteMapping(r usermem.AddrRange) { + hm.mu.Lock() + hm.deleteMapping(r) + hm.mu.Unlock() +} + +// hostMapSetFunctions is used in the implementation of mapSet. +type hostMapSetFunctions struct{} + +func (hostMapSetFunctions) MinKey() usermem.Addr { return 0 } +func (hostMapSetFunctions) MaxKey() usermem.Addr { return ^usermem.Addr(0) } +func (hostMapSetFunctions) ClearValue(val *uintptr) { *val = 0 } + +func (hostMapSetFunctions) Merge(r1 usermem.AddrRange, addr1 uintptr, r2 usermem.AddrRange, addr2 uintptr) (uintptr, bool) { + if addr1+uintptr(r1.Length()) != addr2 { + return 0, false + } + + // Since the two regions are contiguous in both the key space and the + // value space, we can just store a single segment with the first host + // virtual address; the logic above operates based on the size of the + // segments. + return addr1, true +} + +func (hostMapSetFunctions) Split(r usermem.AddrRange, hostAddr uintptr, split usermem.Addr) (uintptr, uintptr) { + return hostAddr, hostAddr + uintptr(split-r.Start) +} diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go new file mode 100644 index 000000000..31928c9f0 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm.go @@ -0,0 +1,149 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package kvm provides a kvm-based implementation of the platform interface. +package kvm + +import ( + "fmt" + "runtime" + "sync" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/cpuid" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/filemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// KVM represents a lightweight VM context. +type KVM struct { + platform.NoCPUPreemptionDetection + + // filemem is our memory source. + *filemem.FileMem + + // machine is the backing VM. + machine *machine +} + +var ( + globalOnce sync.Once + globalErr error +) + +// New returns a new KVM-based implementation of the platform interface. +func New() (*KVM, error) { + // Allocate physical memory for the vCPUs. + fm, err := filemem.New("kvm-memory") + if err != nil { + return nil, err + } + + // Try opening KVM. + fd, err := syscall.Open("/dev/kvm", syscall.O_RDWR, 0) + if err != nil { + return nil, fmt.Errorf("opening /dev/kvm: %v", err) + } + defer syscall.Close(fd) + + // Ensure global initialization is done. + globalOnce.Do(func() { + physicalInit() + globalErr = updateSystemValues(fd) + ring0.Init(cpuid.HostFeatureSet()) + }) + if globalErr != nil { + return nil, err + } + + // Create a new VM fd. + vm, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_CREATE_VM, 0) + if errno != 0 { + return nil, fmt.Errorf("creating VM: %v", errno) + } + + // Create a VM context. + machine, err := newMachine(int(vm), runtime.NumCPU()) + if err != nil { + return nil, err + } + + // All set. + return &KVM{ + FileMem: fm, + machine: machine, + }, nil +} + +// SupportsAddressSpaceIO implements platform.Platform.SupportsAddressSpaceIO. +func (*KVM) SupportsAddressSpaceIO() bool { + return false +} + +// CooperativelySchedulesAddressSpace implements platform.Platform.CooperativelySchedulesAddressSpace. +func (*KVM) CooperativelySchedulesAddressSpace() bool { + return false +} + +// MapUnit implements platform.Platform.MapUnit. +func (*KVM) MapUnit() uint64 { + // We greedily creates PTEs in MapFile, so extremely large mappings can + // be expensive. Not _that_ expensive since we allow super pages, but + // even though can get out of hand if you're creating multi-terabyte + // mappings. For this reason, we limit mappings to an arbitrary 16MB. + return 16 << 20 +} + +// MinUserAddress returns the lowest available address. +func (*KVM) MinUserAddress() usermem.Addr { + return usermem.PageSize +} + +// MaxUserAddress returns the first address that may not be used. +func (*KVM) MaxUserAddress() usermem.Addr { + return usermem.Addr(ring0.MaximumUserAddress) +} + +// NewAddressSpace returns a new pagetable root. +func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { + // Allocate page tables and install system mappings. + pageTables := k.machine.kernel.PageTables.New() + applyPhysicalRegions(func(pr physicalRegion) bool { + // Map the kernel in the upper half. + kernelVirtual := usermem.Addr(ring0.KernelStartAddress | pr.virtual) + pageTables.Map(kernelVirtual, pr.length, false /* kernel */, usermem.AnyAccess, pr.physical) + return true // Keep iterating. + }) + + // Return the new address space. + return &addressSpace{ + filemem: k.FileMem, + machine: k.machine, + pageTables: pageTables, + }, nil, nil +} + +// NewContext returns an interruptible context. +func (k *KVM) NewContext() platform.Context { + return &context{ + machine: k.machine, + } +} + +// Memory returns the platform memory used to do allocations. +func (k *KVM) Memory() platform.Memory { + return k.FileMem +} diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go new file mode 100644 index 000000000..3d56ed895 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -0,0 +1,213 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" +) + +// userMemoryRegion is a region of physical memory. +// +// This mirrors kvm_memory_region. +type userMemoryRegion struct { + slot uint32 + flags uint32 + guestPhysAddr uint64 + memorySize uint64 + userspaceAddr uint64 +} + +// userRegs represents KVM user registers. +// +// This mirrors kvm_regs. +type userRegs struct { + RAX uint64 + RBX uint64 + RCX uint64 + RDX uint64 + RSI uint64 + RDI uint64 + RSP uint64 + RBP uint64 + R8 uint64 + R9 uint64 + R10 uint64 + R11 uint64 + R12 uint64 + R13 uint64 + R14 uint64 + R15 uint64 + RIP uint64 + RFLAGS uint64 +} + +// systemRegs represents KVM system registers. +// +// This mirrors kvm_sregs. +type systemRegs struct { + CS segment + DS segment + ES segment + FS segment + GS segment + SS segment + TR segment + LDT segment + GDT descriptor + IDT descriptor + CR0 uint64 + CR2 uint64 + CR3 uint64 + CR4 uint64 + CR8 uint64 + EFER uint64 + apicBase uint64 + interruptBitmap [(_KVM_NR_INTERRUPTS + 63) / 64]uint64 +} + +// segment is the expanded form of a segment register. +// +// This mirrors kvm_segment. +type segment struct { + base uint64 + limit uint32 + selector uint16 + typ uint8 + present uint8 + DPL uint8 + DB uint8 + S uint8 + L uint8 + G uint8 + AVL uint8 + unusable uint8 + _ uint8 +} + +// Clear clears the segment and marks it unusable. +func (s *segment) Clear() { + *s = segment{unusable: 1} +} + +// selector is a segment selector. +type selector uint16 + +// tobool is a simple helper. +func tobool(x ring0.SegmentDescriptorFlags) uint8 { + if x != 0 { + return 1 + } + return 0 +} + +// Load loads the segment described by d into the segment s. +// +// The argument sel is recorded as the segment selector index. +func (s *segment) Load(d *ring0.SegmentDescriptor, sel ring0.Selector) { + flag := d.Flags() + if flag&ring0.SegmentDescriptorPresent == 0 { + s.Clear() + return + } + s.base = uint64(d.Base()) + s.limit = d.Limit() + s.typ = uint8((flag>>8)&0xF) | 1 + s.S = tobool(flag & ring0.SegmentDescriptorSystem) + s.DPL = uint8(d.DPL()) + s.present = tobool(flag & ring0.SegmentDescriptorPresent) + s.AVL = tobool(flag & ring0.SegmentDescriptorAVL) + s.L = tobool(flag & ring0.SegmentDescriptorLong) + s.DB = tobool(flag & ring0.SegmentDescriptorDB) + s.G = tobool(flag & ring0.SegmentDescriptorG) + if s.L != 0 { + s.limit = 0xffffffff + } + s.unusable = 0 + s.selector = uint16(sel) +} + +// descriptor describes a region of physical memory. +// +// It corresponds to the pseudo-descriptor used in the x86 LGDT and LIDT +// instructions, and mirrors kvm_dtable. +type descriptor struct { + base uint64 + limit uint16 + _ [3]uint16 +} + +// modelControlRegister is an MSR entry. +// +// This mirrors kvm_msr_entry. +type modelControlRegister struct { + index uint32 + _ uint32 + data uint64 +} + +// modelControlRegisers is a collection of MSRs. +// +// This mirrors kvm_msrs. +type modelControlRegisters struct { + nmsrs uint32 + _ uint32 + entries [16]modelControlRegister +} + +// runData is the run structure. This may be mapped for synchronous register +// access (although that doesn't appear to be supported by my kernel at least). +// +// This mirrors kvm_run. +type runData struct { + requestInterruptWindow uint8 + _ [7]uint8 + + exitReason uint32 + readyForInterruptInjection uint8 + ifFlag uint8 + _ [2]uint8 + + cr8 uint64 + apicBase uint64 + + // This is the union data for exits. Interpretation depends entirely on + // the exitReason above (see vCPU code for more information). + data [32]uint64 +} + +// cpuidEntry is a single CPUID entry. +// +// This mirrors kvm_cpuid_entry2. +type cpuidEntry struct { + function uint32 + index uint32 + flags uint32 + eax uint32 + ebx uint32 + ecx uint32 + edx uint32 + _ [3]uint32 +} + +// cpuidEntries is a collection of CPUID entries. +// +// This mirrors kvm_cpuid2. +type cpuidEntries struct { + nr uint32 + _ uint32 + entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry +} diff --git a/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go b/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go new file mode 100644 index 000000000..389412d87 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go @@ -0,0 +1,93 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "fmt" + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +var ( + runDataSize int + hasGuestPCID bool + hasGuestINVPCID bool + pagetablesOpts pagetables.Opts + cpuidSupported = cpuidEntries{nr: _KVM_NR_CPUID_ENTRIES} +) + +func updateSystemValues(fd int) error { + // Extract the mmap size. + sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0) + if errno != 0 { + return fmt.Errorf("getting VCPU mmap size: %v", errno) + } + + // Save the data. + runDataSize = int(sz) + + // Must do the dance to figure out the number of entries. + _, _, errno = syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(fd), + _KVM_GET_SUPPORTED_CPUID, + uintptr(unsafe.Pointer(&cpuidSupported))) + if errno != 0 && errno != syscall.ENOMEM { + // Some other error occurred. + return fmt.Errorf("getting supported CPUID: %v", errno) + } + + // The number should now be correct. + _, _, errno = syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(fd), + _KVM_GET_SUPPORTED_CPUID, + uintptr(unsafe.Pointer(&cpuidSupported))) + if errno != 0 { + // Didn't work with the right number. + return fmt.Errorf("getting supported CPUID (2nd attempt): %v", errno) + } + + // Calculate whether guestPCID is supported. + // + // FIXME: These should go through the much more pleasant + // cpuid package interfaces, once a way to accept raw kvm CPUID entries + // is plumbed (or some rough equivalent). + for i := 0; i < int(cpuidSupported.nr); i++ { + entry := cpuidSupported.entries[i] + if entry.function == 1 && entry.index == 0 && entry.ecx&(1<<17) != 0 { + hasGuestPCID = true // Found matching PCID in guest feature set. + } + if entry.function == 7 && entry.index == 0 && entry.ebx&(1<<10) != 0 { + hasGuestINVPCID = true // Found matching INVPCID in guest feature set. + } + } + + // A basic sanity check: ensure that we don't attempt to + // invpcid if guest PCIDs are not supported; it's not clear + // what the semantics of this would be (or why some CPU or + // hypervisor would export this particular combination). + hasGuestINVPCID = hasGuestPCID && hasGuestINVPCID + + // Set the pagetables to use PCID if it's available. + pagetablesOpts.EnablePCID = hasGuestPCID + + // Success. + return nil +} diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go new file mode 100644 index 000000000..0ec6a4a00 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -0,0 +1,56 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// KVM ioctls. +// +// Only the ioctls we need in Go appear here; some additional ioctls are used +// within the assembly stubs (KVM_INTERRUPT, etc.). +const ( + _KVM_CREATE_VM = 0xae01 + _KVM_GET_VCPU_MMAP_SIZE = 0xae04 + _KVM_CREATE_VCPU = 0xae41 + _KVM_SET_TSS_ADDR = 0xae47 + _KVM_RUN = 0xae80 + _KVM_INTERRUPT = 0x4004ae86 + _KVM_SET_MSRS = 0x4008ae89 + _KVM_SET_USER_MEMORY_REGION = 0x4020ae46 + _KVM_SET_REGS = 0x4090ae82 + _KVM_SET_SREGS = 0x4138ae84 + _KVM_GET_SUPPORTED_CPUID = 0xc008ae05 + _KVM_SET_CPUID2 = 0x4008ae90 + _KVM_SET_SIGNAL_MASK = 0x4004ae8b +) + +// KVM exit reasons. +const ( + _KVM_EXIT_EXCEPTION = 0x1 + _KVM_EXIT_IO = 0x2 + _KVM_EXIT_HYPERCALL = 0x3 + _KVM_EXIT_DEBUG = 0x4 + _KVM_EXIT_HLT = 0x5 + _KVM_EXIT_MMIO = 0x6 + _KVM_EXIT_IRQ_WINDOW_OPEN = 0x7 + _KVM_EXIT_SHUTDOWN = 0x8 + _KVM_EXIT_FAIL_ENTRY = 0x9 + _KVM_EXIT_INTERNAL_ERROR = 0x11 +) + +// KVM limits. +const ( + _KVM_NR_VCPUS = 0x100 + _KVM_NR_INTERRUPTS = 0x100 + _KVM_NR_CPUID_ENTRIES = 0x100 +) diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go new file mode 100644 index 000000000..61cfdd8fd --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -0,0 +1,415 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "math/rand" + "reflect" + "syscall" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/kvm/testutil" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +var dummyFPState = (*byte)(arch.NewFloatingPointData()) + +type testHarness interface { + Errorf(format string, args ...interface{}) + Fatalf(format string, args ...interface{}) +} + +func kvmTest(t testHarness, setup func(*KVM), fn func(*vCPU) bool) { + // Create the machine. + k, err := New() + if err != nil { + t.Fatalf("error creating KVM instance: %v", err) + } + defer k.machine.Destroy() + defer k.FileMem.Destroy() + + // Call additional setup. + if setup != nil { + setup(k) + } + + var c *vCPU // For recovery. + defer func() { + redpill() + if c != nil { + k.machine.Put(c) + } + }() + for { + c, err = k.machine.Get() + if err != nil { + t.Fatalf("error getting vCPU: %v", err) + } + if !fn(c) { + break + } + + // We put the vCPU here and clear the value so that the + // deferred recovery will not re-put it above. + k.machine.Put(c) + c = nil + } +} + +func bluepillTest(t testHarness, fn func(*vCPU)) { + kvmTest(t, nil, func(c *vCPU) bool { + bluepill(c) + fn(c) + return false + }) +} + +func TestKernelSyscall(t *testing.T) { + bluepillTest(t, func(c *vCPU) { + redpill() // Leave guest mode. + if got := c.State(); got != vCPUReady { + t.Errorf("vCPU not in ready state: got %v", got) + } + }) +} + +func hostFault() { + defer func() { + recover() + }() + var foo *int + *foo = 0 +} + +func TestKernelFault(t *testing.T) { + hostFault() // Ensure recovery works. + bluepillTest(t, func(c *vCPU) { + hostFault() + if got := c.State(); got != vCPUReady { + t.Errorf("vCPU not in ready state: got %v", got) + } + }) +} + +func TestKernelFloatingPoint(t *testing.T) { + bluepillTest(t, func(c *vCPU) { + if !testutil.FloatingPointWorks() { + t.Errorf("floating point does not work, and it should!") + } + }) +} + +func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *syscall.PtraceRegs, *pagetables.PageTables) bool) { + // Initialize registers & page tables. + var ( + regs syscall.PtraceRegs + pt *pagetables.PageTables + ) + testutil.SetTestTarget(®s, target) + defer func() { + if pt != nil { + pt.Release() + } + }() + + kvmTest(t, func(k *KVM) { + // Create new page tables. + as, _, err := k.NewAddressSpace(nil /* invalidator */) + if err != nil { + t.Fatalf("can't create new address space: %v", err) + } + pt = as.(*addressSpace).pageTables + + if useHostMappings { + // Apply the physical mappings to these page tables. + // (This is normally dangerous, since they point to + // physical pages that may not exist. This shouldn't be + // done for regular user code, but is fine for test + // purposes.) + applyPhysicalRegions(func(pr physicalRegion) bool { + pt.Map(usermem.Addr(pr.virtual), pr.length, true /* user */, usermem.AnyAccess, pr.physical) + return true // Keep iterating. + }) + } + }, func(c *vCPU) bool { + // Invoke the function with the extra data. + return fn(c, ®s, pt) + }) +} + +func TestApplicationSyscall(t *testing.T) { + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, ring0.FlagFull); err != nil { + t.Errorf("application syscall with full restore failed: %v", err) + } + return false + }) + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != nil { + t.Errorf("application syscall with partial restore failed: %v", err) + } + return false + }) +} + +func TestApplicationFault(t *testing.T) { + applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + testutil.SetTouchTarget(regs, nil) // Cause fault. + if si, _, err := c.SwitchToUser(regs, dummyFPState, pt, ring0.FlagFull); err != platform.ErrContextSignal || (si != nil && si.Signo != int32(syscall.SIGSEGV)) { + t.Errorf("application fault with full restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) + } + return false + }) + applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + testutil.SetTouchTarget(regs, nil) // Cause fault. + if si, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != platform.ErrContextSignal || (si != nil && si.Signo != int32(syscall.SIGSEGV)) { + t.Errorf("application fault with partial restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) + } + return false + }) +} + +func TestRegistersSyscall(t *testing.T) { + applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + testutil.SetTestRegs(regs) // Fill values for all registers. + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != nil { + t.Errorf("application register check with partial restore got unexpected error: %v", err) + } + if err := testutil.CheckTestRegs(regs, false); err != nil { + t.Errorf("application register check with partial restore failed: %v", err) + } + return false + }) +} + +func TestRegistersFault(t *testing.T) { + applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + testutil.SetTestRegs(regs) // Fill values for all registers. + if si, _, err := c.SwitchToUser(regs, dummyFPState, pt, ring0.FlagFull); err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { + t.Errorf("application register check with full restore got unexpected error: %v", err) + } + if err := testutil.CheckTestRegs(regs, true); err != nil { + t.Errorf("application register check with full restore failed: %v", err) + } + return false + }) +} + +func TestSegments(t *testing.T) { + applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + testutil.SetTestSegments(regs) + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, ring0.FlagFull); err != nil { + t.Errorf("application segment check with full restore got unexpected error: %v", err) + } + if err := testutil.CheckTestSegments(regs); err != nil { + t.Errorf("application segment check with full restore failed: %v", err) + } + return false + }) +} + +func TestBounce(t *testing.T) { + applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + go func() { + time.Sleep(time.Millisecond) + c.Bounce() + }() + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != platform.ErrContextInterrupt { + t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) + } + return false + }) + applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + go func() { + time.Sleep(time.Millisecond) + c.Bounce() + }() + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, ring0.FlagFull); err != platform.ErrContextInterrupt { + t.Errorf("application full restore: got %v, wanted %v", err, platform.ErrContextInterrupt) + } + return false + }) +} + +func TestBounceStress(t *testing.T) { + applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + randomSleep := func() { + // O(hundreds of microseconds) is appropriate to ensure + // different overlaps and different schedules. + if n := rand.Intn(1000); n > 100 { + time.Sleep(time.Duration(n) * time.Microsecond) + } + } + for i := 0; i < 1000; i++ { + // Start an asynchronously executing goroutine that + // calls Bounce at pseudo-random point in time. + // This should wind up calling Bounce when the + // kernel is in various stages of the switch. + go func() { + randomSleep() + c.Bounce() + }() + randomSleep() + // Execute the switch. + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != platform.ErrContextInterrupt { + t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) + } + // Simulate work. + c.Unlock() + randomSleep() + c.Lock() + } + return false + }) +} + +func TestInvalidate(t *testing.T) { + var data uintptr // Used below. + applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + testutil.SetTouchTarget(regs, &data) // Read legitimate value. + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != nil { + t.Errorf("application partial restore: got %v, wanted nil", err) + } + // Unmap the page containing data & invalidate. + pt.Unmap(usermem.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(usermem.PageSize-1)), usermem.PageSize) + c.Invalidate() // Ensure invalidation. + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != platform.ErrContextSignal { + t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextSignal) + } + return false + }) +} + +// IsFault returns true iff the given signal represents a fault. +func IsFault(err error, si *arch.SignalInfo) bool { + return err == platform.ErrContextSignal && si.Signo == int32(syscall.SIGSEGV) +} + +func TestEmptyAddressSpace(t *testing.T) { + applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + if si, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); !IsFault(err, si) { + t.Errorf("first fault with partial restore failed got %v", err) + t.Logf("registers: %#v", ®s) + } + return false + }) + applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + if si, _, err := c.SwitchToUser(regs, dummyFPState, pt, ring0.FlagFull); !IsFault(err, si) { + t.Errorf("first fault with full restore failed got %v", err) + t.Logf("registers: %#v", ®s) + } + return false + }) +} + +func TestWrongVCPU(t *testing.T) { + kvmTest(t, nil, func(c1 *vCPU) bool { + kvmTest(t, nil, func(c2 *vCPU) bool { + // Basic test, one then the other. + bluepill(c1) + bluepill(c2) + if c2.switches == 0 { + // Don't allow the test to proceed if this fails. + t.Fatalf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) + } + + // Alternate vCPUs; we expect to need to trigger the + // wrong vCPU path on each switch. + for i := 0; i < 100; i++ { + bluepill(c1) + bluepill(c2) + } + if count := c1.switches; count < 90 { + t.Errorf("wrong vCPU#1 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) + } + if count := c2.switches; count < 90 { + t.Errorf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) + } + return false + }) + return false + }) + kvmTest(t, nil, func(c1 *vCPU) bool { + kvmTest(t, nil, func(c2 *vCPU) bool { + bluepill(c1) + bluepill(c2) + return false + }) + return false + }) +} + +func BenchmarkApplicationSyscall(b *testing.B) { + var ( + i int // Iteration includes machine.Get() / machine.Put(). + a int // Count for ErrContextInterrupt. + ) + applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != nil { + if err == platform.ErrContextInterrupt { + a++ + return true // Ignore. + } + b.Fatalf("benchmark failed: %v", err) + } + i++ + return i < b.N + }) + if a != 0 { + b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) + } +} + +func BenchmarkKernelSyscall(b *testing.B) { + // Note that the target passed here is irrelevant, we never execute SwitchToUser. + applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + // iteration does not include machine.Get() / machine.Put(). + for i := 0; i < b.N; i++ { + testutil.Getpid() + } + return false + }) +} + +func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { + // see BenchmarkApplicationSyscall. + var ( + i int + a int + ) + applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + if _, _, err := c.SwitchToUser(regs, dummyFPState, pt, 0); err != nil { + if err == platform.ErrContextInterrupt { + a++ + return true // Ignore. + } + b.Fatalf("benchmark failed: %v", err) + } + // This will intentionally cause the world switch. By executing + // a host syscall here, we force the transition between guest + // and host mode. + testutil.Getpid() + i++ + return i < b.N + }) + if a != 0 { + b.Logf("EAGAIN occurred %d times (in %d iterations).", a, a+i) + } +} diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go new file mode 100644 index 000000000..a5be0cee3 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine.go @@ -0,0 +1,412 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" + "gvisor.googlesource.com/gvisor/pkg/tmutex" +) + +// machine contains state associated with the VM as a whole. +type machine struct { + // fd is the vm fd. + fd int + + // nextSlot is the next slot for setMemoryRegion. + // + // This must be accessed atomically. If nextSlot is ^uint32(0), then + // slots are currently being updated, and the caller should retry. + nextSlot uint32 + + // kernel is the set of global structures. + kernel *ring0.Kernel + + // mappingCache is used for mapPhysical. + mappingCache sync.Map + + // mu protects vCPUs. + mu sync.Mutex + + // vCPUs are the machine vCPUs. + // + // This is eventually keyed by system TID, but is initially indexed by + // the negative vCPU id. This is merely an optimization, so while + // collisions here are not possible, it wouldn't matter anyways. + vCPUs map[uint64]*vCPU +} + +const ( + // vCPUReady is the lock value for an available vCPU. + // + // Legal transitions: vCPUGuest (bluepill). + vCPUReady uintptr = iota + + // vCPUGuest indicates the vCPU is in guest mode. + // + // Legal transition: vCPUReady (bluepill), vCPUWaiter (wait). + vCPUGuest + + // vCPUWaiter indicates that the vCPU should be released. + // + // Legal transition: vCPUReady (bluepill). + vCPUWaiter +) + +// vCPU is a single KVM vCPU. +type vCPU struct { + // CPU is the kernel CPU data. + // + // This must be the first element of this structure, it is referenced + // by the bluepill code (see bluepill_amd64.s). + ring0.CPU + + // fd is the vCPU fd. + fd int + + // tid is the last set tid. + tid uint64 + + // switches is a count of world switches (informational only). + switches uint32 + + // faults is a count of world faults (informational only). + faults uint32 + + // state is the vCPU state; all are described above. + state uintptr + + // runData for this vCPU. + runData *runData + + // machine associated with this vCPU. + machine *machine + + // mu applies across get/put; it does not protect the above. + mu tmutex.Mutex +} + +// newMachine returns a new VM context. +func newMachine(vm int, vCPUs int) (*machine, error) { + // Create the machine. + m := &machine{ + fd: vm, + vCPUs: make(map[uint64]*vCPU), + } + if vCPUs > _KVM_NR_VCPUS { + // Hard cap at KVM's limit. + vCPUs = _KVM_NR_VCPUS + } + if n := 2 * runtime.NumCPU(); vCPUs > n { + // Cap at twice the number of physical cores. Otherwise we're + // just wasting memory and thrashing. (There may be scheduling + // issues when you've got > n active threads.) + vCPUs = n + } + m.kernel = ring0.New(ring0.KernelOpts{ + PageTables: pagetables.New(m, pagetablesOpts), + }) + + // Initialize architecture state. + if err := m.initArchState(vCPUs); err != nil { + m.Destroy() + return nil, err + } + + // Create all the vCPUs. + for id := 0; id < vCPUs; id++ { + // Create the vCPU. + fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(vm), _KVM_CREATE_VCPU, uintptr(id)) + if errno != 0 { + m.Destroy() + return nil, fmt.Errorf("error creating VCPU: %v", errno) + } + c := &vCPU{ + fd: int(fd), + machine: m, + } + c.mu.Init() + c.CPU.Init(m.kernel) + c.CPU.KernelSyscall = bluepillSyscall + c.CPU.KernelException = bluepillException + m.vCPUs[uint64(-id)] = c // See above. + + // Ensure the signal mask is correct. + if err := c.setSignalMask(); err != nil { + m.Destroy() + return nil, err + } + + // Initialize architecture state. + if err := c.initArchState(); err != nil { + m.Destroy() + return nil, err + } + + // Map the run data. + runData, err := mapRunData(int(fd)) + if err != nil { + m.Destroy() + return nil, err + } + c.runData = runData + } + + // Apply the physical mappings. Note that these mappings may point to + // guest physical addresses that are not actually available. These + // physical pages are mapped on demand, see kernel_unsafe.go. + applyPhysicalRegions(func(pr physicalRegion) bool { + // Map everything in the lower half. + m.kernel.PageTables.Map(usermem.Addr(pr.virtual), pr.length, false /* kernel */, usermem.AnyAccess, pr.physical) + // And keep everything in the upper half. + kernelAddr := usermem.Addr(ring0.KernelStartAddress | pr.virtual) + m.kernel.PageTables.Map(kernelAddr, pr.length, false /* kernel */, usermem.AnyAccess, pr.physical) + return true // Keep iterating. + }) + + // Ensure that the currently mapped virtual regions are actually + // available in the VM. Note that this doesn't guarantee no future + // faults, however it should guarantee that everything is available to + // ensure successful vCPU entry. + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { + physical, length, ok := TranslateToPhysical(virtual) + if !ok { + // This must be an invalid region that was + // knocked out by creation of the physical map. + return + } + if virtual+length > vr.virtual+vr.length { + // Cap the length to the end of the area. + length = vr.virtual + vr.length - virtual + } + + // Ensure the physical range is mapped. + m.mapPhysical(physical, length) + virtual += length + } + }) + + // Ensure the machine is cleaned up properly. + runtime.SetFinalizer(m, (*machine).Destroy) + return m, nil +} + +// mapPhysical checks for the mapping of a physical range, and installs one if +// not available. This attempts to be efficient for calls in the hot path. +// +// This panics on error. +func (m *machine) mapPhysical(physical, length uintptr) { + for end := physical + length; physical < end; { + _, physicalStart, length, ok := calculateBluepillFault(m, physical) + if !ok { + // Should never happen. + panic("mapPhysical on unknown physical address") + } + + if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { + // Not present in the cache; requires setting the slot. + if _, ok := handleBluepillFault(m, physical); !ok { + panic("handleBluepillFault failed") + } + } + + // Move to the next chunk. + physical = physicalStart + length + } +} + +// Destroy frees associated resources. +// +// Destroy should only be called once all active users of the machine are gone. +// The machine object should not be used after calling Destroy. +// +// Precondition: all vCPUs must be returned to the machine. +func (m *machine) Destroy() { + runtime.SetFinalizer(m, nil) + + // Destroy vCPUs. + for _, c := range m.vCPUs { + // Ensure the vCPU is not still running in guest mode. This is + // possible iff teardown has been done by other threads, and + // somehow a single thread has not executed any system calls. + c.wait() + + // Teardown the vCPU itself. + switch state := c.State(); state { + case vCPUReady: + // Note that the runData may not be mapped if an error + // occurs during the middle of initialization. + if c.runData != nil { + if err := unmapRunData(c.runData); err != nil { + panic(fmt.Sprintf("error unmapping rundata: %v", err)) + } + } + if err := syscall.Close(int(c.fd)); err != nil { + panic(fmt.Sprintf("error closing vCPU fd: %v", err)) + } + case vCPUGuest, vCPUWaiter: + // Should never happen; waited above. + panic("vCPU disposed in guest state") + default: + // Should never happen; not a valid state. + panic(fmt.Sprintf("vCPU in invalid state: %v", state)) + } + } + + // Release host mappings. + if m.kernel.PageTables != nil { + m.kernel.PageTables.Release() + } + + // vCPUs are gone: teardown machine state. + if err := syscall.Close(m.fd); err != nil { + panic(fmt.Sprintf("error closing VM fd: %v", err)) + } +} + +// Get gets an available vCPU. +func (m *machine) Get() (*vCPU, error) { + runtime.LockOSThread() + tid := procid.Current() + m.mu.Lock() + + for { + // Check for an exact match. + if c := m.vCPUs[tid]; c != nil && c.mu.TryLock() { + m.mu.Unlock() + return c, nil + } + + // Scan for an available vCPU. + for origTID, c := range m.vCPUs { + if c.LockInState(vCPUReady) { + delete(m.vCPUs, origTID) + m.vCPUs[tid] = c + m.mu.Unlock() + + // We need to reload thread-local segments as + // we have origTID != tid and the vCPU state + // may be stale. + c.loadSegments() + atomic.StoreUint64(&c.tid, tid) + return c, nil + } + } + + // Everything is busy executing user code (locked). + // + // We hold the pool lock here, so we should be able to kick something + // out of kernel mode and have it bounce into host mode when it tries + // to grab the vCPU again. + for _, c := range m.vCPUs { + if c.State() != vCPUWaiter { + c.Bounce() + } + } + + // Give other threads an opportunity to run. + yield() + } +} + +// Put puts the current vCPU. +func (m *machine) Put(c *vCPU) { + c.Unlock() + runtime.UnlockOSThread() +} + +// State returns the current state. +func (c *vCPU) State() uintptr { + return atomic.LoadUintptr(&c.state) +} + +// Lock locks the vCPU. +func (c *vCPU) Lock() { + c.mu.Lock() +} + +// Invalidate invalidates caches. +func (c *vCPU) Invalidate() { +} + +// LockInState locks the vCPU if it is in the given state and TryLock succeeds. +func (c *vCPU) LockInState(state uintptr) bool { + if c.State() == state && c.mu.TryLock() { + if c.State() != state { + c.mu.Unlock() + return false + } + return true + } + return false +} + +// Unlock unlocks the given vCPU. +func (c *vCPU) Unlock() { + // Ensure we're out of guest mode, if necessary. + if c.State() == vCPUWaiter { + redpill() // Force guest mode exit. + } + c.mu.Unlock() +} + +// NotifyInterrupt implements interrupt.Receiver.NotifyInterrupt. +func (c *vCPU) NotifyInterrupt() { + c.Bounce() +} + +// pid is used below in bounce. +var pid = syscall.Getpid() + +// Bounce ensures that the vCPU bounces back to the kernel. +// +// In practice, this means returning EAGAIN from running user code. The vCPU +// will be unlocked and relock, and the kernel is guaranteed to check for +// interrupt notifications (e.g. injected via Notify) and invalidations. +func (c *vCPU) Bounce() { + for { + if c.mu.TryLock() { + // We know that the vCPU must be in the kernel already, + // because the lock was not acquired. We specifically + // don't want to call bounce in this case, because it's + // not necessary to knock the vCPU out of guest mode. + c.mu.Unlock() + return + } + + if state := c.State(); state == vCPUGuest || state == vCPUWaiter { + // We know that the vCPU was in guest mode, so a single signal + // interruption will guarantee that a transition takes place. + syscall.Tgkill(pid, int(atomic.LoadUint64(&c.tid)), bounceSignal) + return + } + + // Someone holds the lock, but the vCPU is not yet transitioned + // into guest mode. It's in the critical section; give it time. + yield() + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go new file mode 100644 index 000000000..dfa691e88 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -0,0 +1,168 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "fmt" + "reflect" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// initArchState initializes architecture-specific state. +func (m *machine) initArchState(vCPUs int) error { + // Set the legacy TSS address. This address is covered by the reserved + // range (up to 4GB). In fact, this is a main reason it exists. + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_TSS_ADDR, + uintptr(reservedMemory-(3*usermem.PageSize))); errno != 0 { + return errno + } + return nil +} + +// initArchState initializes architecture-specific state. +func (c *vCPU) initArchState() error { + var ( + kernelSystemRegs systemRegs + kernelUserRegs userRegs + ) + + // Set base control registers. + kernelSystemRegs.CR0 = c.CR0() + kernelSystemRegs.CR4 = c.CR4() + kernelSystemRegs.EFER = c.EFER() + + // Set the IDT & GDT in the registers. + kernelSystemRegs.IDT.base, kernelSystemRegs.IDT.limit = c.IDT() + kernelSystemRegs.GDT.base, kernelSystemRegs.GDT.limit = c.GDT() + kernelSystemRegs.CS.Load(&ring0.KernelCodeSegment, ring0.Kcode) + kernelSystemRegs.DS.Load(&ring0.UserDataSegment, ring0.Udata) + kernelSystemRegs.ES.Load(&ring0.UserDataSegment, ring0.Udata) + kernelSystemRegs.SS.Load(&ring0.KernelDataSegment, ring0.Kdata) + kernelSystemRegs.FS.Load(&ring0.UserDataSegment, ring0.Udata) + kernelSystemRegs.GS.Load(&ring0.UserDataSegment, ring0.Udata) + tssBase, tssLimit, tss := c.TSS() + kernelSystemRegs.TR.Load(tss, ring0.Tss) + kernelSystemRegs.TR.base = tssBase + kernelSystemRegs.TR.limit = uint32(tssLimit) + + // Point to kernel page tables. + kernelSystemRegs.CR3 = c.machine.kernel.PageTables.FlushCR3() + + // Set the CPUID; this is required before setting system registers, + // since KVM will reject several CR4 bits if the CPUID does not + // indicate the support is available. + if err := c.setCPUID(); err != nil { + return err + } + + // Set the entrypoint for the kernel. + kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer()) + kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer()) + kernelUserRegs.RFLAGS = ring0.KernelFlagsSet + + // Set the system registers. + if err := c.setSystemRegisters(&kernelSystemRegs); err != nil { + return err + } + + // Set the user registers. + if err := c.setUserRegisters(&kernelUserRegs); err != nil { + return err + } + + // Set the time offset to the host native time. + return c.setSystemTime() +} + +// SwitchToUser unpacks architectural-details. +func (c *vCPU) SwitchToUser(regs *syscall.PtraceRegs, fpState *byte, pt *pagetables.PageTables, flags ring0.Flags) (*arch.SignalInfo, usermem.AccessType, error) { + // See below. + var vector ring0.Vector + + // Past this point, stack growth can cause system calls (and a break + // from guest mode). So we need to ensure that between the bluepill + // call here and the switch call immediately below, no additional + // allocations occur. + entersyscall() + bluepill(c) + vector = c.CPU.SwitchToUser(regs, fpState, pt, flags) + exitsyscall() + + // Free and clear. + switch vector { + case ring0.Debug, ring0.Breakpoint: + info := &arch.SignalInfo{Signo: int32(syscall.SIGTRAP)} + return info, usermem.AccessType{}, platform.ErrContextSignal + + case ring0.PageFault: + bluepill(c) // Probably no-op, but may not be. + faultAddr := ring0.ReadCR2() + code, user := c.ErrorCode() + if !user { + // The last fault serviced by this CPU was not a user + // fault, so we can't reliably trust the faultAddr or + // the code provided here. We need to re-execute. + return nil, usermem.NoAccess, platform.ErrContextInterrupt + } + info := &arch.SignalInfo{Signo: int32(syscall.SIGSEGV)} + info.SetAddr(uint64(faultAddr)) + accessType := usermem.AccessType{ + Read: code&(1<<1) == 0, + Write: code&(1<<1) != 0, + Execute: code&(1<<4) != 0, + } + return info, accessType, platform.ErrContextSignal + + case ring0.GeneralProtectionFault: + if !ring0.IsCanonical(regs.Rip) { + // If the RIP is non-canonical, it's a SEGV. + info := &arch.SignalInfo{Signo: int32(syscall.SIGSEGV)} + return info, usermem.AccessType{}, platform.ErrContextSignal + } + // Otherwise, we deliver a SIGBUS. + info := &arch.SignalInfo{Signo: int32(syscall.SIGBUS)} + return info, usermem.AccessType{}, platform.ErrContextSignal + + case ring0.InvalidOpcode: + info := &arch.SignalInfo{Signo: int32(syscall.SIGILL)} + return info, usermem.AccessType{}, platform.ErrContextSignal + + case ring0.X87FloatingPointException: + info := &arch.SignalInfo{Signo: int32(syscall.SIGFPE)} + return info, usermem.AccessType{}, platform.ErrContextSignal + + case ring0.Vector(bounce): + redpill() // Bail and reacqire. + return nil, usermem.NoAccess, platform.ErrContextInterrupt + + case ring0.Syscall, ring0.SyscallInt80: + // System call executed. + return nil, usermem.NoAccess, nil + + default: + panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go new file mode 100644 index 000000000..c2bcb3a47 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -0,0 +1,156 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "fmt" + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/time" +) + +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: 0, + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + +// loadSegments copies the current segments. +// +// This may be called from within the signal context and throws on error. +// +//go:nosplit +func (c *vCPU) loadSegments() { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_ARCH_PRCTL, + linux.ARCH_GET_FS, + uintptr(unsafe.Pointer(&c.CPU.Registers().Fs_base)), + 0); errno != 0 { + throw("getting FS segment") + } + if _, _, errno := syscall.RawSyscall( + syscall.SYS_ARCH_PRCTL, + linux.ARCH_GET_GS, + uintptr(unsafe.Pointer(&c.CPU.Registers().Gs_base)), + 0); errno != 0 { + throw("getting GS segment") + } +} + +// setUserRegisters sets user registers in the vCPU. +func (c *vCPU) setUserRegisters(uregs *userRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) + } + return nil +} + +// setSystemRegisters sets system registers. +func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return fmt.Errorf("error setting system registers: %v", errno) + } + return nil +} + +// setCPUID sets the CPUID to be used by the guest. +func (c *vCPU) setCPUID() error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_CPUID2, + uintptr(unsafe.Pointer(&cpuidSupported))); errno != 0 { + return fmt.Errorf("error setting CPUID: %v", errno) + } + return nil +} + +// setSystemTime sets the TSC for the vCPU. +// +// FIXME: This introduces a slight TSC offset between host and +// guest, which may vary per vCPU. +func (c *vCPU) setSystemTime() error { + const _MSR_IA32_TSC = 0x00000010 + registers := modelControlRegisters{ + nmsrs: 1, + } + registers.entries[0] = modelControlRegister{ + index: _MSR_IA32_TSC, + data: uint64(time.Rdtsc()), + } + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_MSRS, + uintptr(unsafe.Pointer(®isters))); errno != 0 { + return fmt.Errorf("error setting system time: %v", errno) + } + return nil +} + +// setSignalMask sets the vCPU signal mask. +// +// This must be called prior to running the vCPU. +func (c *vCPU) setSignalMask() error { + // The layout of this structure implies that it will not necessarily be + // the same layout chosen by the Go compiler. It gets fudged here. + var data struct { + length uint32 + mask1 uint32 + mask2 uint32 + _ uint32 + } + data.length = 8 // Fixed sigset size. + data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) + data.mask2 = ^uint32(bounceSignalMask >> 32) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SIGNAL_MASK, + uintptr(unsafe.Pointer(&data))); errno != 0 { + return fmt.Errorf("error setting signal mask: %v", errno) + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go new file mode 100644 index 000000000..da67e23f6 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -0,0 +1,112 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "fmt" + "sync/atomic" + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +//go:linkname entersyscall runtime.entersyscall +func entersyscall() + +//go:linkname exitsyscall runtime.exitsyscall +func exitsyscall() + +// TranslateToVirtual implements pagetables.Translater.TranslateToPhysical. +func (m *machine) TranslateToPhysical(ptes *pagetables.PTEs) uintptr { + // The length doesn't matter because all these translations require + // only a single page, which is guaranteed to be satisfied. + physical, _, ok := TranslateToPhysical(uintptr(unsafe.Pointer(ptes))) + if !ok { + panic("unable to translate pagetables.Node to physical address") + } + return physical +} + +// mapRunData maps the vCPU run data. +func mapRunData(fd int) (*runData, error) { + r, _, errno := syscall.RawSyscall6( + syscall.SYS_MMAP, + 0, + uintptr(runDataSize), + syscall.PROT_READ|syscall.PROT_WRITE, + syscall.MAP_SHARED, + uintptr(fd), + 0) + if errno != 0 { + return nil, fmt.Errorf("error mapping runData: %v", errno) + } + return (*runData)(unsafe.Pointer(r)), nil +} + +// unmapRunData unmaps the vCPU run data. +func unmapRunData(r *runData) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_MUNMAP, + uintptr(unsafe.Pointer(r)), + uintptr(runDataSize), + 0); errno != 0 { + return fmt.Errorf("error unmapping runData: %v", errno) + } + return nil +} + +// notify notifies that the vCPU has returned to host mode. +// +// This may be called by a signal handler and therefore throws on error. +// +//go:nosplit +func (c *vCPU) notify() { + _, _, errno := syscall.RawSyscall6( + syscall.SYS_FUTEX, + uintptr(unsafe.Pointer(&c.state)), + linux.FUTEX_WAKE, + ^uintptr(0), // Number of waiters. + 0, 0, 0) + if errno != 0 { + throw("futex wake error") + } +} + +// wait waits for the vCPU to return to host mode. +// +// This panics on error. +func (c *vCPU) wait() { + if !atomic.CompareAndSwapUintptr(&c.state, vCPUGuest, vCPUWaiter) { + return // Nothing to wait for. + } + for { + _, _, errno := syscall.Syscall6( + syscall.SYS_FUTEX, + uintptr(unsafe.Pointer(&c.state)), + linux.FUTEX_WAIT, + uintptr(vCPUWaiter), // Expected value. + 0, 0, 0) + if errno == syscall.EINTR { + continue + } else if errno == syscall.EAGAIN { + break + } else if errno != 0 { + panic("futex wait error") + } + break + } +} diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go new file mode 100644 index 000000000..5d55c9486 --- /dev/null +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -0,0 +1,221 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "fmt" + "sort" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +const ( + // reservedMemory is a chunk of physical memory reserved starting at + // physical address zero. There are some special pages in this region, + // so we just call the whole thing off. + // + // Other architectures may define this to be zero. + reservedMemory = 0x100000000 +) + +type region struct { + virtual uintptr + length uintptr +} + +type physicalRegion struct { + region + physical uintptr +} + +// physicalRegions contains a list of available physical regions. +// +// The physical value used in physicalRegions is a number indicating the +// physical offset, aligned appropriately and starting above reservedMemory. +var physicalRegions []physicalRegion + +// fillAddressSpace fills the host address space with PROT_NONE mappings until +// the number of available bits until we have a host address space size that is +// equal to the physical address space. +// +// The excluded regions are returned. +func fillAddressSpace() (excludedRegions []region) { + // We can cut vSize in half, because the kernel will be using the top + // half and we ignore it while constructing mappings. It's as if we've + // already excluded half the possible addresses. + vSize := uintptr(1) << ring0.VirtualAddressBits() + vSize = vSize >> 1 + + // We exclude reservedMemory below from our physical memory size, so it + // needs to be dropped here as well. Otherwise, we could end up with + // physical addresses that are beyond what is mapped. + pSize := uintptr(1) << ring0.PhysicalAddressBits() + pSize -= reservedMemory + + // Sanity check. + if vSize < pSize { + panic(fmt.Sprintf("vSize (%x) < pSize (%x)", vSize, pSize)) + } + + // Add specifically excluded regions; see excludeVirtualRegion. + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + excludedRegions = append(excludedRegions, vr.region) + vSize -= vr.length + log.Infof("excluded: virtual [%x,%x)", vr.virtual, vr.virtual+vr.length) + } + }) + + // Calculate the required space and fill it. + // + // Note carefully that we add faultBlockSize to required up front, and + // on each iteration of the loop below (i.e. each new physical region + // we define), we add faultBlockSize again. This is done because the + // computation of physical regions will ensure proper alignments with + // faultBlockSize, potentially causing up to faultBlockSize bytes in + // internal fragmentation for each physical region. So we need to + // account for this properly during allocation. + requiredAddr, ok := usermem.Addr(vSize - pSize + faultBlockSize).RoundUp() + if !ok { + panic(fmt.Sprintf( + "overflow for vSize (%x) - pSize (%x) + faultBlockSize (%x)", + vSize, pSize, faultBlockSize)) + } + required := uintptr(requiredAddr) + current := required // Attempted mmap size. + for filled := uintptr(0); filled < required && current > 0; { + addr, _, errno := syscall.RawSyscall6( + syscall.SYS_MMAP, + 0, // Suggested address. + current, + syscall.PROT_NONE, + syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE|syscall.MAP_NORESERVE, + 0, 0) + if errno != 0 { + // Attempt half the size; overflow not possible. + currentAddr, _ := usermem.Addr(current >> 1).RoundUp() + current = uintptr(currentAddr) + continue + } + // We filled a block. + filled += current + excludedRegions = append(excludedRegions, region{ + virtual: addr, + length: current, + }) + // See comment above. + if filled != required { + required += faultBlockSize + } + } + if current == 0 { + panic("filling address space failed") + } + sort.Slice(excludedRegions, func(i, j int) bool { + return excludedRegions[i].virtual < excludedRegions[j].virtual + }) + for _, r := range excludedRegions { + log.Infof("region: virtual [%x,%x)", r.virtual, r.virtual+r.length) + } + return excludedRegions +} + +// computePhysicalRegions computes physical regions. +func computePhysicalRegions(excludedRegions []region) (physicalRegions []physicalRegion) { + physical := uintptr(reservedMemory) + addValidRegion := func(virtual, length uintptr) { + if length == 0 { + return + } + if virtual == 0 { + virtual += usermem.PageSize + length -= usermem.PageSize + } + if end := virtual + length; end > ring0.MaximumUserAddress { + length -= (end - ring0.MaximumUserAddress) + } + if length == 0 { + return + } + // Round physical up to the same alignment as the virtual + // address (with respect to faultBlockSize). + if offset := virtual &^ faultBlockMask; physical&^faultBlockMask != offset { + if newPhysical := (physical & faultBlockMask) + offset; newPhysical > physical { + physical = newPhysical // Round up by only a little bit. + } else { + physical = ((physical + faultBlockSize) & faultBlockMask) + offset + } + } + physicalRegions = append(physicalRegions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: physical, + }) + physical += length + } + lastExcludedEnd := uintptr(0) + for _, r := range excludedRegions { + addValidRegion(lastExcludedEnd, r.virtual-lastExcludedEnd) + lastExcludedEnd = r.virtual + r.length + } + addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd) + + // Dump our all physical regions. + for _, r := range physicalRegions { + log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)", + r.virtual, r.virtual+r.length, r.physical, r.physical+r.length) + } + return physicalRegions +} + +// physicalInit initializes physical address mappings. +func physicalInit() { + physicalRegions = computePhysicalRegions(fillAddressSpace()) +} + +// applyPhysicalRegions applies the given function on physical regions. +// +// Iteration continues as long as true is returned. The return value is the +// return from the last call to fn, or true if there are no entries. +// +// Precondition: physicalInit must have been called. +func applyPhysicalRegions(fn func(pr physicalRegion) bool) bool { + for _, pr := range physicalRegions { + if !fn(pr) { + return false + } + } + return true +} + +// TranslateToPhysical translates the given virtual address. +// +// Precondition: physicalInit must have been called. +func TranslateToPhysical(virtual uintptr) (physical uintptr, length uintptr, ok bool) { + ok = !applyPhysicalRegions(func(pr physicalRegion) bool { + if pr.virtual <= virtual && virtual < pr.virtual+pr.length { + physical = pr.physical + (virtual - pr.virtual) + length = pr.length - (virtual - pr.virtual) + return false + } + return true + }) + return +} diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD new file mode 100644 index 000000000..8533a8d89 --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -0,0 +1,15 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "testutil", + testonly = 1, + srcs = [ + "testutil.go", + "testutil_amd64.go", + "testutil_amd64.s", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/kvm/testutil", + visibility = ["//pkg/sentry/platform/kvm:__pkg__"], +) diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go new file mode 100644 index 000000000..8a614e25d --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil.go @@ -0,0 +1,75 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil provides common assembly stubs for testing. +package testutil + +import ( + "fmt" + "strings" +) + +// Getpid executes a trivial system call. +func Getpid() + +// Touch touches the value in the first register. +func Touch() + +// SyscallLoop executes a syscall and loops. +func SyscallLoop() + +// SpinLoop spins on the CPU. +func SpinLoop() + +// HaltLoop immediately halts and loops. +func HaltLoop() + +// TwiddleRegsFault twiddles registers then faults. +func TwiddleRegsFault() + +// TwiddleRegsSyscall twiddles registers then executes a syscall. +func TwiddleRegsSyscall() + +// TwiddleSegments reads segments into known registers. +func TwiddleSegments() + +// FloatingPointWorks is a floating point test. +// +// It returns true or false. +func FloatingPointWorks() bool + +// RegisterMismatchError is used for checking registers. +type RegisterMismatchError []string + +// Error returns a human-readable error. +func (r RegisterMismatchError) Error() string { + return strings.Join([]string(r), ";") +} + +// addRegisterMisatch allows simple chaining of register mismatches. +func addRegisterMismatch(err error, reg string, got, expected interface{}) error { + errStr := fmt.Sprintf("%s got %08x, expected %08x", reg, got, expected) + switch r := err.(type) { + case nil: + // Return a new register mismatch. + return RegisterMismatchError{errStr} + case RegisterMismatchError: + // Append the error. + r = append(r, errStr) + return r + default: + // Leave as is. + return err + } +} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go new file mode 100644 index 000000000..39286a0af --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go @@ -0,0 +1,135 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package testutil + +import ( + "reflect" + "syscall" +) + +// SetTestTarget sets the rip appropriately. +func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { + regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) +} + +// SetTouchTarget sets rax appropriately. +func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { + if target != nil { + regs.Rax = uint64(reflect.ValueOf(target).Pointer()) + } else { + regs.Rax = 0 + } +} + +// RewindSyscall rewinds a syscall RIP. +func RewindSyscall(regs *syscall.PtraceRegs) { + regs.Rip -= 2 +} + +// SetTestRegs initializes registers to known values. +func SetTestRegs(regs *syscall.PtraceRegs) { + regs.R15 = 0x15 + regs.R14 = 0x14 + regs.R13 = 0x13 + regs.R12 = 0x12 + regs.Rbp = 0xb9 + regs.Rbx = 0xb4 + regs.R11 = 0x11 + regs.R10 = 0x10 + regs.R9 = 0x09 + regs.R8 = 0x08 + regs.Rax = 0x44 + regs.Rcx = 0xc4 + regs.Rdx = 0xd4 + regs.Rsi = 0x51 + regs.Rdi = 0xd1 + regs.Rsp = 0x59 +} + +// CheckTestRegs checks that registers were twiddled per TwiddleRegs. +func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { + if need := ^uint64(0x15); regs.R15 != need { + err = addRegisterMismatch(err, "R15", regs.R15, need) + } + if need := ^uint64(0x14); regs.R14 != need { + err = addRegisterMismatch(err, "R14", regs.R14, need) + } + if need := ^uint64(0x13); regs.R13 != need { + err = addRegisterMismatch(err, "R13", regs.R13, need) + } + if need := ^uint64(0x12); regs.R12 != need { + err = addRegisterMismatch(err, "R12", regs.R12, need) + } + if need := ^uint64(0xb9); regs.Rbp != need { + err = addRegisterMismatch(err, "Rbp", regs.Rbp, need) + } + if need := ^uint64(0xb4); regs.Rbx != need { + err = addRegisterMismatch(err, "Rbx", regs.Rbx, need) + } + if need := ^uint64(0x10); regs.R10 != need { + err = addRegisterMismatch(err, "R10", regs.R10, need) + } + if need := ^uint64(0x09); regs.R9 != need { + err = addRegisterMismatch(err, "R9", regs.R9, need) + } + if need := ^uint64(0x08); regs.R8 != need { + err = addRegisterMismatch(err, "R8", regs.R8, need) + } + if need := ^uint64(0x44); regs.Rax != need { + err = addRegisterMismatch(err, "Rax", regs.Rax, need) + } + if need := ^uint64(0xd4); regs.Rdx != need { + err = addRegisterMismatch(err, "Rdx", regs.Rdx, need) + } + if need := ^uint64(0x51); regs.Rsi != need { + err = addRegisterMismatch(err, "Rsi", regs.Rsi, need) + } + if need := ^uint64(0xd1); regs.Rdi != need { + err = addRegisterMismatch(err, "Rdi", regs.Rdi, need) + } + if need := ^uint64(0x59); regs.Rsp != need { + err = addRegisterMismatch(err, "Rsp", regs.Rsp, need) + } + // Rcx & R11 are ignored if !full is set. + if need := ^uint64(0x11); full && regs.R11 != need { + err = addRegisterMismatch(err, "R11", regs.R11, need) + } + if need := ^uint64(0xc4); full && regs.Rcx != need { + err = addRegisterMismatch(err, "Rcx", regs.Rcx, need) + } + return +} + +var fsData uint64 = 0x55 +var gsData uint64 = 0x85 + +// SetTestSegments initializes segments to known values. +func SetTestSegments(regs *syscall.PtraceRegs) { + regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer()) + regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer()) +} + +// CheckTestSegments checks that registers were twiddled per TwiddleSegments. +func CheckTestSegments(regs *syscall.PtraceRegs) (err error) { + if regs.Rax != fsData { + err = addRegisterMismatch(err, "Rax", regs.Rax, fsData) + } + if regs.Rbx != gsData { + err = addRegisterMismatch(err, "Rbx", regs.Rcx, gsData) + } + return +} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s new file mode 100644 index 000000000..3b5ad8817 --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s @@ -0,0 +1,98 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +// test_util_amd64.s provides AMD64 test functions. + +#include "funcdata.h" +#include "textflag.h" + +TEXT ·Getpid(SB),NOSPLIT,$0 + NO_LOCAL_POINTERS + MOVQ $39, AX // getpid + SYSCALL + RET + +TEXT ·Touch(SB),NOSPLIT,$0 +start: + MOVQ 0(AX), BX // deref AX + MOVQ $39, AX // getpid + SYSCALL + JMP start + +TEXT ·HaltLoop(SB),NOSPLIT,$0 +start: + HLT + JMP start + +TEXT ·SyscallLoop(SB),NOSPLIT,$0 +start: + SYSCALL + JMP start + +TEXT ·SpinLoop(SB),NOSPLIT,$0 +start: + JMP start + +TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 + NO_LOCAL_POINTERS + MOVQ $1, AX + MOVQ AX, X0 + MOVQ $39, AX // getpid + SYSCALL + MOVQ X0, AX + CMPQ AX, $1 + SETEQ ret+0(FP) + RET + +#define TWIDDLE_REGS() \ + NOTQ R15; \ + NOTQ R14; \ + NOTQ R13; \ + NOTQ R12; \ + NOTQ BP; \ + NOTQ BX; \ + NOTQ R11; \ + NOTQ R10; \ + NOTQ R9; \ + NOTQ R8; \ + NOTQ AX; \ + NOTQ CX; \ + NOTQ DX; \ + NOTQ SI; \ + NOTQ DI; \ + NOTQ SP; + +TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 + TWIDDLE_REGS() + SYSCALL + RET // never reached + +TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 + TWIDDLE_REGS() + JMP AX // must fault + RET // never reached + +#define READ_FS() BYTE $0x64; BYTE $0x48; BYTE $0x8b; BYTE $0x00; +#define READ_GS() BYTE $0x65; BYTE $0x48; BYTE $0x8b; BYTE $0x00; + +TEXT ·TwiddleSegments(SB),NOSPLIT,$0 + MOVQ $0x0, AX + READ_GS() + MOVQ AX, BX + MOVQ $0x0, AX + READ_FS() + SYSCALL + RET // never reached diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go new file mode 100644 index 000000000..0d3fbe043 --- /dev/null +++ b/pkg/sentry/platform/kvm/virtual_map.go @@ -0,0 +1,113 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "bufio" + "fmt" + "io" + "os" + "regexp" + "strconv" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +type virtualRegion struct { + region + accessType usermem.AccessType + shared bool + offset uintptr + filename string +} + +// mapsLine matches a single line from /proc/PID/maps. +var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)") + +// excludeRegion returns true if these regions should be excluded from the +// physical map. Virtual regions need to be excluded if get_user_pages will +// fail on those addresses, preventing KVM from satisfying EPT faults. +// +// This includes the VVAR page because the VVAR page may be mapped as I/O +// memory. And the VDSO page is knocked out because the VVAR page is not even +// recorded in /proc/self/maps on older kernels; knocking out the VDSO page +// prevents code in the VDSO from accessing the VVAR address. +// +// This is called by the physical map functions, not applyVirtualRegions. +func excludeVirtualRegion(r virtualRegion) bool { + return r.filename == "[vvar]" || r.filename == "[vdso]" +} + +// applyVirtualRegions parses the process maps file. +// +// Unlike mappedRegions, these are not consistent over time. +func applyVirtualRegions(fn func(vr virtualRegion)) error { + // Open /proc/self/maps. + f, err := os.Open("/proc/self/maps") + if err != nil { + return err + } + defer f.Close() + + // Parse all entries. + r := bufio.NewReader(f) + for { + b, err := r.ReadBytes('\n') + if b != nil && len(b) > 0 { + m := mapsLine.FindSubmatch(b) + if m == nil { + // This should not happen: kernel bug? + return fmt.Errorf("badly formed line: %v", string(b)) + } + start, err := strconv.ParseUint(string(m[1]), 16, 64) + if err != nil { + return fmt.Errorf("bad start address: %v", string(b)) + } + end, err := strconv.ParseUint(string(m[2]), 16, 64) + if err != nil { + return fmt.Errorf("bad end address: %v", string(b)) + } + read := m[3][0] == 'r' + write := m[3][1] == 'w' + execute := m[3][2] == 'x' + shared := m[3][3] == 's' + offset, err := strconv.ParseUint(string(m[4]), 16, 64) + if err != nil { + return fmt.Errorf("bad offset: %v", string(b)) + } + fn(virtualRegion{ + region: region{ + virtual: uintptr(start), + length: uintptr(end - start), + }, + accessType: usermem.AccessType{ + Read: read, + Write: write, + Execute: execute, + }, + shared: shared, + offset: uintptr(offset), + filename: string(m[5]), + }) + } + if err != nil && err == io.EOF { + break + } else if err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go new file mode 100644 index 000000000..31e5b0e61 --- /dev/null +++ b/pkg/sentry/platform/kvm/virtual_map_test.go @@ -0,0 +1,78 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "syscall" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +type checker struct { + ok bool +} + +func (c *checker) Contains(addr uintptr) func(virtualRegion) { + c.ok = false // Reset for below calls. + return func(vr virtualRegion) { + if vr.virtual <= addr && addr < vr.virtual+vr.length { + c.ok = true + } + } +} + +func TestParseMaps(t *testing.T) { + c := new(checker) + + // Simple test. + if err := applyVirtualRegions(c.Contains(0)); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // MMap a new page. + addr, _, errno := syscall.RawSyscall6( + syscall.SYS_MMAP, 0, usermem.PageSize, + syscall.PROT_READ|syscall.PROT_WRITE, + syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE, 0, 0) + if errno != 0 { + t.Fatalf("unexpected map error: %v", errno) + } + + // Re-parse maps. + if err := applyVirtualRegions(c.Contains(addr)); err != nil { + syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) + t.Fatalf("unexpected error: %v", err) + } + + // Assert that it now does contain the region. + if !c.ok { + syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) + t.Fatalf("updated map does not contain 0x%08x, expected true", addr) + } + + // Unmap the region. + syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) + + // Re-parse maps. + if err := applyVirtualRegions(c.Contains(addr)); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Assert that it once again does _not_ contain the region. + if c.ok { + t.Fatalf("final map does contain 0x%08x, expected false", addr) + } +} diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go new file mode 100644 index 000000000..6398e5e01 --- /dev/null +++ b/pkg/sentry/platform/mmap_min_addr.go @@ -0,0 +1,60 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package platform + +import ( + "fmt" + "io/ioutil" + "strconv" + "strings" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// systemMMapMinAddrSource is the source file. +const systemMMapMinAddrSource = "/proc/sys/vm/mmap_min_addr" + +// systemMMapMinAddr is the system's minimum map address. +var systemMMapMinAddr uint64 + +// SystemMMapMinAddr returns the minimum system address. +func SystemMMapMinAddr() usermem.Addr { + return usermem.Addr(systemMMapMinAddr) +} + +// MMapMinAddr is a size zero struct that implements MinUserAddress based on +// the system minimum address. It is suitable for embedding in platforms that +// rely on the system mmap, and thus require the system minimum. +type MMapMinAddr struct { +} + +// MinUserAddress implements platform.MinUserAddresss. +func (*MMapMinAddr) MinUserAddress() usermem.Addr { + return SystemMMapMinAddr() +} + +func init() { + // Open the source file. + b, err := ioutil.ReadFile(systemMMapMinAddrSource) + if err != nil { + panic(fmt.Sprintf("couldn't open %s: %v", systemMMapMinAddrSource, err)) + } + + // Parse the result. + systemMMapMinAddr, err = strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) + if err != nil { + panic(fmt.Sprintf("couldn't parse %s from %s: %v", string(b), systemMMapMinAddrSource, err)) + } +} diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go new file mode 100644 index 000000000..6219dada7 --- /dev/null +++ b/pkg/sentry/platform/platform.go @@ -0,0 +1,428 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package platform provides a Platform abstraction. +// +// See Platform for more information. +package platform + +import ( + "fmt" + "io" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/usage" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// Platform provides abstractions for execution contexts (Context) and memory +// management (Memory, AddressSpace). +type Platform interface { + // SupportsAddressSpaceIO returns true if AddressSpaces returned by this + // Platform support AddressSpaceIO methods. + // + // The value returned by SupportsAddressSpaceIO is guaranteed to remain + // unchanged over the lifetime of the Platform. + SupportsAddressSpaceIO() bool + + // CooperativelySchedulesAddressSpace returns true if the Platform has a + // limited number of AddressSpaces, such that mm.MemoryManager.Deactivate + // should call AddressSpace.Release when there are no goroutines that + // require the mm.MemoryManager to have an active AddressSpace. + // + // The value returned by CooperativelySchedulesAddressSpace is guaranteed + // to remain unchanged over the lifetime of the Platform. + CooperativelySchedulesAddressSpace() bool + + // DetectsCPUPreemption returns true if Contexts returned by the Platform + // can reliably return ErrContextCPUPreempted. + DetectsCPUPreemption() bool + + // MapUnit returns the alignment used for optional mappings into this + // platform's AddressSpaces. Higher values indicate lower per-page + // costs for AddressSpace.MapInto. As a special case, a MapUnit of 0 + // indicates that the cost of AddressSpace.MapInto is effectively + // independent of the number of pages mapped. If MapUnit is non-zero, + // it must be a power-of-2 multiple of usermem.PageSize. + MapUnit() uint64 + + // MinUserAddress returns the minimum mappable address on this + // platform. + MinUserAddress() usermem.Addr + + // MaxUserAddress returns the maximum mappable address on this + // platform. + MaxUserAddress() usermem.Addr + + // NewAddressSpace returns a new memory context for this platform. + // + // If mappingsID is not nil, the platform may assume that (1) all calls + // to NewAddressSpace with the same mappingsID represent the same + // (mutable) set of mappings, and (2) the set of mappings has not + // changed since the last time AddressSpace.Release was called on an + // AddressSpace returned by a call to NewAddressSpace with the same + // mappingsID. + // + // If a new AddressSpace cannot be created immediately, a nil + // AddressSpace is returned, along with channel that is closed when + // the caller should retry a call to NewAddressSpace. + // + // In general, this blocking behavior only occurs when + // CooperativelySchedulesAddressSpace (above) returns false. + NewAddressSpace(mappingsID interface{}) (AddressSpace, <-chan struct{}, error) + + // NewContext returns a new execution context. + NewContext() Context + + // Memory returns memory for allocations. + Memory() Memory + + // PreemptAllCPUs causes all concurrent calls to Context.Switch(), as well + // as the first following call to Context.Switch() for each Context, to + // return ErrContextCPUPreempted. + // + // PreemptAllCPUs is only supported if DetectsCPUPremption() == true. + // Platforms for which this does not hold may panic if PreemptAllCPUs is + // called. + PreemptAllCPUs() error +} + +// NoCPUPreemptionDetection implements Platform.DetectsCPUPreemption and +// dependent methods for Platforms that do not support this feature. +type NoCPUPreemptionDetection struct{} + +// DetectsCPUPreemption implements Platform.DetectsCPUPreemption. +func (NoCPUPreemptionDetection) DetectsCPUPreemption() bool { + return false +} + +// PreemptAllCPUs implements Platform.PreemptAllCPUs. +func (NoCPUPreemptionDetection) PreemptAllCPUs() error { + panic("This platform does not support CPU preemption detection") +} + +// Context represents the execution context for a single thread. +type Context interface { + // Switch resumes execution of the thread specified by the arch.Context + // in the provided address space. This call will block while the thread + // is executing. + // + // If cpu is non-negative, and it is not the number of the CPU that the + // thread executes on, Context should return ErrContextCPUPreempted. cpu + // can only be non-negative if Platform.DetectsCPUPreemption() is true; + // Contexts from Platforms for which this does not hold may ignore cpu, or + // panic if cpu is non-negative. + // + // Switch may return one of the following special errors: + // + // - nil: The Context invoked a system call. + // + // - ErrContextSignal: The Context was interrupted by a signal. The + // returned *arch.SignalInfo contains information about the signal. If + // arch.SignalInfo.Signo == SIGSEGV, the returned usermem.AccessType + // contains the access type of the triggering fault. + // + // - ErrContextInterrupt: The Context was interrupted by a call to + // Interrupt(). Switch() may return ErrContextInterrupt spuriously. In + // particular, most implementations of Interrupt() will cause the first + // following call to Switch() to return ErrContextInterrupt if there is no + // concurrent call to Switch(). + // + // - ErrContextCPUPreempted: See the definition of that error for details. + Switch(as AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) + + // Interrupt interrupts a concurrent call to Switch(), causing it to return + // ErrContextInterrupt. + Interrupt() +} + +var ( + // ErrContextSignal is returned by Context.Switch() to indicate that the + // Context was interrupted by a signal. + ErrContextSignal = fmt.Errorf("interrupted by signal") + + // ErrContextInterrupt is returned by Context.Switch() to indicate that the + // Context was interrupted by a call to Context.Interrupt(). + ErrContextInterrupt = fmt.Errorf("interrupted by platform.Context.Interrupt()") + + // ErrContextCPUPreempted is returned by Context.Switch() to indicate that + // one of the following occurred: + // + // - The CPU executing the Context is not the CPU passed to + // Context.Switch(). + // + // - The CPU executing the Context may have executed another Context since + // the last time it executed this one; or the CPU has previously executed + // another Context, and has never executed this one. + // + // - Platform.PreemptAllCPUs() was called since the last return from + // Context.Switch(). + ErrContextCPUPreempted = fmt.Errorf("interrupted by CPU preemption") +) + +// SignalInterrupt is a signal reserved for use by implementations of +// Context.Interrupt(). The sentry guarantees that it will ignore delivery of +// this signal both to Contexts and to the sentry itself, under the assumption +// that they originate from races with Context.Interrupt(). +// +// NOTE: The Go runtime only guarantees that a small subset +// of signals will be always be unblocked on all threads, one of which +// is SIGCHLD. +const SignalInterrupt = linux.SIGCHLD + +// AddressSpace represents a virtual address space in which a Context can +// execute. +type AddressSpace interface { + // MapFile creates a shared mapping of offsets in fr, from the file + // with file descriptor fd, at address addr. Any existing overlapping + // mappings are silently replaced. + // + // If precommit is true, host memory should be committed to the mapping + // when MapFile returns when possible. The precommit flag is advisory + // and implementations may choose to ignore it. + // + // Preconditions: addr and fr must be page-aligned. length > 0. + // at.Any() == true. + MapFile(addr usermem.Addr, fd int, fr FileRange, at usermem.AccessType, precommit bool) error + + // Unmap unmaps the given range. + // + // Preconditions: addr is page-aligned. length > 0. + Unmap(addr usermem.Addr, length uint64) + + // Release releases this address space. After releasing, a new AddressSpace + // must be acquired via platform.NewAddressSpace(). + Release() error + + // AddressSpaceIO methods are supported iff the associated platform's + // Platform.SupportsAddressSpaceIO() == true. AddressSpaces for which this + // does not hold may panic if AddressSpaceIO methods are invoked. + AddressSpaceIO +} + +// AddressSpaceIO supports IO through the memory mappings installed in an +// AddressSpace. +// +// AddressSpaceIO implementors are responsible for ensuring that address ranges +// are application-mappable. +type AddressSpaceIO interface { + // CopyOut copies len(src) bytes from src to the memory mapped at addr. It + // returns the number of bytes copied. If the number of bytes copied is < + // len(src), it returns a non-nil error explaining why. + CopyOut(addr usermem.Addr, src []byte) (int, error) + + // CopyIn copies len(dst) bytes from the memory mapped at addr to dst. + // It returns the number of bytes copied. If the number of bytes copied is + // < len(dst), it returns a non-nil error explaining why. + CopyIn(addr usermem.Addr, dst []byte) (int, error) + + // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number + // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a + // non-nil error explaining why. + ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) + + // SwapUint32 atomically sets the uint32 value at addr to new and returns + // the previous value. + // + // Preconditions: addr must be aligned to a 4-byte boundary. + SwapUint32(addr usermem.Addr, new uint32) (uint32, error) + + // CompareAndSwapUint32 atomically compares the uint32 value at addr to + // old; if they are equal, the value in memory is replaced by new. In + // either case, the previous value stored in memory is returned. + // + // Preconditions: addr must be aligned to a 4-byte boundary. + CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) +} + +// NoAddressSpaceIO implements AddressSpaceIO methods by panicing. +type NoAddressSpaceIO struct{} + +// CopyOut implements AddressSpaceIO.CopyOut. +func (NoAddressSpaceIO) CopyOut(addr usermem.Addr, src []byte) (int, error) { + panic("This platform does not support AddressSpaceIO") +} + +// CopyIn implements AddressSpaceIO.CopyIn. +func (NoAddressSpaceIO) CopyIn(addr usermem.Addr, dst []byte) (int, error) { + panic("This platform does not support AddressSpaceIO") +} + +// ZeroOut implements AddressSpaceIO.ZeroOut. +func (NoAddressSpaceIO) ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) { + panic("This platform does not support AddressSpaceIO") +} + +// SwapUint32 implements AddressSpaceIO.SwapUint32. +func (NoAddressSpaceIO) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { + panic("This platform does not support AddressSpaceIO") +} + +// CompareAndSwapUint32 implements AddressSpaceIO.CompareAndSwapUint32. +func (NoAddressSpaceIO) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { + panic("This platform does not support AddressSpaceIO") +} + +// SegmentationFault is an error returned by AddressSpaceIO methods when IO +// fails due to access of an unmapped page, or a mapped page with insufficient +// permissions. +type SegmentationFault struct { + // Addr is the address at which the fault occurred. + Addr usermem.Addr +} + +// Error implements error.Error. +func (f SegmentationFault) Error() string { + return fmt.Sprintf("segmentation fault at %#x", f.Addr) +} + +// File represents a host file that may be mapped into an AddressSpace. +type File interface { + // MapInto maps fr into as, starting at addr, for accesses of type at. + // + // If precommit is true, the platform should eagerly commit resources (e.g. + // physical memory) to the mapping. The precommit flag is advisory and + // implementations may choose to ignore it. + // + // Note that there is no File.Unmap; clients should use as.Unmap directly. + // + // Preconditions: fr.Start and fr.End must be page-aligned. + // fr.Length() > 0. at.Any() == true. Implementors may define + // additional requirements. + MapInto(as AddressSpace, addr usermem.Addr, fr FileRange, at usermem.AccessType, precommit bool) error + + // MapInternal returns a mapping of the given file offsets in the invoking + // process' address space for reading and writing. The lifetime of the + // returned mapping is implementation-defined. + // + // Note that fr.Start and fr.End need not be page-aligned. + // + // Preconditions: fr.Length() > 0. Implementors may define additional + // requirements. + MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + + // IncRef signals that a region in the file is actively referenced through a + // memory map. Implementors must ensure that the contents of a referenced + // region remain consistent. Specifically, mappings returned by MapInternal + // must refer to the same underlying contents. If the implementor also + // implements the Memory interface, the file range must not be reused in a + // different allocation while it has active references. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > 0. + IncRef(fr FileRange) + + // DecRef reduces the frame ref count on the range specified by fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. DecRef()s on a region must match earlier IncRef()s. + DecRef(fr FileRange) +} + +// FileRange represents a range of uint64 offsets into a File. +// +// type FileRange <generated using go_generics> + +// String implements fmt.Stringer.String. +func (fr FileRange) String() string { + return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) +} + +// Memory represents an allocatable File that may be mapped into any +// AddressSpace associated with the same Platform. +type Memory interface { + // Memory implements File methods with the following properties: + // + // - Pages mapped by MapInto must be allocated, and must be unmapped from + // all AddressSpaces before they are freed. + // + // - Pages mapped by MapInternal must be allocated. Returned mappings are + // guaranteed to be valid until the mapped pages are freed. + File + + // Allocate returns a range of pages of the given length, owned by the + // caller and with the given accounting kind. Allocated memory initially has + // a single reference and will automatically be freed when no references to + // them remain. See File.IncRef and File.DecRef. + // + // Preconditions: length must be page-aligned and non-zero. + Allocate(length uint64, kind usage.MemoryKind) (FileRange, error) + + // Decommit releases resources associated with maintaining the contents of + // the given frames. If Decommit succeeds, future accesses of the + // decommitted frames will read zeroes. + // + // Preconditions: fr.Length() > 0. + Decommit(fr FileRange) error + + // UpdateUsage updates the memory usage statistics. This must be called + // before the relevant memory statistics in usage.MemoryAccounting can + // be considered accurate. + UpdateUsage() error + + // TotalUsage returns an aggregate usage for all memory statistics + // except Mapped (which is external to the Memory implementation). This + // is generally much cheaper than UpdateUsage, but will not provide a + // fine-grained breakdown. + TotalUsage() (uint64, error) + + // TotalSize returns the current maximum size of the Memory in bytes. The + // value returned by TotalSize is permitted to change. + TotalSize() uint64 + + // Destroy releases all resources associated with the Memory. + // + // Preconditions: There are no remaining uses of any of the freed memory's + // frames. + // + // Postconditions: None of the Memory's methods may be called after Destroy. + Destroy() + + // SaveTo saves the memory state to the given stream, which will + // generally be a statefile. + SaveTo(w io.Writer) error + + // LoadFrom loads the memory state from the given stream, which will + // generally be a statefile. + LoadFrom(r io.Reader) error +} + +// AllocateAndFill allocates memory of the given kind from mem and fills it by +// calling r.ReadToBlocks() repeatedly until either length bytes are read or a +// non-nil error is returned. It returns the memory filled by r, truncated down +// to the nearest page. If this is shorter than length bytes due to an error +// returned by r.ReadToBlocks(), it returns that error. +// +// Preconditions: length > 0. length must be page-aligned. +func AllocateAndFill(mem Memory, length uint64, kind usage.MemoryKind, r safemem.Reader) (FileRange, error) { + fr, err := mem.Allocate(length, kind) + if err != nil { + return FileRange{}, err + } + dsts, err := mem.MapInternal(fr, usermem.Write) + if err != nil { + mem.DecRef(fr) + return FileRange{}, err + } + n, err := safemem.ReadFullToBlocks(r, dsts) + un := uint64(usermem.Addr(n).RoundDown()) + if un < length { + // Free unused memory and update fr to contain only the memory that is + // still allocated. + mem.DecRef(FileRange{fr.Start + un, fr.End}) + fr.End = fr.Start + un + } + return fr, err +} diff --git a/pkg/sentry/platform/procid/BUILD b/pkg/sentry/platform/procid/BUILD new file mode 100644 index 000000000..5db4f6261 --- /dev/null +++ b/pkg/sentry/platform/procid/BUILD @@ -0,0 +1,32 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "procid", + srcs = [ + "procid.go", + "procid_amd64.s", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid", + visibility = ["//pkg/sentry:internal"], +) + +go_test( + name = "procid_test", + size = "small", + srcs = [ + "procid_test.go", + ], + embed = [":procid"], +) + +go_test( + name = "procid_net_test", + size = "small", + srcs = [ + "procid_net_test.go", + "procid_test.go", + ], + embed = [":procid"], +) diff --git a/pkg/sentry/platform/procid/procid.go b/pkg/sentry/platform/procid/procid.go new file mode 100644 index 000000000..5f861908f --- /dev/null +++ b/pkg/sentry/platform/procid/procid.go @@ -0,0 +1,21 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package procid provides a way to get the current system thread identifier. +package procid + +// Current returns the current system thread identifier. +// +// Precondition: This should only be called with the runtime OS thread locked. +func Current() uint64 diff --git a/pkg/sentry/platform/procid/procid_amd64.s b/pkg/sentry/platform/procid/procid_amd64.s new file mode 100644 index 000000000..ead4e3d91 --- /dev/null +++ b/pkg/sentry/platform/procid/procid_amd64.s @@ -0,0 +1,30 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 +// +build go1.8 +// +build !go1.11 + +#include "textflag.h" + +TEXT ·Current(SB),NOSPLIT,$0-8 + // The offset specified here is the m_procid offset for Go1.8+. + // Changes to this offset should be caught by the tests, and major + // version changes require an explicit tag change above. + MOVQ TLS, AX + MOVQ 0(AX)(TLS*1), AX + MOVQ 48(AX), AX // g_m (may change in future versions) + MOVQ 72(AX), AX // m_procid (may change in future versions) + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/sentry/platform/procid/procid_net_test.go b/pkg/sentry/platform/procid/procid_net_test.go new file mode 100644 index 000000000..2d1605a08 --- /dev/null +++ b/pkg/sentry/platform/procid/procid_net_test.go @@ -0,0 +1,21 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procid + +// This file is just to force the inclusion of the "net" package, which will +// make the test binary a cgo one. +import ( + _ "net" +) diff --git a/pkg/sentry/platform/procid/procid_test.go b/pkg/sentry/platform/procid/procid_test.go new file mode 100644 index 000000000..5e44da36f --- /dev/null +++ b/pkg/sentry/platform/procid/procid_test.go @@ -0,0 +1,85 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procid + +import ( + "os" + "runtime" + "sync" + "syscall" + "testing" +) + +// runOnMain is used to send functions to run on the main (initial) thread. +var runOnMain = make(chan func(), 10) + +func checkProcid(t *testing.T, start *sync.WaitGroup, done *sync.WaitGroup) { + defer done.Done() + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + start.Done() + start.Wait() + + procID := Current() + tid := syscall.Gettid() + + if procID != uint64(tid) { + t.Logf("Bad procid: expected %v, got %v", tid, procID) + t.Fail() + } +} + +func TestProcidInitialized(t *testing.T) { + var start sync.WaitGroup + var done sync.WaitGroup + + count := 100 + start.Add(count + 1) + done.Add(count + 1) + + // Run the check on the main thread. + // + // When cgo is not included, the only case when procid isn't initialized + // is in the main (initial) thread, so we have to test this case + // specifically. + runOnMain <- func() { + checkProcid(t, &start, &done) + } + + // Run the check on a number of different threads. + for i := 0; i < count; i++ { + go checkProcid(t, &start, &done) + } + + done.Wait() +} + +func TestMain(m *testing.M) { + // Make sure we remain at the main (initial) thread. + runtime.LockOSThread() + + // Start running tests in a different goroutine. + go func() { + os.Exit(m.Run()) + }() + + // Execute any functions that have been sent for execution in the main + // thread. + for f := range runOnMain { + f() + } +} diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD new file mode 100644 index 000000000..16b0b3c69 --- /dev/null +++ b/pkg/sentry/platform/ptrace/BUILD @@ -0,0 +1,31 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "ptrace", + srcs = [ + "ptrace.go", + "ptrace_unsafe.go", + "stub_amd64.s", + "stub_unsafe.go", + "subprocess.go", + "subprocess_amd64.go", + "subprocess_linux.go", + "subprocess_linux_amd64_unsafe.go", + "subprocess_unsafe.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ptrace", + visibility = ["//:sandbox"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/arch", + "//pkg/sentry/platform", + "//pkg/sentry/platform/filemem", + "//pkg/sentry/platform/interrupt", + "//pkg/sentry/platform/procid", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/usermem", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go new file mode 100644 index 000000000..05f8b1d05 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -0,0 +1,242 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ptrace provides a ptrace-based implementation of the platform +// interface. This is useful for development and testing purposes primarily, +// and runs on stock kernels without special permissions. +// +// In a nutshell, it works as follows: +// +// The creation of a new address space creates a new child processes with a +// single thread which is traced by a single goroutine. +// +// A context is just a collection of temporary variables. Calling Switch on a +// context does the following: +// +// Locks the runtime thread. +// +// Looks up a traced subprocess thread for the current runtime thread. If +// none exists, the dedicated goroutine is asked to create a new stopped +// thread in the subprocess. This stopped subprocess thread is then traced +// by the current thread and this information is stored for subsequent +// switches. +// +// The context is then bound with information about the subprocess thread +// so that the context may be appropriately interrupted via a signal. +// +// The requested operation is performed in the traced subprocess thread +// (e.g. set registers, execute, return). +// +// FIXME: This package is currently sloppy with cleanup. +// +// Lock order: +// +// subprocess.mu +// context.mu +package ptrace + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/filemem" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/interrupt" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +var ( + // stubStart is the link address for our stub, and determines the + // maximum user address. This is valid only after a call to stubInit. + // + // We attempt to link the stub here, and adjust downward as needed. + stubStart uintptr = 0x7fffffff0000 + + // stubEnd is the first byte past the end of the stub, as with + // stubStart this is valid only after a call to stubInit. + stubEnd uintptr + + // stubInitialized controls one-time stub initialization. + stubInitialized sync.Once +) + +type context struct { + // signalInfo is the signal info, if and when a signal is received. + signalInfo arch.SignalInfo + + // interrupt is the interrupt context. + interrupt interrupt.Forwarder + + // mu protects the following fields. + mu sync.Mutex + + // If lastFaultSP is non-nil, the last context switch was due to a fault + // received while executing lastFaultSP. Only context.Switch may set + // lastFaultSP to a non-nil value. + lastFaultSP *subprocess + + // lastFaultAddr is the last faulting address; this is only meaningful if + // lastFaultSP is non-nil. + lastFaultAddr usermem.Addr + + // lastFaultIP is the address of the last faulting instruction; + // this is also only meaningful if lastFaultSP is non-nil. + lastFaultIP usermem.Addr +} + +// Switch runs the provided context in the given address space. +func (c *context) Switch(as platform.AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) { + s := as.(*subprocess) + isSyscall := s.switchToApp(c, ac) + + var faultSP *subprocess + var faultAddr usermem.Addr + var faultIP usermem.Addr + if !isSyscall && linux.Signal(c.signalInfo.Signo) == linux.SIGSEGV { + faultSP = s + faultAddr = usermem.Addr(c.signalInfo.Addr()) + faultIP = usermem.Addr(ac.IP()) + } + + // Update the context to reflect the outcome of this context switch. + c.mu.Lock() + lastFaultSP := c.lastFaultSP + lastFaultAddr := c.lastFaultAddr + lastFaultIP := c.lastFaultIP + // At this point, c may not yet be in s.contexts, so c.lastFaultSP won't be + // updated by s.Unmap(). This is fine; we only need to synchronize with + // calls to s.Unmap() that occur after the handling of this fault. + c.lastFaultSP = faultSP + c.lastFaultAddr = faultAddr + c.lastFaultIP = faultIP + c.mu.Unlock() + + // Update subprocesses to reflect the outcome of this context switch. + if lastFaultSP != faultSP { + if lastFaultSP != nil { + lastFaultSP.mu.Lock() + delete(lastFaultSP.contexts, c) + lastFaultSP.mu.Unlock() + } + if faultSP != nil { + faultSP.mu.Lock() + faultSP.contexts[c] = struct{}{} + faultSP.mu.Unlock() + } + } + + if isSyscall { + return nil, usermem.NoAccess, nil + } + if faultSP == nil { + // Non-fault signal. + return &c.signalInfo, usermem.NoAccess, platform.ErrContextSignal + } + + // Got a page fault. Ideally, we'd get real fault type here, but ptrace + // doesn't expose this information. Instead, we use a simple heuristic: + // + // It was an instruction fault iff the faulting addr == instruction + // pointer. + // + // It was a write fault if the fault is immediately repeated. + at := usermem.Read + if faultAddr == faultIP { + at.Execute = true + } + if lastFaultSP == faultSP && + lastFaultAddr == faultAddr && + lastFaultIP == faultIP { + at.Write = true + } + return &c.signalInfo, at, platform.ErrContextSignal +} + +// Interrupt interrupts the running guest application associated with this context. +func (c *context) Interrupt() { + c.interrupt.NotifyInterrupt() +} + +// PTrace represents a collection of ptrace subprocesses. +type PTrace struct { + platform.MMapMinAddr + platform.NoCPUPreemptionDetection + *filemem.FileMem +} + +// New returns a new ptrace-based implementation of the platform interface. +func New() (*PTrace, error) { + stubInitialized.Do(func() { + // Initialize the stub. + stubInit() + + // Create the master process for the global pool. This must be + // done before initializing any other processes. + master, err := newSubprocess(createStub) + if err != nil { + // Should never happen. + panic("unable to initialize ptrace master: " + err.Error()) + } + + // Set the master on the globalPool. + globalPool.master = master + }) + + fm, err := filemem.New("ptrace-memory") + if err != nil { + return nil, err + } + + return &PTrace{FileMem: fm}, nil +} + +// SupportsAddressSpaceIO implements platform.Platform.SupportsAddressSpaceIO. +func (*PTrace) SupportsAddressSpaceIO() bool { + return false +} + +// CooperativelySchedulesAddressSpace implements platform.Platform.CooperativelySchedulesAddressSpace. +func (*PTrace) CooperativelySchedulesAddressSpace() bool { + return false +} + +// MapUnit implements platform.Platform.MapUnit. +func (*PTrace) MapUnit() uint64 { + // The host kernel manages page tables and arbitrary-sized mappings + // have effectively the same cost. + return 0 +} + +// MaxUserAddress returns the first address that may not be used by user +// applications. +func (*PTrace) MaxUserAddress() usermem.Addr { + return usermem.Addr(stubStart) +} + +// NewAddressSpace returns a new subprocess. +func (p *PTrace) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { + as, err := newSubprocess(globalPool.master.createStub) + return as, nil, err +} + +// NewContext returns an interruptible context. +func (*PTrace) NewContext() platform.Context { + return &context{} +} + +// Memory returns the platform memory used to do allocations. +func (p *PTrace) Memory() platform.Memory { + return p.FileMem +} diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go new file mode 100644 index 000000000..b55b2795a --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -0,0 +1,166 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ptrace + +import ( + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// GETREGSET/SETREGSET register set types. +// +// See include/uapi/linux/elf.h. +const ( + // _NT_PRFPREG is for x86 floating-point state without using xsave. + _NT_PRFPREG = 0x2 + + // _NT_X86_XSTATE is for x86 extended state using xsave. + _NT_X86_XSTATE = 0x202 +) + +// fpRegSet returns the GETREGSET/SETREGSET register set type to be used. +func fpRegSet(useXsave bool) uintptr { + if useXsave { + return _NT_X86_XSTATE + } + return _NT_PRFPREG +} + +// getRegs sets the regular register set. +func (t *thread) getRegs(regs *syscall.PtraceRegs) error { + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETREGS, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(regs)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} + +// setRegs sets the regular register set. +func (t *thread) setRegs(regs *syscall.PtraceRegs) error { + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_SETREGS, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(regs)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} + +// getFPRegs gets the floating-point data via the GETREGSET ptrace syscall. +func (t *thread) getFPRegs(fpState *arch.FloatingPointData, fpLen uint64, useXsave bool) error { + iovec := syscall.Iovec{ + Base: (*byte)(fpState), + Len: fpLen, + } + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETREGSET, + uintptr(t.tid), + fpRegSet(useXsave), + uintptr(unsafe.Pointer(&iovec)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} + +// setFPRegs sets the floating-point data via the SETREGSET ptrace syscall. +func (t *thread) setFPRegs(fpState *arch.FloatingPointData, fpLen uint64, useXsave bool) error { + iovec := syscall.Iovec{ + Base: (*byte)(fpState), + Len: fpLen, + } + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_SETREGSET, + uintptr(t.tid), + fpRegSet(useXsave), + uintptr(unsafe.Pointer(&iovec)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} + +// getSignalInfo retrieves information about the signal that caused the stop. +func (t *thread) getSignalInfo(si *arch.SignalInfo) error { + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETSIGINFO, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(si)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} + +// clone creates a new thread from this one. +// +// The returned thread will be stopped and available for any system thread to +// call attach on it. +// +// Precondition: the OS thread must be locked and own t. +func (t *thread) clone(initRegs *syscall.PtraceRegs) (*thread, error) { + r, ok := usermem.Addr(initRegs.Rsp).RoundUp() + if !ok { + return nil, syscall.EINVAL + } + rval, err := t.syscallIgnoreInterrupt( + initRegs, + syscall.SYS_CLONE, + arch.SyscallArgument{Value: uintptr( + syscall.CLONE_FILES | + syscall.CLONE_FS | + syscall.CLONE_SIGHAND | + syscall.CLONE_THREAD | + syscall.CLONE_PTRACE | + syscall.CLONE_VM)}, + // The stack pointer is just made up, but we have it be + // something sensible so the kernel doesn't think we're + // up to no good. Which we are. + arch.SyscallArgument{Value: uintptr(r)}, + arch.SyscallArgument{}, + arch.SyscallArgument{}, + // We use these registers initially, but really they + // could be anything. We're going to stop immediately. + arch.SyscallArgument{Value: uintptr(unsafe.Pointer(initRegs))}) + if err != nil { + return nil, err + } + + return &thread{ + tgid: t.tgid, + tid: int32(rval), + cpu: ^uint32(0), + }, nil +} diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s new file mode 100644 index 000000000..9bf87b6f6 --- /dev/null +++ b/pkg/sentry/platform/ptrace/stub_amd64.s @@ -0,0 +1,114 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +#define SYS_GETPID 39 +#define SYS_EXIT 60 +#define SYS_KILL 62 +#define SYS_GETPPID 110 +#define SYS_PRCTL 157 + +#define SIGKILL 9 +#define SIGSTOP 19 + +#define PR_SET_PDEATHSIG 1 + +// stub bootstraps the child and sends itself SIGSTOP to wait for attach. +// +// R15 contains the expected PPID. R15 is used instead of a more typical DI +// since syscalls will clobber DI and createStub wants to pass a new PPID to +// grandchildren. +// +// This should not be used outside the context of a new ptrace child (as the +// function is otherwise a bunch of nonsense). +TEXT ·stub(SB),NOSPLIT,$0 +begin: + // N.B. This loop only executes in the context of a single-threaded + // fork child. + + MOVQ $SYS_PRCTL, AX + MOVQ $PR_SET_PDEATHSIG, DI + MOVQ $SIGKILL, SI + SYSCALL + + CMPQ AX, $0 + JNE error + + // If the parent already died before we called PR_SET_DEATHSIG then + // we'll have an unexpected PPID. + MOVQ $SYS_GETPPID, AX + SYSCALL + + CMPQ AX, $0 + JL error + + CMPQ AX, R15 + JNE parent_dead + + MOVQ $SYS_GETPID, AX + SYSCALL + + CMPQ AX, $0 + JL error + + // SIGSTOP to wait for attach. + // + // The SYSCALL instruction will be used for future syscall injection by + // thread.syscall. + MOVQ AX, DI + MOVQ $SYS_KILL, AX + MOVQ $SIGSTOP, SI + SYSCALL + + // The tracer may "detach" and/or allow code execution here in three cases: + // + // 1. New (traced) stub threads are explicitly detached by the + // goroutine in newSubprocess. However, they are detached while in + // group-stop, so they do not execute code here. + // + // 2. If a tracer thread exits, it implicitly detaches from the stub, + // potentially allowing code execution here. However, the Go runtime + // never exits individual threads, so this case never occurs. + // + // 3. subprocess.createStub clones a new stub process that is untraced, + // thus executing this code. We setup the PDEATHSIG before SIGSTOPing + // ourselves for attach by the tracer. + // + // R15 has been updated with the expected PPID. + JMP begin + +error: + // Exit with -errno. + MOVQ AX, DI + NEGQ DI + MOVQ $SYS_EXIT, AX + SYSCALL + HLT + +parent_dead: + MOVQ $SYS_EXIT, AX + MOVQ $1, DI + SYSCALL + HLT + +// stubCall calls the stub function at the given address with the given PPID. +// +// This is a distinct function because stub, above, may be mapped at any +// arbitrary location, and stub has a specific binary API (see above). +TEXT ·stubCall(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), AX + MOVQ pid+8(FP), R15 + JMP AX diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go new file mode 100644 index 000000000..c868a2d68 --- /dev/null +++ b/pkg/sentry/platform/ptrace/stub_unsafe.go @@ -0,0 +1,98 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ptrace + +import ( + "reflect" + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// stub is defined in arch-specific assembly. +func stub() + +// stubCall calls the stub at the given address with the given pid. +func stubCall(addr, pid uintptr) + +// unsafeSlice returns a slice for the given address and length. +func unsafeSlice(addr uintptr, length int) (slice []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + sh.Data = addr + sh.Len = length + sh.Cap = length + return +} + +// stubInit initializes the stub. +func stubInit() { + // Grab the existing stub. + stubBegin := reflect.ValueOf(stub).Pointer() + stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin) + stubSlice := unsafeSlice(stubBegin, stubLen) + mapLen := uintptr(stubLen) + if offset := mapLen % usermem.PageSize; offset != 0 { + mapLen += usermem.PageSize - offset + } + + for stubStart > 0 { + // Map the target address for the stub. + // + // We don't use FIXED here because we don't want to unmap + // something that may have been there already. We just walk + // down the address space until we find a place where the stub + // can be placed. + addr, _, errno := syscall.RawSyscall6( + syscall.SYS_MMAP, + stubStart, + mapLen, + syscall.PROT_WRITE|syscall.PROT_READ, + syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS, + 0 /* fd */, 0 /* offset */) + if addr != stubStart || errno != 0 { + if addr != 0 { + // Unmap the region we've mapped accidentally. + syscall.RawSyscall(syscall.SYS_MUNMAP, addr, mapLen, 0) + } + + // Attempt to begin at a lower address. + stubStart -= uintptr(usermem.PageSize) + continue + } + + // Copy the stub to the address. + targetSlice := unsafeSlice(addr, stubLen) + copy(targetSlice, stubSlice) + + // Make the stub executable. + if _, _, errno := syscall.RawSyscall( + syscall.SYS_MPROTECT, + stubStart, + mapLen, + syscall.PROT_EXEC|syscall.PROT_READ); errno != 0 { + panic("mprotect failed: " + errno.Error()) + } + + // Set the end. + stubEnd = stubStart + mapLen + return + } + + // This will happen only if we exhaust the entire address + // space, and it will take a long, long time. + panic("failed to map stub") +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go new file mode 100644 index 000000000..0d6a38f15 --- /dev/null +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -0,0 +1,559 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ptrace + +import ( + "fmt" + "os" + "runtime" + "sync" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// globalPool exists to solve two distinct problems: +// +// 1) Subprocesses can't always be killed properly (see Release). +// +// 2) Any seccomp filters that have been installed will apply to subprocesses +// created here. Therefore we use the intermediary (master), which is created +// on initialization of the platform. +var globalPool struct { + mu sync.Mutex + master *subprocess + available []*subprocess +} + +// thread is a traced thread; it is a thread identifier. +// +// This is a convenience type for defining ptrace operations. +type thread struct { + tgid int32 + tid int32 + cpu uint32 +} + +// threadPool is a collection of threads. +type threadPool struct { + // mu protects below. + mu sync.Mutex + + // threads is the collection of threads. + // + // This map is indexed by system TID (the calling thread); which will + // be the tracer for the given *thread, and therefore capable of using + // relevant ptrace calls. + threads map[int32]*thread +} + +// lookupOrCreate looks up a given thread or creates one. +// +// newThread will generally be subprocess.newThread. +// +// Precondition: the runtime OS thread must be locked. +func (tp *threadPool) lookupOrCreate(currentTID int32, newThread func() *thread) *thread { + tp.mu.Lock() + t, ok := tp.threads[currentTID] + if !ok { + // Before creating a new thread, see if we can find a thread + // whose system tid has disappeared. + // + // TODO: Other parts of this package depend on + // threads never exiting. + for origTID, t := range tp.threads { + // Signal zero is an easy existence check. + if err := syscall.Tgkill(syscall.Getpid(), int(origTID), 0); err != nil { + // This thread has been abandoned; reuse it. + delete(tp.threads, origTID) + tp.threads[currentTID] = t + tp.mu.Unlock() + return t + } + } + + // Create a new thread. + t = newThread() + tp.threads[currentTID] = t + } + tp.mu.Unlock() + return t +} + +// subprocess is a collection of threads being traced. +type subprocess struct { + platform.NoAddressSpaceIO + + // initRegs are the initial registers for the first thread. + // + // These are used for the register set for system calls. + initRegs syscall.PtraceRegs + + // requests is used to signal creation of new threads. + requests chan chan *thread + + // sysemuThreads are reserved for emulation. + sysemuThreads threadPool + + // syscallThreads are reserved for syscalls (except clone, which is + // handled in the dedicated goroutine corresponding to requests above). + syscallThreads threadPool + + // mu protects the following fields. + mu sync.Mutex + + // contexts is the set of contexts for which it's possible that + // context.lastFaultSP == this subprocess. + contexts map[*context]struct{} +} + +// newSubprocess returns a useable subprocess. +// +// This will either be a newly created subprocess, or one from the global pool. +// The create function will be called in the latter case, which is guaranteed +// to happen with the runtime thread locked. +func newSubprocess(create func() (*thread, error)) (*subprocess, error) { + // See Release. + globalPool.mu.Lock() + if len(globalPool.available) > 0 { + sp := globalPool.available[len(globalPool.available)-1] + globalPool.available = globalPool.available[:len(globalPool.available)-1] + globalPool.mu.Unlock() + return sp, nil + } + globalPool.mu.Unlock() + + // The following goroutine is responsible for creating the first traced + // thread, and responding to requests to make additional threads in the + // traced process. The process will be killed and reaped when the + // request channel is closed, which happens in Release below. + var initRegs syscall.PtraceRegs + errChan := make(chan error) + requests := make(chan chan *thread) + go func() { // S/R-SAFE: Platform-related. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + // Initialize the first thread. + firstThread, err := create() + if err != nil { + errChan <- err + return + } + + // Grab registers. + // + // Note that we adjust the current register RIP value to be + // just before the current system call executed. This depends + // on the definition of the stub itself. + if err := firstThread.getRegs(&initRegs); err != nil { + panic(fmt.Sprintf("ptrace get regs failed: %v", err)) + } + initRegs.Rip -= initRegsRipAdjustment + + // Ready to handle requests. + errChan <- nil + + // Wait for requests to create threads. + for r := range requests { + t, err := firstThread.clone(&initRegs) + if err != nil { + // Should not happen: not recoverable. + panic(fmt.Sprintf("error initializing first thread: %v", err)) + } + + // Since the new thread was created with + // clone(CLONE_PTRACE), it will begin execution with + // SIGSTOP pending and with this thread as its tracer. + // (Hopefully nobody tgkilled it with a signal < + // SIGSTOP before the SIGSTOP was delivered, in which + // case that signal would be delivered before SIGSTOP.) + if sig := t.wait(); sig != syscall.SIGSTOP { + panic(fmt.Sprintf("error waiting for new clone: expected SIGSTOP, got %v", sig)) + } + + // Detach the thread without suppressing the SIGSTOP, + // causing it to enter group-stop. + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_DETACH, uintptr(t.tid), 0, uintptr(syscall.SIGSTOP), 0, 0); errno != 0 { + panic(fmt.Sprintf("can't detach new clone: %v", errno)) + } + + // Return the thread. + r <- t + } + + // Requests should never be closed. + panic("unreachable") + }() + + // Wait until error or readiness. + if err := <-errChan; err != nil { + return nil, err + } + + // Ready. + sp := &subprocess{ + initRegs: initRegs, + requests: requests, + sysemuThreads: threadPool{ + threads: make(map[int32]*thread), + }, + syscallThreads: threadPool{ + threads: make(map[int32]*thread), + }, + contexts: make(map[*context]struct{}), + } + + sp.unmap() + return sp, nil +} + +// unmap unmaps non-stub regions of the process. +// +// This will panic on failure (which should never happen). +func (s *subprocess) unmap() { + s.Unmap(0, uint64(stubStart)) + if maximumUserAddress != stubEnd { + s.Unmap(usermem.Addr(stubEnd), uint64(maximumUserAddress-stubEnd)) + } +} + +// Release kills the subprocess. +// +// Just kidding! We can't safely co-ordinate the detaching of all the +// tracees (since the tracers are random runtime threads, and the process +// won't exit until tracers have been notifier). +// +// Therefore we simply unmap everything in the subprocess and return it to the +// globalPool. This has the added benefit of reducing creation time for new +// subprocesses. +func (s *subprocess) Release() error { + go func() { // S/R-SAFE: Platform. + s.unmap() + globalPool.mu.Lock() + globalPool.available = append(globalPool.available, s) + globalPool.mu.Unlock() + }() + return nil +} + +// newThread creates a new traced thread. +// +// Precondition: the OS thread must be locked. +func (s *subprocess) newThread() *thread { + // Ask the first thread to create a new one. + r := make(chan *thread) + s.requests <- r + t := <-r + + // Attach the subprocess to this one. + t.attach() + + // Return the new thread, which is now bound. + return t +} + +// attach attachs to the thread. +func (t *thread) attach() { + if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 { + panic(fmt.Sprintf("unable to attach: %v", errno)) + } + + // PTRACE_ATTACH sends SIGSTOP, and wakes the tracee if it was already + // stopped from the SIGSTOP queued by CLONE_PTRACE (see inner loop of + // newSubprocess), so we always expect to see signal-delivery-stop with + // SIGSTOP. + if sig := t.wait(); sig != syscall.SIGSTOP { + panic(fmt.Sprintf("wait failed: expected SIGSTOP, got %v", sig)) + } + + // Initialize options. + t.init() +} + +// wait waits for a stop event. +func (t *thread) wait() syscall.Signal { + var status syscall.WaitStatus + + for { + r, err := syscall.Wait4(int(t.tid), &status, syscall.WALL|syscall.WUNTRACED, nil) + if err == syscall.EINTR || err == syscall.EAGAIN { + // Wait was interrupted; wait again. + continue + } else if err != nil { + panic(fmt.Sprintf("ptrace wait failed: %v", err)) + } + if int(r) != int(t.tid) { + panic(fmt.Sprintf("ptrace wait returned %v, expected %v", r, t.tid)) + } + if !status.Stopped() { + panic(fmt.Sprintf("ptrace status unexpected: got %v, wanted stopped", status)) + } + if status.StopSignal() == 0 { + continue // Spurious stop. + } + return status.StopSignal() + } +} + +// init initializes trace options. +func (t *thread) init() { + // Set our TRACESYSGOOD option to differeniate real SIGTRAP. + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_SETOPTIONS, + uintptr(t.tid), + 0, + syscall.PTRACE_O_TRACESYSGOOD, + 0, 0) + if errno != 0 { + panic(fmt.Sprintf("ptrace set options failed: %v", errno)) + } +} + +// syscall executes a system call cycle in the traced context. +// +// This is _not_ for use by application system calls, rather it is for use when +// a system call must be injected into the remote context (e.g. mmap, munmap). +// Note that clones are handled separately. +func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { + // Set registers. + if err := t.setRegs(regs); err != nil { + panic(fmt.Sprintf("ptrace set regs failed: %v", err)) + } + + for { + // Execute the syscall instruction. + if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 { + panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) + } + + sig := t.wait() + if sig == (0x80 | syscall.SIGTRAP) { + // Reached syscall-enter-stop. + break + } else { + // Some other signal caused a thread stop; ignore. + continue + } + } + + // Complete the actual system call. + if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 { + panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) + } + + // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens + // between syscall-enter-stop and syscall-exit-stop; it happens *after* + // syscall-exit-stop.)" - ptrace(2), "Syscall-stops" + if sig := t.wait(); sig != (0x80 | syscall.SIGTRAP) { + panic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig)) + } + + // Grab registers. + if err := t.getRegs(regs); err != nil { + panic(fmt.Sprintf("ptrace get regs failed: %v", err)) + } + + return syscallReturnValue(regs) +} + +// syscallIgnoreInterrupt ignores interrupts on the system call thread and +// restarts the syscall if the kernel indicates that should happen. +func (t *thread) syscallIgnoreInterrupt( + initRegs *syscall.PtraceRegs, + sysno uintptr, + args ...arch.SyscallArgument) (uintptr, error) { + for { + regs := createSyscallRegs(initRegs, sysno, args...) + rval, err := t.syscall(®s) + switch err { + case ERESTARTSYS: + continue + case ERESTARTNOINTR: + continue + case ERESTARTNOHAND: + continue + default: + return rval, err + } + } +} + +// NotifyInterrupt implements interrupt.Receiver.NotifyInterrupt. +func (t *thread) NotifyInterrupt() { + syscall.Tgkill(int(t.tgid), int(t.tid), syscall.Signal(platform.SignalInterrupt)) +} + +// switchToApp is called from the main SwitchToApp entrypoint. +// +// This function returns true on a system call, false on a signal. +func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { + regs := &ac.StateData().Regs + s.resetSysemuRegs(regs) + + // Extract floating point state. + fpState := ac.FloatingPointData() + fpLen, _ := ac.FeatureSet().ExtendedStateSize() + useXsave := ac.FeatureSet().UseXsave() + + // Lock the thread for ptrace operations. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + // Grab our thread from the pool. + currentTID := int32(procid.Current()) + t := s.sysemuThreads.lookupOrCreate(currentTID, s.newThread) + + // Check for interrupts, and ensure that future interrupts will signal t. + if !c.interrupt.Enable(t) { + // Pending interrupt; simulate. + c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)} + return false + } + defer c.interrupt.Disable() + + // Ensure that the CPU set is bound appropriately; this makes the + // emulation below several times faster, presumably by avoiding + // interprocessor wakeups and by simplifying the schedule. + t.bind() + + // Set registers. + if err := t.setRegs(regs); err != nil { + panic(fmt.Sprintf("ptrace set regs failed: %v", err)) + } + if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil { + panic(fmt.Sprintf("ptrace set fpregs failed: %v", err)) + } + + for { + // Start running until the next system call. + if isSingleStepping(regs) { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_PTRACE, + syscall.PTRACE_SYSEMU_SINGLESTEP, + uintptr(t.tid), 0); errno != 0 { + panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) + } + } else { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_PTRACE, + syscall.PTRACE_SYSEMU, + uintptr(t.tid), 0); errno != 0 { + panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) + } + } + + // Wait for the syscall-enter stop. + sig := t.wait() + + // Refresh all registers. + if err := t.getRegs(regs); err != nil { + panic(fmt.Sprintf("ptrace get regs failed: %v", err)) + } + if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil { + panic(fmt.Sprintf("ptrace get fpregs failed: %v", err)) + } + + // Is it a system call? + if sig == (0x80 | syscall.SIGTRAP) { + // Ensure registers are sane. + updateSyscallRegs(regs) + return true + } + + if sig == syscall.SIGSTOP { + // SIGSTOP was delivered to another thread in the same thread + // group, which initiated another group stop. Just ignore it. + continue + } + + // Grab signal information. + if err := t.getSignalInfo(&c.signalInfo); err != nil { + // Should never happen. + panic(fmt.Sprintf("ptrace get signal info failed: %v", err)) + } + + // We have a signal. We verify however, that the signal was + // either delivered from the kernel or from this process. We + // don't respect other signals. + if c.signalInfo.Code > 0 { + return false // kernel. + } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) { + return false // this process. + } + } +} + +// syscall executes the given system call without handling interruptions. +func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintptr, error) { + // Grab a thread. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + currentTID := int32(procid.Current()) + t := s.syscallThreads.lookupOrCreate(currentTID, s.newThread) + + return t.syscallIgnoreInterrupt(&s.initRegs, sysno, args...) +} + +// MapFile implements platform.AddressSpace.MapFile. +func (s *subprocess) MapFile(addr usermem.Addr, fd int, fr platform.FileRange, at usermem.AccessType, precommit bool) error { + var flags int + if precommit { + flags |= syscall.MAP_POPULATE + } + _, err := s.syscall( + syscall.SYS_MMAP, + arch.SyscallArgument{Value: uintptr(addr)}, + arch.SyscallArgument{Value: uintptr(fr.Length())}, + arch.SyscallArgument{Value: uintptr(at.Prot())}, + arch.SyscallArgument{Value: uintptr(flags | syscall.MAP_SHARED | syscall.MAP_FIXED)}, + arch.SyscallArgument{Value: uintptr(fd)}, + arch.SyscallArgument{Value: uintptr(fr.Start)}) + return err +} + +// Unmap implements platform.AddressSpace.Unmap. +func (s *subprocess) Unmap(addr usermem.Addr, length uint64) { + ar, ok := addr.ToRange(length) + if !ok { + panic(fmt.Sprintf("addr %#x + length %#x overflows", addr, length)) + } + s.mu.Lock() + for c := range s.contexts { + c.mu.Lock() + if c.lastFaultSP == s && ar.Contains(c.lastFaultAddr) { + // Forget the last fault so that if c faults again, the fault isn't + // incorrectly reported as a write fault. If this is being called + // due to munmap() of the corresponding vma, handling of the second + // fault will fail anyway. + c.lastFaultSP = nil + delete(s.contexts, c) + } + c.mu.Unlock() + } + s.mu.Unlock() + _, err := s.syscall( + syscall.SYS_MUNMAP, + arch.SyscallArgument{Value: uintptr(addr)}, + arch.SyscallArgument{Value: uintptr(length)}) + if err != nil { + // We never expect this to happen. + panic(fmt.Sprintf("munmap(%x, %x)) failed: %v", addr, length, err)) + } +} diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go new file mode 100644 index 000000000..8211215df --- /dev/null +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -0,0 +1,104 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package ptrace + +import ( + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" +) + +const ( + // maximumUserAddress is the largest possible user address. + maximumUserAddress = 0x7ffffffff000 + + // initRegsRipAdjustment is the size of the syscall instruction. + initRegsRipAdjustment = 2 +) + +// Linux kernel errnos which "should never be seen by user programs", but will +// be revealed to ptrace syscall exit tracing. +// +// These constants are used in subprocess.go. +const ( + ERESTARTSYS = syscall.Errno(512) + ERESTARTNOINTR = syscall.Errno(513) + ERESTARTNOHAND = syscall.Errno(514) +) + +// resetSysemuRegs sets up emulation registers. +// +// This should be called prior to calling sysemu. +func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) { + regs.Cs = s.initRegs.Cs + regs.Ss = s.initRegs.Ss + regs.Ds = s.initRegs.Ds + regs.Es = s.initRegs.Es + regs.Fs = s.initRegs.Fs + regs.Gs = s.initRegs.Gs +} + +// createSyscallRegs sets up syscall registers. +// +// This should be called to generate registers for a system call. +func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs { + // Copy initial registers (RIP, segments, etc.). + regs := *initRegs + + // Set our syscall number. + regs.Rax = uint64(sysno) + if len(args) >= 1 { + regs.Rdi = args[0].Uint64() + } + if len(args) >= 2 { + regs.Rsi = args[1].Uint64() + } + if len(args) >= 3 { + regs.Rdx = args[2].Uint64() + } + if len(args) >= 4 { + regs.R10 = args[3].Uint64() + } + if len(args) >= 5 { + regs.R8 = args[4].Uint64() + } + if len(args) >= 6 { + regs.R9 = args[5].Uint64() + } + + return regs +} + +// isSingleStepping determines if the registers indicate single-stepping. +func isSingleStepping(regs *syscall.PtraceRegs) bool { + return (regs.Eflags & arch.X86TrapFlag) != 0 +} + +// updateSyscallRegs updates registers after finishing sysemu. +func updateSyscallRegs(regs *syscall.PtraceRegs) { + // Ptrace puts -ENOSYS in rax on syscall-enter-stop. + regs.Rax = regs.Orig_rax +} + +// syscallReturnValue extracts a sensible return from registers. +func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) { + rval := int64(regs.Rax) + if rval < 0 { + return 0, syscall.Errno(-rval) + } + return uintptr(rval), nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go new file mode 100644 index 000000000..227dd4882 --- /dev/null +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -0,0 +1,146 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package ptrace + +import ( + "fmt" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid" +) + +// createStub creates a fresh stub processes. +// +// Precondition: the runtime OS thread must be locked. +func createStub() (*thread, error) { + // Declare all variables up front in order to ensure that there's no + // need for allocations between beforeFork & afterFork. + var ( + pid uintptr + ppid uintptr + errno syscall.Errno + ) + + // Remember the current ppid for the pdeathsig race. + ppid, _, _ = syscall.RawSyscall(syscall.SYS_GETPID, 0, 0, 0) + + // Among other things, beforeFork masks all signals. + beforeFork() + pid, _, errno = syscall.RawSyscall6(syscall.SYS_CLONE, uintptr(syscall.SIGCHLD)|syscall.CLONE_FILES, 0, 0, 0, 0, 0) + if errno != 0 { + afterFork() + return nil, errno + } + + // Is this the parent? + if pid != 0 { + // Among other things, restore signal mask. + afterFork() + + // Initialize the first thread. + t := &thread{ + tgid: int32(pid), + tid: int32(pid), + cpu: ^uint32(0), + } + if sig := t.wait(); sig != syscall.SIGSTOP { + return nil, fmt.Errorf("wait failed: expected SIGSTOP, got %v", sig) + } + t.attach() + + return t, nil + } + + // afterForkInChild resets all signals to their default dispositions + // and restores the signal mask to its pre-fork state. + afterForkInChild() + + // Explicitly unmask all signals to ensure that the tracer can see + // them. + errno = unmaskAllSignals() + if errno != 0 { + syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0) + } + + // Call the stub; should not return. + stubCall(stubStart, ppid) + panic("unreachable") +} + +// createStub creates a stub processes as a child of an existing subprocesses. +// +// Precondition: the runtime OS thread must be locked. +func (s *subprocess) createStub() (*thread, error) { + // There's no need to lock the runtime thread here, as this can only be + // called from a context that is already locked. + currentTID := int32(procid.Current()) + t := s.syscallThreads.lookupOrCreate(currentTID, s.newThread) + + // Pass the expected PPID to the child via R15. + regs := s.initRegs + regs.R15 = uint64(t.tgid) + + // Call fork in a subprocess. + // + // The new child must set up PDEATHSIG to ensure it dies if this + // process dies. Since this process could die at any time, this cannot + // be done via instrumentation from here. + // + // Instead, we create the child untraced, which will do the PDEATHSIG + // setup and then SIGSTOP itself for our attach below. + pid, err := t.syscallIgnoreInterrupt( + ®s, + syscall.SYS_CLONE, + arch.SyscallArgument{Value: uintptr(syscall.SIGCHLD | syscall.CLONE_FILES)}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: 0}) + if err != nil { + return nil, err + } + + // Wait for child to enter group-stop, so we don't stop its + // bootstrapping work with t.attach below. + // + // We unfortunately don't have a handy part of memory to write the wait + // status. If the wait succeeds, we'll assume that it was the SIGSTOP. + // If the child actually exited, the attach below will fail. + _, err = t.syscallIgnoreInterrupt( + &s.initRegs, + syscall.SYS_WAIT4, + arch.SyscallArgument{Value: uintptr(pid)}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: syscall.WUNTRACED}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: 0}, + arch.SyscallArgument{Value: 0}) + if err != nil { + return nil, err + } + + childT := &thread{ + tgid: int32(pid), + tid: int32(pid), + cpu: ^uint32(0), + } + childT.attach() + + return childT, nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go new file mode 100644 index 000000000..697431472 --- /dev/null +++ b/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go @@ -0,0 +1,109 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 linux + +package ptrace + +import ( + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" +) + +// maskPool contains reusable CPU masks for setting affinity. Unfortunately, +// runtime.NumCPU doesn't actually record the number of CPUs on the system, it +// just records the number of CPUs available in the scheduler affinity set at +// startup. This may a) change over time and b) gives a number far lower than +// the maximum indexable CPU. To prevent lots of allocation in the hot path, we +// use a pool to store large masks that we can reuse during bind. +var maskPool = sync.Pool{ + New: func() interface{} { + const maxCPUs = 1024 // Not a hard limit; see below. + return make([]uintptr, maxCPUs/64) + }, +} + +// unmaskAllSignals unmasks all signals on the current thread. +// +//go:nosplit +func unmaskAllSignals() syscall.Errno { + var set linux.SignalSet + _, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0) + return errno +} + +// getCPU gets the current CPU. +// +// Precondition: the current runtime thread should be locked. +func getCPU() (uint32, error) { + var cpu uintptr + if _, _, errno := syscall.RawSyscall( + unix.SYS_GETCPU, + uintptr(unsafe.Pointer(&cpu)), + 0, 0); errno != 0 { + return 0, errno + } + return uint32(cpu), nil +} + +// setCPU sets the CPU affinity. +func (t *thread) setCPU(cpu uint32) error { + mask := maskPool.Get().([]uintptr) + n := int(cpu / 64) + v := uintptr(1 << uintptr(cpu%64)) + if n >= len(mask) { + // See maskPool note above. We've actually exceeded the number + // of available cores. Grow the mask and return it. + mask = make([]uintptr, n+1) + } + mask[n] |= v + if _, _, errno := syscall.RawSyscall( + unix.SYS_SCHED_SETAFFINITY, + uintptr(t.tid), + uintptr(len(mask)*8), + uintptr(unsafe.Pointer(&mask[0]))); errno != 0 { + return errno + } + mask[n] &^= v + maskPool.Put(mask) + return nil +} + +// bind attempts to ensure that the thread is on the same CPU as the current +// thread. This provides no guarantees as it is fundamentally a racy operation: +// CPU sets may change and we may be rescheduled in the middle of this +// operation. As a result, no failures are reported. +// +// Precondition: the current runtime thread should be locked. +func (t *thread) bind() { + currentCPU, err := getCPU() + if err != nil { + return + } + if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { + // Set the affinity on the thread and save the CPU for next + // round; we don't expect CPUs to bounce around too frequently. + // + // (It's worth noting that we could move CPUs between this point + // and when the tracee finishes executing. But that would be + // roughly the status quo anyways -- we're just maximizing our + // chances of colocation, not guaranteeing it.) + t.setCPU(currentCPU) + } +} diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go new file mode 100644 index 000000000..fe41641d3 --- /dev/null +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ptrace + +import ( + _ "unsafe" // required for go:linkname. +) + +//go:linkname beforeFork syscall.runtime_BeforeFork +func beforeFork() + +//go:linkname afterFork syscall.runtime_AfterFork +func afterFork() + +//go:linkname afterForkInChild syscall.runtime_AfterForkInChild +func afterForkInChild() diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD new file mode 100644 index 000000000..2df232a64 --- /dev/null +++ b/pkg/sentry/platform/ring0/BUILD @@ -0,0 +1,52 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") + +go_template( + name = "defs", + srcs = [ + "defs.go", + "defs_amd64.go", + "offsets_amd64.go", + "x86.go", + ], + visibility = [":__subpackages__"], +) + +go_template_instance( + name = "defs_impl", + out = "defs_impl.go", + package = "ring0", + template = ":defs", +) + +genrule( + name = "entry_impl_amd64", + srcs = ["entry_amd64.s"], + outs = ["entry_impl_amd64.s"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) + +go_library( + name = "ring0", + srcs = [ + "defs_impl.go", + "entry_amd64.go", + "entry_impl_amd64.s", + "kernel.go", + "kernel_amd64.go", + "kernel_unsafe.go", + "lib_amd64.go", + "lib_amd64.s", + "ring0.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/cpuid", + "//pkg/sentry/platform/ring0/pagetables", + "//pkg/sentry/usermem", + ], +) diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go new file mode 100644 index 000000000..9d947b73d --- /dev/null +++ b/pkg/sentry/platform/ring0/defs.go @@ -0,0 +1,93 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ring0 + +import ( + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// Kernel is a global kernel object. +// +// This contains global state, shared by multiple CPUs. +type Kernel struct { + KernelArchState +} + +// CPU is the per-CPU struct. +type CPU struct { + // self is a self reference. + // + // This is always guaranteed to be at offset zero. + self *CPU + + // kernel is reference to the kernel that this CPU was initialized + // with. This reference is kept for garbage collection purposes: CPU + // registers may refer to objects within the Kernel object that cannot + // be safely freed. + kernel *Kernel + + // CPUArchState is architecture-specific state. + CPUArchState + + // registers is a set of registers; these may be used on kernel system + // calls and exceptions via the Registers function. + registers syscall.PtraceRegs + + // KernelException handles an exception during kernel execution. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelException func(Vector) + + // KernelSyscall is called for kernel system calls. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelSyscall func() +} + +// Registers returns a modifiable-copy of the kernel registers. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) Registers() *syscall.PtraceRegs { + return &c.registers +} diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go new file mode 100644 index 000000000..bb3420125 --- /dev/null +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -0,0 +1,113 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package ring0 + +import ( + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +// Segment indices and Selectors. +const ( + // Index into GDT array. + _ = iota // Null descriptor first. + _ // Reserved (Linux is kernel 32). + segKcode // Kernel code (64-bit). + segKdata // Kernel data. + segUcode32 // User code (32-bit). + segUdata // User data. + segUcode64 // User code (64-bit). + segTss // Task segment descriptor. + segTssHi // Upper bits for TSS. + segLast // Last segment (terminal, not included). +) + +// Selectors. +const ( + Kcode Selector = segKcode << 3 + Kdata Selector = segKdata << 3 + Ucode32 Selector = (segUcode32 << 3) | 3 + Udata Selector = (segUdata << 3) | 3 + Ucode64 Selector = (segUcode64 << 3) | 3 + Tss Selector = segTss << 3 +) + +// Standard segments. +var ( + UserCodeSegment32 SegmentDescriptor + UserDataSegment SegmentDescriptor + UserCodeSegment64 SegmentDescriptor + KernelCodeSegment SegmentDescriptor + KernelDataSegment SegmentDescriptor +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts + + // globalIDT is our set of interrupt gates. + globalIDT idt64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [256]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +func init() { + KernelCodeSegment.setCode64(0, 0, 0) + KernelDataSegment.setData(0, 0xffffffff, 0) + UserCodeSegment32.setCode64(0, 0, 3) + UserDataSegment.setData(0, 0xffffffff, 3) + UserCodeSegment64.setCode64(0, 0, 3) +} diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go new file mode 100644 index 000000000..a3e992e0d --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_amd64.go @@ -0,0 +1,128 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package ring0 + +import ( + "syscall" +) + +// This is an assembly function. +// +// The sysenter function is invoked in two situations: +// +// (1) The guest kernel has executed a system call. +// (2) The guest application has executed a system call. +// +// The interrupt flag is examined to determine whether the system call was +// executed from kernel mode or not and the appropriate stub is called. +func sysenter() + +// swapgs swaps the current GS value. +// +// This must be called prior to sysret/iret. +func swapgs() + +// sysret returns to userspace from a system call. +// +// The return code is the vector that interrupted execution. +// +// See stubs.go for a note regarding the frame size of this function. +func sysret(*CPU, *syscall.PtraceRegs) Vector + +// "iret is the cadillac of CPL switching." +// +// -- Neel Natu +// +// iret is nearly identical to sysret, except an iret is used to fully restore +// all user state. This must be called in cases where all registers need to be +// restored. +func iret(*CPU, *syscall.PtraceRegs) Vector + +// exception is the generic exception entry. +// +// This is called by the individual stub definitions. +func exception() + +// resume is a stub that restores the CPU kernel registers. +// +// This is used when processing kernel exceptions and syscalls. +func resume() + +// Start is the CPU entrypoint. +// +// The following start conditions must be satisfied: +// +// * AX should contain the CPU pointer. +// * c.GDT() should be loaded as the GDT. +// * c.IDT() should be loaded as the IDT. +// * c.CR0() should be the current CR0 value. +// * c.CR3() should be set to the kernel PageTables. +// * c.CR4() should be the current CR4 value. +// * c.EFER() should be the current EFER value. +// +// The CPU state will be set to c.Registers(). +func Start() + +// Exception stubs. +func divideByZero() +func debug() +func nmi() +func breakpoint() +func overflow() +func boundRangeExceeded() +func invalidOpcode() +func deviceNotAvailable() +func doubleFault() +func coprocessorSegmentOverrun() +func invalidTSS() +func segmentNotPresent() +func stackSegmentFault() +func generalProtectionFault() +func pageFault() +func x87FloatingPointException() +func alignmentCheck() +func machineCheck() +func simdFloatingPointException() +func virtualizationException() +func securityException() +func syscallInt80() + +// Exception handler index. +var handlers = map[Vector]func(){ + DivideByZero: divideByZero, + Debug: debug, + NMI: nmi, + Breakpoint: breakpoint, + Overflow: overflow, + BoundRangeExceeded: boundRangeExceeded, + InvalidOpcode: invalidOpcode, + DeviceNotAvailable: deviceNotAvailable, + DoubleFault: doubleFault, + CoprocessorSegmentOverrun: coprocessorSegmentOverrun, + InvalidTSS: invalidTSS, + SegmentNotPresent: segmentNotPresent, + StackSegmentFault: stackSegmentFault, + GeneralProtectionFault: generalProtectionFault, + PageFault: pageFault, + X87FloatingPointException: x87FloatingPointException, + AlignmentCheck: alignmentCheck, + MachineCheck: machineCheck, + SIMDFloatingPointException: simdFloatingPointException, + VirtualizationException: virtualizationException, + SecurityException: securityException, + SyscallInt80: syscallInt80, +} diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s new file mode 100644 index 000000000..e8638133b --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_amd64.s @@ -0,0 +1,334 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +// NB: Offsets are programatically generated (see BUILD). +// +// This file is concatenated with the definitions. + +// Saves a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// +// The following registers are not saved: AX, SP, IP, FLAGS, all segments. +#define REGISTERS_SAVE(reg, offset) \ + MOVQ R15, offset+PTRACE_R15(reg); \ + MOVQ R14, offset+PTRACE_R14(reg); \ + MOVQ R13, offset+PTRACE_R13(reg); \ + MOVQ R12, offset+PTRACE_R12(reg); \ + MOVQ BP, offset+PTRACE_RBP(reg); \ + MOVQ BX, offset+PTRACE_RBX(reg); \ + MOVQ CX, offset+PTRACE_RCX(reg); \ + MOVQ DX, offset+PTRACE_RDX(reg); \ + MOVQ R11, offset+PTRACE_R11(reg); \ + MOVQ R10, offset+PTRACE_R10(reg); \ + MOVQ R9, offset+PTRACE_R9(reg); \ + MOVQ R8, offset+PTRACE_R8(reg); \ + MOVQ SI, offset+PTRACE_RSI(reg); \ + MOVQ DI, offset+PTRACE_RDI(reg); + +// Loads a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// +// The following registers are not loaded: AX, SP, IP, FLAGS, all segments. +#define REGISTERS_LOAD(reg, offset) \ + MOVQ offset+PTRACE_R15(reg), R15; \ + MOVQ offset+PTRACE_R14(reg), R14; \ + MOVQ offset+PTRACE_R13(reg), R13; \ + MOVQ offset+PTRACE_R12(reg), R12; \ + MOVQ offset+PTRACE_RBP(reg), BP; \ + MOVQ offset+PTRACE_RBX(reg), BX; \ + MOVQ offset+PTRACE_RCX(reg), CX; \ + MOVQ offset+PTRACE_RDX(reg), DX; \ + MOVQ offset+PTRACE_R11(reg), R11; \ + MOVQ offset+PTRACE_R10(reg), R10; \ + MOVQ offset+PTRACE_R9(reg), R9; \ + MOVQ offset+PTRACE_R8(reg), R8; \ + MOVQ offset+PTRACE_RSI(reg), SI; \ + MOVQ offset+PTRACE_RDI(reg), DI; + +// SWAP_GS swaps the kernel GS (CPU). +#define SWAP_GS() \ + BYTE $0x0F; BYTE $0x01; BYTE $0xf8; + +// IRET returns from an interrupt frame. +#define IRET() \ + BYTE $0x48; BYTE $0xcf; + +// SYSRET64 executes the sysret instruction. +#define SYSRET64() \ + BYTE $0x48; BYTE $0x0f; BYTE $0x07; + +// LOAD_KERNEL_ADDRESS loads a kernel address. +#define LOAD_KERNEL_ADDRESS(from, to) \ + MOVQ from, to; \ + ORQ ·KernelStartAddress(SB), to; + +// LOAD_KERNEL_STACK loads the kernel stack. +#define LOAD_KERNEL_STACK(from) \ + LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \ + LEAQ CPU_STACK_TOP(SP), SP; + +// See kernel.go. +TEXT ·Halt(SB),NOSPLIT,$0 + HLT + RET + +// See kernel.go. +TEXT ·Current(SB),NOSPLIT,$0-8 + MOVQ CPU_SELF(GS), AX + MOVQ AX, ret+0(FP) + RET + +// See entry_amd64.go. +TEXT ·swapgs(SB),NOSPLIT,$0 + SWAP_GS() + RET + +// See entry_amd64.go. +TEXT ·sysret(SB),NOSPLIT,$0-24 + // Save original state. + LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) + LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) + MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) + MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) + + // Restore user register state. + REGISTERS_LOAD(AX, 0) + MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET. + MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET. + MOVQ PTRACE_RSP(AX), SP // Restore the stack directly. + MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + SYSRET64() + +// See entry_amd64.go. +TEXT ·iret(SB),NOSPLIT,$0-24 + // Save original state. + LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) + LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) + MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) + MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) + + // Build an IRET frame & restore state. + LOAD_KERNEL_STACK(BX) + MOVQ PTRACE_SS(AX), BX; PUSHQ BX + MOVQ PTRACE_RSP(AX), CX; PUSHQ CX + MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX + MOVQ PTRACE_CS(AX), DI; PUSHQ DI + MOVQ PTRACE_RIP(AX), SI; PUSHQ SI + REGISTERS_LOAD(AX, 0) // Restore most registers. + MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + IRET() + +// See entry_amd64.go. +TEXT ·resume(SB),NOSPLIT,$0 + // See iret, above. + MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX + MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX + MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX + MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI + MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI + REGISTERS_LOAD(GS, CPU_REGISTERS) + MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX + IRET() + +// See entry_amd64.go. +TEXT ·Start(SB),NOSPLIT,$0 + LOAD_KERNEL_STACK(AX) // Set the stack. + PUSHQ $0x0 // Previous frame pointer. + MOVQ SP, BP // Set frame pointer. + PUSHQ AX // First argument (CPU). + CALL ·start(SB) // Call Go hook. + JMP ·resume(SB) // Restore to registers. + +// See entry_amd64.go. +TEXT ·sysenter(SB),NOSPLIT,$0 + // Interrupts are always disabled while we're executing in kernel mode + // and always enabled while executing in user mode. Therefore, we can + // reliably look at the flags in R11 to determine where this syscall + // was from. + TESTL $_RFLAGS_IF, R11 + JZ kernel + +user: + SWAP_GS() + XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks. + XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs). + REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. + MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value. + MOVQ BX, PTRACE_RAX(AX) // Save everything else. + MOVQ BX, PTRACE_ORIGRAX(AX) + MOVQ CX, PTRACE_RIP(AX) + MOVQ R11, PTRACE_FLAGS(AX) + MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX) + MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. + MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. + + // Return to the kernel, where the frame is: + // + // vector (sp+24) + // regs (sp+16) + // cpu (sp+8) + // vcpu.Switch (sp+0) + // + MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. + MOVQ $Syscall, 24(SP) // Output vector. + RET + +kernel: + // We can't restore the original stack, but we can access the registers + // in the CPU state directly. No need for temporary juggling. + MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) + MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) + REGISTERS_SAVE(GS, CPU_REGISTERS) + MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS) + MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS) + MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. + MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + + // Load the function stored in KernelSyscall. + // + // Note that this function needs to be executed on the stack in case + // the runtime decides to make use of the redzone (grumble). This also + // protects against any functions that might not be go:nosplit, since + // this will cause a failure immediately. + LOAD_KERNEL_STACK(GS) + MOVQ CPU_KERNEL_SYSCALL(GS), DX // Function data. + MOVQ 0(DX), AX // Function pointer. + PUSHQ BP // Push the frame pointer. + MOVQ SP, BP // Set frame pointer value. + CALL *AX // Call the function. + POPQ BP // Restore the frame pointer. + JMP ·resume(SB) + +// exception is a generic exception handler. +// +// There are two cases handled: +// +// 1) An exception in kernel mode: this results in saving the state at the time +// of the exception and calling the defined hook. +// +// 2) An exception in guest mode: the original kernel frame is restored, and +// the vector & error codes are pushed as return values. +// +// See below for the stubs that call exception. +TEXT ·exception(SB),NOSPLIT,$0 + // Determine whether the exception occurred in kernel mode or user + // mode, based on the flags. We expect the following stack: + // + // SS (sp+48) + // SP (sp+40) + // FLAGS (sp+32) + // CS (sp+24) + // IP (sp+16) + // ERROR_CODE (sp+8) + // VECTOR (sp+0) + // + TESTL $_RFLAGS_IF, 32(SP) + JZ kernel + +user: + SWAP_GS() + XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs). + REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. + MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value. + MOVQ BX, PTRACE_RAX(AX) // Save everything else. + MOVQ BX, PTRACE_ORIGRAX(AX) + MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX) + MOVQ 24(SP), CX; MOVQ CX, PTRACE_CS(AX) + MOVQ 32(SP), DX; MOVQ DX, PTRACE_FLAGS(AX) + MOVQ 40(SP), DI; MOVQ DI, PTRACE_RSP(AX) + MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX) + + // Copy out and return. + MOVQ 0(SP), BX // Load vector. + MOVQ 8(SP), CX // Load error code. + MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version). + MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. + MOVQ CX, CPU_ERROR_CODE(GS) // Set error code. + MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. + MOVQ BX, 24(SP) // Output vector. + RET + +kernel: + // As per above, we can save directly. + MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) + MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) + REGISTERS_SAVE(GS, CPU_REGISTERS) + MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS) + MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS) + MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS) + + // Set the error code and adjust the stack. + MOVQ 8(SP), AX // Load the error code. + MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU. + MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ 0(SP), BX // BX contains the vector. + ADDQ $48, SP // Drop the exception frame. + + // Load the function stored in KernelException. + // + // See note above re: the kernel stack. + LOAD_KERNEL_STACK(GS) + MOVQ CPU_KERNEL_EXCEPTION(GS), DX // Function data. + MOVQ 0(DX), AX // Function pointer. + PUSHQ BP // Push the frame pointer. + MOVQ SP, BP // Set frame pointer value. + PUSHQ BX // First argument (vector). + CALL *AX // Call the function. + POPQ BX // Discard the argument. + POPQ BP // Restore the frame pointer. + JMP ·resume(SB) + +#define EXCEPTION_WITH_ERROR(value, symbol) \ +TEXT symbol,NOSPLIT,$0; \ + PUSHQ $value; \ + JMP ·exception(SB); + +#define EXCEPTION_WITHOUT_ERROR(value, symbol) \ +TEXT symbol,NOSPLIT,$0; \ + PUSHQ $0x0; \ + PUSHQ $value; \ + JMP ·exception(SB); + +EXCEPTION_WITHOUT_ERROR(DivideByZero, ·divideByZero(SB)) +EXCEPTION_WITHOUT_ERROR(Debug, ·debug(SB)) +EXCEPTION_WITHOUT_ERROR(NMI, ·nmi(SB)) +EXCEPTION_WITHOUT_ERROR(Breakpoint, ·breakpoint(SB)) +EXCEPTION_WITHOUT_ERROR(Overflow, ·overflow(SB)) +EXCEPTION_WITHOUT_ERROR(BoundRangeExceeded, ·boundRangeExceeded(SB)) +EXCEPTION_WITHOUT_ERROR(InvalidOpcode, ·invalidOpcode(SB)) +EXCEPTION_WITHOUT_ERROR(DeviceNotAvailable, ·deviceNotAvailable(SB)) +EXCEPTION_WITH_ERROR(DoubleFault, ·doubleFault(SB)) +EXCEPTION_WITHOUT_ERROR(CoprocessorSegmentOverrun, ·coprocessorSegmentOverrun(SB)) +EXCEPTION_WITH_ERROR(InvalidTSS, ·invalidTSS(SB)) +EXCEPTION_WITH_ERROR(SegmentNotPresent, ·segmentNotPresent(SB)) +EXCEPTION_WITH_ERROR(StackSegmentFault, ·stackSegmentFault(SB)) +EXCEPTION_WITH_ERROR(GeneralProtectionFault, ·generalProtectionFault(SB)) +EXCEPTION_WITH_ERROR(PageFault, ·pageFault(SB)) +EXCEPTION_WITHOUT_ERROR(X87FloatingPointException, ·x87FloatingPointException(SB)) +EXCEPTION_WITH_ERROR(AlignmentCheck, ·alignmentCheck(SB)) +EXCEPTION_WITHOUT_ERROR(MachineCheck, ·machineCheck(SB)) +EXCEPTION_WITHOUT_ERROR(SIMDFloatingPointException, ·simdFloatingPointException(SB)) +EXCEPTION_WITHOUT_ERROR(VirtualizationException, ·virtualizationException(SB)) +EXCEPTION_WITH_ERROR(SecurityException, ·securityException(SB)) +EXCEPTION_WITHOUT_ERROR(SyscallInt80, ·syscallInt80(SB)) diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD new file mode 100644 index 000000000..3bce56985 --- /dev/null +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -0,0 +1,25 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +go_template_instance( + name = "defs_impl", + out = "defs_impl.go", + package = "main", + template = "//pkg/sentry/platform/ring0:defs", +) + +go_binary( + name = "gen_offsets", + srcs = [ + "defs_impl.go", + "main.go", + ], + visibility = ["//pkg/sentry/platform/ring0:__pkg__"], + deps = [ + "//pkg/cpuid", + "//pkg/sentry/platform/ring0/pagetables", + "//pkg/sentry/usermem", + ], +) diff --git a/pkg/sentry/platform/ring0/gen_offsets/main.go b/pkg/sentry/platform/ring0/gen_offsets/main.go new file mode 100644 index 000000000..ffa7eaf77 --- /dev/null +++ b/pkg/sentry/platform/ring0/gen_offsets/main.go @@ -0,0 +1,24 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary gen_offsets is a helper for generating offset headers. +package main + +import ( + "os" +) + +func main() { + Emit(os.Stdout) +} diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go new file mode 100644 index 000000000..b0471ab9a --- /dev/null +++ b/pkg/sentry/platform/ring0/kernel.go @@ -0,0 +1,71 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ring0 + +// New creates a new kernel. +// +// N.B. that constraints on KernelOpts must be satisfied. +// +// Init must have been called. +func New(opts KernelOpts) *Kernel { + k := new(Kernel) + k.init(opts) + return k +} + +// NewCPU creates a new CPU associated with this Kernel. +// +// Note that execution of the new CPU must begin at Start, with constraints as +// documented. Initialization is not completed by this method alone. +// +// See also Init. +func (k *Kernel) NewCPU() *CPU { + c := new(CPU) + c.Init(k) + return c +} + +// Halt halts execution. +func Halt() + +// Current returns the current CPU. +// +// Its use is only legal in the KernelSyscall and KernelException contexts, +// which must all be guarded go:nosplit. +func Current() *CPU + +// defaultSyscall is the default syscall hook. +// +//go:nosplit +func defaultSyscall() { Halt() } + +// defaultException is the default exception hook. +// +//go:nosplit +func defaultException(Vector) { Halt() } + +// Init allows the initialization of a CPU from a kernel without allocation. +// The same constraints as NewCPU apply. +// +// Init allows embedding in other objects. +func (c *CPU) Init(k *Kernel) { + c.self = c // Set self reference. + c.kernel = k // Set kernel reference. + c.init() // Perform architectural init. + + // Defaults. + c.KernelSyscall = defaultSyscall + c.KernelException = defaultException +} diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go new file mode 100644 index 000000000..c82613a9c --- /dev/null +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -0,0 +1,280 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package ring0 + +import ( + "encoding/binary" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _RFLAGS_RESERVED + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF + + // KernelFlagsClear should always be clear in the kernel. + KernelFlagsClear = _RFLAGS_IF | _RFLAGS_NT | _RFLAGS_IOPL + + // UserFlagsClear are always cleared in userspace. + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL +) + +// init initializes architecture-specific state. +func (k *Kernel) init(opts KernelOpts) { + // Save the root page tables. + k.PageTables = opts.PageTables + + // Setup the IDT, which is uniform. + for v, handler := range handlers { + // Note that we set all traps to use the interrupt stack, this + // is defined below when setting up the TSS. + k.globalIDT[v].setInterrupt(Kcode, uint64(kernelFunc(handler)), 0 /* dpl */, 1 /* ist */) + } +} + +// init initializes architecture-specific state. +func (c *CPU) init() { + // Null segment. + c.gdt[0].setNull() + + // Kernel & user segments. + c.gdt[segKcode] = KernelCodeSegment + c.gdt[segKdata] = KernelDataSegment + c.gdt[segUcode32] = UserCodeSegment32 + c.gdt[segUdata] = UserDataSegment + c.gdt[segUcode64] = UserCodeSegment64 + + // The task segment, this spans two entries. + tssBase, tssLimit, _ := c.TSS() + c.gdt[segTss].set( + uint32(tssBase), + uint32(tssLimit), + 0, // Privilege level zero. + SegmentDescriptorPresent| + SegmentDescriptorAccess| + SegmentDescriptorWrite| + SegmentDescriptorExecute) + c.gdt[segTssHi].setHi(uint32((tssBase) >> 32)) + + // Set the kernel stack pointer in the TSS (virtual address). + stackAddr := c.StackTop() + c.tss.rsp0Lo = uint32(stackAddr) + c.tss.rsp0Hi = uint32(stackAddr >> 32) + c.tss.ist1Lo = uint32(stackAddr) + c.tss.ist1Hi = uint32(stackAddr >> 32) + + // Permanently set the kernel segments. + c.registers.Cs = uint64(Kcode) + c.registers.Ds = uint64(Kdata) + c.registers.Es = uint64(Kdata) + c.registers.Ss = uint64(Kdata) + c.registers.Fs = uint64(Kdata) + c.registers.Gs = uint64(Kdata) +} + +// StackTop returns the kernel's stack address. +// +//go:nosplit +func (c *CPU) StackTop() uint64 { + return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack)) +} + +// IDT returns the CPU's IDT base and limit. +// +//go:nosplit +func (c *CPU) IDT() (uint64, uint16) { + return uint64(kernelAddr(&c.kernel.globalIDT[0])), uint16(binary.Size(&c.kernel.globalIDT) - 1) +} + +// GDT returns the CPU's GDT base and limit. +// +//go:nosplit +func (c *CPU) GDT() (uint64, uint16) { + return uint64(kernelAddr(&c.gdt[0])), uint16(8*segLast - 1) +} + +// TSS returns the CPU's TSS base, limit and value. +// +//go:nosplit +func (c *CPU) TSS() (uint64, uint16, *SegmentDescriptor) { + return uint64(kernelAddr(&c.tss)), uint16(binary.Size(&c.tss) - 1), &c.gdt[segTss] +} + +// CR0 returns the CPU's CR0 value. +// +//go:nosplit +func (c *CPU) CR0() uint64 { + return _CR0_PE | _CR0_PG | _CR0_ET +} + +// CR4 returns the CPU's CR4 value. +// +//go:nosplit +func (c *CPU) CR4() uint64 { + cr4 := uint64(_CR4_PAE | _CR4_PSE | _CR4_OSFXSR | _CR4_OSXMMEXCPT) + if hasPCID { + cr4 |= _CR4_PCIDE + } + if hasXSAVE { + cr4 |= _CR4_OSXSAVE + } + if hasSMEP { + cr4 |= _CR4_SMEP + } + if hasFSGSBASE { + cr4 |= _CR4_FSGSBASE + } + return cr4 +} + +// EFER returns the CPU's EFER value. +// +//go:nosplit +func (c *CPU) EFER() uint64 { + return _EFER_LME | _EFER_SCE | _EFER_NX +} + +// IsCanonical indicates whether addr is canonical per the amd64 spec. +// +//go:nosplit +func IsCanonical(addr uint64) bool { + return addr <= 0x00007fffffffffff || addr > 0xffff800000000000 +} + +// Flags contains flags related to switch. +type Flags uintptr + +const ( + // FlagFull indicates that a full restore should be not, not a fast + // restore (on the syscall return path.) + FlagFull = 1 << iota + + // FlagFlush indicates that a full TLB flush is required. + FlagFlush +) + +// SwitchToUser performs either a sysret or an iret. +// +// The return value is the vector that interrupted execution. +// +// This function will not split the stack. Callers will probably want to call +// runtime.entersyscall (and pair with a call to runtime.exitsyscall) prior to +// calling this function. +// +// When this is done, this region is quite sensitive to things like system +// calls. After calling entersyscall, any memory used must have been allocated +// and no function calls without go:nosplit are permitted. Any calls made here +// are protected appropriately (e.g. IsCanonical and CR3). +// +// Also note that this function transitively depends on the compiler generating +// code that uses IP-relative addressing inside of absolute addresses. That's +// the case for amd64, but may not be the case for other architectures. +// +//go:nosplit +func (c *CPU) SwitchToUser(regs *syscall.PtraceRegs, fpState *byte, pt *pagetables.PageTables, flags Flags) (vector Vector) { + // Check for canonical addresses. + if !IsCanonical(regs.Rip) || !IsCanonical(regs.Rsp) || !IsCanonical(regs.Fs_base) || !IsCanonical(regs.Gs_base) { + return GeneralProtectionFault + } + + var ( + userCR3 uint64 + kernelCR3 uint64 + ) + + // Sanitize registers. + if flags&FlagFlush != 0 { + userCR3 = pt.FlushCR3() + } else { + userCR3 = pt.CR3() + } + regs.Eflags &= ^uint64(UserFlagsClear) + regs.Eflags |= UserFlagsSet + regs.Cs = uint64(Ucode64) // Required for iret. + regs.Ss = uint64(Udata) // Ditto. + kernelCR3 = c.kernel.PageTables.CR3() + + // Perform the switch. + swapgs() // GS will be swapped on return. + wrfs(uintptr(regs.Fs_base)) // Set application FS. + wrgs(uintptr(regs.Gs_base)) // Set application GS. + LoadFloatingPoint(fpState) // Copy in floating point. + jumpToKernel() // Switch to upper half. + writeCR3(uintptr(userCR3)) // Change to user address space. + if flags&FlagFull != 0 { + vector = iret(c, regs) + } else { + vector = sysret(c, regs) + } + writeCR3(uintptr(kernelCR3)) // Return to kernel address space. + jumpToUser() // Return to lower half. + SaveFloatingPoint(fpState) // Copy out floating point. + wrfs(uintptr(c.registers.Fs_base)) // Restore kernel FS. + return +} + +// start is the CPU entrypoint. +// +// This is called from the Start asm stub (see entry_amd64.go); on return the +// registers in c.registers will be restored (not segments). +// +//go:nosplit +func start(c *CPU) { + // Save per-cpu & FS segment. + wrgs(kernelAddr(c)) + wrfs(uintptr(c.Registers().Fs_base)) + + // Initialize floating point. + // + // Note that on skylake, the valid XCR0 mask reported seems to be 0xff. + // This breaks down as: + // + // bit0 - x87 + // bit1 - SSE + // bit2 - AVX + // bit3-4 - MPX + // bit5-7 - AVX512 + // + // For some reason, enabled MPX & AVX512 on platforms that report them + // seems to be cause a general protection fault. (Maybe there are some + // virtualization issues and these aren't exported to the guest cpuid.) + // This needs further investigation, but we can limit the floating + // point operations to x87, SSE & AVX for now. + fninit() + xsetbv(0, validXCR0Mask&0x7) + + // Set the syscall target. + wrmsr(_MSR_LSTAR, kernelFunc(sysenter)) + wrmsr(_MSR_SYSCALL_MASK, _RFLAGS_STEP|_RFLAGS_IF|_RFLAGS_DF|_RFLAGS_IOPL|_RFLAGS_AC|_RFLAGS_NT) + + // NOTE: This depends on having the 64-bit segments immediately + // following the 32-bit user segments. This is simply the way the + // sysret instruction is designed to work (it assumes they follow). + wrmsr(_MSR_STAR, uintptr(uint64(Kcode)<<32|uint64(Ucode32)<<48)) + wrmsr(_MSR_CSTAR, kernelFunc(sysenter)) +} + +// ReadCR2 reads the current CR2 value. +// +//go:nosplit +func ReadCR2() uintptr { + return readCR2() +} diff --git a/pkg/sentry/platform/ring0/kernel_unsafe.go b/pkg/sentry/platform/ring0/kernel_unsafe.go new file mode 100644 index 000000000..cfb3ad853 --- /dev/null +++ b/pkg/sentry/platform/ring0/kernel_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ring0 + +import ( + "unsafe" +) + +// eface mirrors runtime.eface. +type eface struct { + typ uintptr + data unsafe.Pointer +} + +// kernelAddr returns the kernel virtual address for the given object. +// +//go:nosplit +func kernelAddr(obj interface{}) uintptr { + e := (*eface)(unsafe.Pointer(&obj)) + return KernelStartAddress | uintptr(e.data) +} + +// kernelFunc returns the address of the given function. +// +//go:nosplit +func kernelFunc(fn func()) uintptr { + fnptr := (**uintptr)(unsafe.Pointer(&fn)) + return KernelStartAddress | **fnptr +} diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go new file mode 100644 index 000000000..f1ed5bfb4 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_amd64.go @@ -0,0 +1,128 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package ring0 + +import ( + "gvisor.googlesource.com/gvisor/pkg/cpuid" +) + +// LoadFloatingPoint loads floating point state by the most efficient mechanism +// available (set by Init). +var LoadFloatingPoint func(*byte) + +// SaveFloatingPoint saves floating point state by the most efficient mechanism +// available (set by Init). +var SaveFloatingPoint func(*byte) + +// fxrstor uses fxrstor64 to load floating point state. +func fxrstor(*byte) + +// xrstor uses xrstor to load floating point state. +func xrstor(*byte) + +// fxsave uses fxsave64 to save floating point state. +func fxsave(*byte) + +// xsave uses xsave to save floating point state. +func xsave(*byte) + +// xsaveopt uses xsaveopt to save floating point state. +func xsaveopt(*byte) + +// wrfs sets the GS address (set by init). +var wrfs func(addr uintptr) + +// wrfsbase writes to the GS base address. +func wrfsbase(addr uintptr) + +// wrfsmsr writes to the GS_BASE MSR. +func wrfsmsr(addr uintptr) + +// wrgs sets the GS address (set by init). +var wrgs func(addr uintptr) + +// wrgsbase writes to the GS base address. +func wrgsbase(addr uintptr) + +// wrgsmsr writes to the GS_BASE MSR. +func wrgsmsr(addr uintptr) + +// writeCR3 writes the CR3 value. +func writeCR3(phys uintptr) + +// readCR2 reads the current CR2 value. +func readCR2() uintptr + +// jumpToKernel jumps to the kernel version of the current RIP. +func jumpToKernel() + +// jumpToUser jumps to the user version of the current RIP. +func jumpToUser() + +// fninit initializes the floating point unit. +func fninit() + +// xsetbv writes to an extended control register. +func xsetbv(reg, value uintptr) + +// xgetbv reads an extended control register. +func xgetbv(reg uintptr) uintptr + +// wrmsr reads to the given MSR. +func wrmsr(reg, value uintptr) + +// rdmsr reads the given MSR. +func rdmsr(reg uintptr) uintptr + +// Mostly-constants set by Init. +var ( + hasSMEP bool + hasPCID bool + hasXSAVEOPT bool + hasXSAVE bool + hasFSGSBASE bool + validXCR0Mask uintptr +) + +// Init sets function pointers based on architectural features. +// +// This must be called prior to using ring0. +func Init(featureSet *cpuid.FeatureSet) { + hasSMEP = featureSet.HasFeature(cpuid.X86FeatureSMEP) + hasPCID = featureSet.HasFeature(cpuid.X86FeaturePCID) + hasXSAVEOPT = featureSet.UseXsaveopt() + hasXSAVE = featureSet.UseXsave() + hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase) + validXCR0Mask = uintptr(featureSet.ValidXCR0Mask()) + if hasXSAVEOPT { + SaveFloatingPoint = xsaveopt + LoadFloatingPoint = xrstor + } else if hasXSAVE { + SaveFloatingPoint = xsave + LoadFloatingPoint = xrstor + } else { + SaveFloatingPoint = fxsave + LoadFloatingPoint = fxrstor + } + if hasFSGSBASE { + wrfs = wrfsbase + wrgs = wrgsbase + } else { + wrfs = wrfsmsr + wrgs = wrgsmsr + } +} diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s new file mode 100644 index 000000000..6f143ea5a --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_amd64.s @@ -0,0 +1,247 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +// fxrstor loads floating point state. +// +// The code corresponds to: +// +// fxrstor64 (%rbx) +// +TEXT ·fxrstor(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), BX + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x0b; + RET + +// xrstor loads floating point state. +// +// The code corresponds to: +// +// xrstor (%rdi) +// +TEXT ·xrstor(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), DI + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f; + RET + +// fxsave saves floating point state. +// +// The code corresponds to: +// +// fxsave64 (%rbx) +// +TEXT ·fxsave(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), BX + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x03; + RET + +// xsave saves floating point state. +// +// The code corresponds to: +// +// xsave (%rdi) +// +TEXT ·xsave(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), DI + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; + RET + +// xsaveopt saves floating point state. +// +// The code corresponds to: +// +// xsaveopt (%rdi) +// +TEXT ·xsaveopt(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), DI + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; + RET + +// wrfsbase writes to the FS base. +// +// The code corresponds to: +// +// wrfsbase %rax +// +TEXT ·wrfsbase(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), AX + BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0xd0; + RET + +// wrfsmsr writes to the FSBASE MSR. +// +// The code corresponds to: +// +// wrmsr (writes EDX:EAX to the MSR in ECX) +// +TEXT ·wrfsmsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), AX + MOVQ AX, DX + SHRQ $32, DX + MOVQ $0xc0000100, CX // MSR_FS_BASE + BYTE $0x0f; BYTE $0x30; + RET + +// wrgsbase writes to the GS base. +// +// The code corresponds to: +// +// wrgsbase %rax +// +TEXT ·wrgsbase(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), AX + BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0xd8; + RET + +// wrgsmsr writes to the GSBASE MSR. +// +// See wrfsmsr. +TEXT ·wrgsmsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), AX + MOVQ AX, DX + SHRQ $32, DX + MOVQ $0xc0000101, CX // MSR_GS_BASE + BYTE $0x0f; BYTE $0x30; // WRMSR + RET + +// jumpToUser changes execution to the user address. +// +// This works by changing the return value to the user version. +TEXT ·jumpToUser(SB),NOSPLIT,$0 + MOVQ 0(SP), AX + MOVQ ·KernelStartAddress(SB), BX + NOTQ BX + ANDQ BX, SP // Switch the stack. + ANDQ BX, BP // Switch the frame pointer. + ANDQ BX, AX // Future return value. + MOVQ AX, 0(SP) + RET + +// jumpToKernel changes execution to the kernel address space. +// +// This works by changing the return value to the kernel version. +TEXT ·jumpToKernel(SB),NOSPLIT,$0 + MOVQ 0(SP), AX + MOVQ ·KernelStartAddress(SB), BX + ORQ BX, SP // Switch the stack. + ORQ BX, BP // Switch the frame pointer. + ORQ BX, AX // Future return value. + MOVQ AX, 0(SP) + RET + +// writeCR3 writes the given CR3 value. +// +// The code corresponds to: +// +// mov %rax, %cr3 +// +TEXT ·writeCR3(SB),NOSPLIT,$0-8 + MOVQ cr3+0(FP), AX + BYTE $0x0f; BYTE $0x22; BYTE $0xd8; + RET + +// readCR3 reads the current CR3 value. +// +// The code corresponds to: +// +// mov %cr3, %rax +// +TEXT ·readCR3(SB),NOSPLIT,$0-8 + BYTE $0x0f; BYTE $0x20; BYTE $0xd8; + MOVQ AX, ret+0(FP) + RET + +// readCR2 reads the current CR2 value. +// +// The code corresponds to: +// +// mov %cr2, %rax +// +TEXT ·readCR2(SB),NOSPLIT,$0-8 + BYTE $0x0f; BYTE $0x20; BYTE $0xd0; + MOVQ AX, ret+0(FP) + RET + +// fninit initializes the floating point unit. +// +// The code corresponds to: +// +// fninit +TEXT ·fninit(SB),NOSPLIT,$0 + BYTE $0xdb; BYTE $0xe3; + RET + +// xsetbv writes to an extended control register. +// +// The code corresponds to: +// +// xsetbv +// +TEXT ·xsetbv(SB),NOSPLIT,$0-16 + MOVL reg+0(FP), CX + MOVL value+8(FP), AX + MOVL value+12(FP), DX + BYTE $0x0f; BYTE $0x01; BYTE $0xd1; + RET + +// xgetbv reads an extended control register. +// +// The code corresponds to: +// +// xgetbv +// +TEXT ·xgetbv(SB),NOSPLIT,$0-16 + MOVL reg+0(FP), CX + BYTE $0x0f; BYTE $0x01; BYTE $0xd0; + MOVL AX, ret+8(FP) + MOVL DX, ret+12(FP) + RET + +// wrmsr writes to a control register. +// +// The code corresponds to: +// +// wrmsr +// +TEXT ·wrmsr(SB),NOSPLIT,$0-16 + MOVL reg+0(FP), CX + MOVL value+8(FP), AX + MOVL value+12(FP), DX + BYTE $0x0f; BYTE $0x30; + RET + +// rdmsr reads a control register. +// +// The code corresponds to: +// +// rdmsr +// +TEXT ·rdmsr(SB),NOSPLIT,$0-16 + MOVL reg+0(FP), CX + BYTE $0x0f; BYTE $0x32; + MOVL AX, ret+8(FP) + MOVL DX, ret+12(FP) + RET diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go new file mode 100644 index 000000000..9acd442ba --- /dev/null +++ b/pkg/sentry/platform/ring0/offsets_amd64.go @@ -0,0 +1,93 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package ring0 + +import ( + "fmt" + "io" + "reflect" + "syscall" +) + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_KERNEL_EXCEPTION 0x%02x\n", reflect.ValueOf(&c.KernelException).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_KERNEL_SYSCALL 0x%02x\n", reflect.ValueOf(&c.KernelSyscall).Pointer()-reflect.ValueOf(c).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) + fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) + fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) + fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) + fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) + fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) + fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) + fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) + fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) + fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) + fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) + fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) + fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) + fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) + fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) + fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) + fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) + fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) + fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) + fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &syscall.PtraceRegs{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) +} diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD new file mode 100644 index 000000000..c0c481ab3 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -0,0 +1,32 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "pagetables", + srcs = [ + "pagetables.go", + "pagetables_amd64.go", + "pagetables_unsafe.go", + "pagetables_x86.go", + "pcids_x86.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables", + visibility = [ + "//pkg/sentry/platform/kvm:__subpackages__", + "//pkg/sentry/platform/ring0:__subpackages__", + ], + deps = ["//pkg/sentry/usermem"], +) + +go_test( + name = "pagetables_test", + size = "small", + srcs = [ + "pagetables_test.go", + "pagetables_x86_test.go", + "pcids_x86_test.go", + ], + embed = [":pagetables"], + deps = ["//pkg/sentry/usermem"], +) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go new file mode 100644 index 000000000..3cbf0bfa5 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -0,0 +1,193 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pagetables provides a generic implementation of pagetables. +package pagetables + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// Node is a single node within a set of page tables. +type Node struct { + // unalignedData has unaligned data. Unfortunately, we can't really + // rely on the allocator to give us what we want here. So we just throw + // it at the wall and use the portion that matches. Gross. This may be + // changed in the future to use a different allocation mechanism. + // + // Access must happen via functions found in pagetables_unsafe.go. + unalignedData [(2 * usermem.PageSize) - 1]byte + + // physical is the translated address of these entries. + // + // This is filled in at creation time. + physical uintptr +} + +// PageTables is a set of page tables. +type PageTables struct { + mu sync.Mutex + + // root is the pagetable root. + root *Node + + // translater is the translater passed at creation. + translater Translater + + // archPageTables includes architecture-specific features. + archPageTables + + // allNodes is a set of nodes indexed by translater address. + allNodes map[uintptr]*Node +} + +// Translater translates to guest physical addresses. +type Translater interface { + // TranslateToPhysical translates the given pointer object into a + // "physical" address. We do not require that it translates back, the + // reverse mapping is maintained internally. + TranslateToPhysical(*PTEs) uintptr +} + +// New returns new PageTables. +func New(t Translater, opts Opts) *PageTables { + p := &PageTables{ + translater: t, + allNodes: make(map[uintptr]*Node), + } + p.root = p.allocNode() + p.init(opts) + return p +} + +// New returns a new set of PageTables derived from the given one. +// +// This function should always be preferred to New if there are existing +// pagetables, as this function preserves architectural constraints relevant to +// managing multiple sets of pagetables. +func (p *PageTables) New() *PageTables { + np := &PageTables{ + translater: p.translater, + allNodes: make(map[uintptr]*Node), + } + np.root = np.allocNode() + np.initFrom(&p.archPageTables) + return np +} + +// setPageTable sets the given index as a page table. +func (p *PageTables) setPageTable(n *Node, index int, child *Node) { + phys := p.translater.TranslateToPhysical(child.PTEs()) + p.allNodes[phys] = child + pte := &n.PTEs()[index] + pte.setPageTable(phys) +} + +// clearPageTable clears the given entry. +func (p *PageTables) clearPageTable(n *Node, index int) { + pte := &n.PTEs()[index] + physical := pte.Address() + pte.Clear() + delete(p.allNodes, physical) +} + +// getPageTable returns the page table entry. +func (p *PageTables) getPageTable(n *Node, index int) *Node { + pte := &n.PTEs()[index] + physical := pte.Address() + child := p.allNodes[physical] + return child +} + +// Map installs a mapping with the given physical address. +// +// True is returned iff there was a previous mapping in the range. +// +// Precondition: addr & length must be aligned, their sum must not overflow. +func (p *PageTables) Map(addr usermem.Addr, length uintptr, user bool, at usermem.AccessType, physical uintptr) bool { + if at == usermem.NoAccess { + return p.Unmap(addr, length) + } + prev := false + p.mu.Lock() + end, ok := addr.AddLength(uint64(length)) + if !ok { + panic("pagetables.Map: overflow") + } + p.iterateRange(uintptr(addr), uintptr(end), true, func(s, e uintptr, pte *PTE, align uintptr) { + p := physical + (s - uintptr(addr)) + prev = prev || (pte.Valid() && (p != pte.Address() || at.Write != pte.Writeable() || at.Execute != pte.Executable())) + if p&align != 0 { + // We will install entries at a smaller granulaity if + // we don't install a valid entry here, however we must + // zap any existing entry to ensure this happens. + pte.Clear() + return + } + pte.Set(p, at.Write, at.Execute, user) + }) + p.mu.Unlock() + return prev +} + +// Unmap unmaps the given range. +// +// True is returned iff there was a previous mapping in the range. +func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { + p.mu.Lock() + count := 0 + p.iterateRange(uintptr(addr), uintptr(addr)+length, false, func(s, e uintptr, pte *PTE, align uintptr) { + pte.Clear() + count++ + }) + p.mu.Unlock() + return count > 0 +} + +// Release releases this address space. +// +// This must be called to release the PCID. +func (p *PageTables) Release() { + // Clear all pages. + p.Unmap(0, ^uintptr(0)) + p.release() +} + +// Lookup returns the physical address for the given virtual address. +func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, accessType usermem.AccessType) { + mask := uintptr(usermem.PageSize - 1) + off := uintptr(addr) & mask + addr = addr &^ usermem.Addr(mask) + p.iterateRange(uintptr(addr), uintptr(addr+usermem.PageSize), false, func(s, e uintptr, pte *PTE, align uintptr) { + if !pte.Valid() { + return + } + physical = pte.Address() + (s - uintptr(addr)) + off + accessType = usermem.AccessType{ + Read: true, + Write: pte.Writeable(), + Execute: pte.Executable(), + } + }) + return physical, accessType +} + +// allocNode allocates a new page. +func (p *PageTables) allocNode() *Node { + n := new(Node) + n.physical = p.translater.TranslateToPhysical(n.PTEs()) + return n +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go new file mode 100644 index 000000000..b89665c96 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -0,0 +1,397 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package pagetables + +import ( + "fmt" + "sync/atomic" +) + +// Address constraints. +// +// The lowerTop and upperBottom currently apply to four-level pagetables; +// additional refactoring would be necessary to support five-level pagetables. +const ( + lowerTop = 0x00007fffffffffff + upperBottom = 0xffff800000000000 + + pteShift = 12 + pmdShift = 21 + pudShift = 30 + pgdShift = 39 + + pteMask = 0x1ff << pteShift + pmdMask = 0x1ff << pmdShift + pudMask = 0x1ff << pudShift + pgdMask = 0x1ff << pgdShift + + pteSize = 1 << pteShift + pmdSize = 1 << pmdShift + pudSize = 1 << pudShift + pgdSize = 1 << pgdShift +) + +// Bits in page table entries. +const ( + present = 0x001 + writable = 0x002 + user = 0x004 + writeThrough = 0x008 + cacheDisable = 0x010 + accessed = 0x020 + dirty = 0x040 + super = 0x080 + executeDisable = 1 << 63 +) + +// PTE is a page table entry. +type PTE uint64 + +// Clear clears this PTE, including super page information. +func (p *PTE) Clear() { + atomic.StoreUint64((*uint64)(p), 0) +} + +// Valid returns true iff this entry is valid. +func (p *PTE) Valid() bool { + return atomic.LoadUint64((*uint64)(p))&present != 0 +} + +// Writeable returns true iff the page is writable. +func (p *PTE) Writeable() bool { + return atomic.LoadUint64((*uint64)(p))&writable != 0 +} + +// User returns true iff the page is user-accessible. +func (p *PTE) User() bool { + return atomic.LoadUint64((*uint64)(p))&user != 0 +} + +// Executable returns true iff the page is executable. +func (p *PTE) Executable() bool { + return atomic.LoadUint64((*uint64)(p))&executeDisable == 0 +} + +// SetSuper sets this page as a super page. +// +// The page must not be valid or a panic will result. +func (p *PTE) SetSuper() { + if p.Valid() { + // This is not allowed. + panic("SetSuper called on valid page!") + } + atomic.StoreUint64((*uint64)(p), super) +} + +// IsSuper returns true iff this page is a super page. +func (p *PTE) IsSuper() bool { + return atomic.LoadUint64((*uint64)(p))&super != 0 +} + +// Set sets this PTE value. +func (p *PTE) Set(addr uintptr, write, execute bool, userAccessible bool) { + v := uint64(addr)&^uint64(0xfff) | present | accessed + if userAccessible { + v |= user + } + if !execute { + v |= executeDisable + } + if write { + v |= writable | dirty + } + if p.IsSuper() { + v |= super + } + atomic.StoreUint64((*uint64)(p), v) +} + +// setPageTable sets this PTE value and forces the write bit and super bit to +// be cleared. This is used explicitly for breaking super pages. +func (p *PTE) setPageTable(addr uintptr) { + v := uint64(addr)&^uint64(0xfff) | present | user | writable | accessed | dirty + atomic.StoreUint64((*uint64)(p), v) +} + +// Address extracts the address. This should only be used if Valid returns true. +func (p *PTE) Address() uintptr { + return uintptr(atomic.LoadUint64((*uint64)(p)) & ^uint64(executeDisable|0xfff)) +} + +// entriesPerPage is the number of PTEs per page. +const entriesPerPage = 512 + +// PTEs is a collection of entries. +type PTEs [entriesPerPage]PTE + +// next returns the next address quantized by the given size. +func next(start uint64, size uint64) uint64 { + start &= ^(size - 1) + start += size + return start +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If alloc is set, then Set _must_ be called on all given PTEs. The exception +// is super pages. If a valid super page cannot be installed, then the walk +// will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if alloc set, then no gaps will be present. However, if alloc is +// not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: startAddr and endAddr must be page-aligned. +// +// Precondition: startStart must be less than endAddr. +// +// Precondition: If alloc is set, then startAddr and endAddr should not span +// non-canonical ranges. If they do, a panic will result. +func (p *PageTables) iterateRange(startAddr, endAddr uintptr, alloc bool, fn func(s, e uintptr, pte *PTE, align uintptr)) { + start := uint64(startAddr) + end := uint64(endAddr) + if start%pteSize != 0 { + panic(fmt.Sprintf("unaligned start: %v", start)) + } + if start > end { + panic(fmt.Sprintf("start > end (%v > %v))", start, end)) + } + + // Deal with cases where we traverse the "gap". + // + // These are all explicitly disallowed if alloc is set, and we must + // traverse an entry for each address explicitly. + switch { + case start < lowerTop && end > lowerTop && end < upperBottom: + if alloc { + panic(fmt.Sprintf("alloc [%x, %x) spans non-canonical range", start, end)) + } + p.iterateRange(startAddr, lowerTop, false, fn) + return + case start < lowerTop && end > lowerTop: + if alloc { + panic(fmt.Sprintf("alloc [%x, %x) spans non-canonical range", start, end)) + } + p.iterateRange(startAddr, lowerTop, false, fn) + p.iterateRange(upperBottom, endAddr, false, fn) + return + case start > lowerTop && end < upperBottom: + if alloc { + panic(fmt.Sprintf("alloc [%x, %x) spans non-canonical range", start, end)) + } + return + case start > lowerTop && start < upperBottom && end > upperBottom: + if alloc { + panic(fmt.Sprintf("alloc [%x, %x) spans non-canonical range", start, end)) + } + p.iterateRange(upperBottom, endAddr, false, fn) + return + } + + for pgdIndex := int((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + pgdEntry := &p.root.PTEs()[pgdIndex] + if !pgdEntry.Valid() { + if !alloc { + // Skip over this entry. + start = next(start, pgdSize) + continue + } + + // Allocate a new pgd. + p.setPageTable(p.root, pgdIndex, p.allocNode()) + } + + // Map the next level. + pudNode := p.getPageTable(p.root, pgdIndex) + clearPUDEntries := 0 + + for pudIndex := int((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + pudEntry := &(pudNode.PTEs()[pudIndex]) + if !pudEntry.Valid() { + if !alloc { + // Skip over this entry. + clearPUDEntries++ + start = next(start, pudSize) + continue + } + + // This level has 1-GB super pages. Is this + // entire region contained in a single PUD + // entry? If so, we can skip allocating a new + // page for the pmd. + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + fn(uintptr(start), uintptr(start+pudSize), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = next(start, pudSize) + continue + } + } + + // Allocate a new pud. + p.setPageTable(pudNode, pudIndex, p.allocNode()) + + } else if pudEntry.IsSuper() { + // Does this page need to be split? + if start&(pudSize-1) != 0 || end < next(start, pudSize) { + currentAddr := uint64(pudEntry.Address()) + writeable := pudEntry.Writeable() + executable := pudEntry.Executable() + user := pudEntry.User() + + // Install the relevant entries. + pmdNode := p.allocNode() + pmdEntries := pmdNode.PTEs() + for index := 0; index < entriesPerPage; index++ { + pmdEntry := &pmdEntries[index] + pmdEntry.SetSuper() + pmdEntry.Set(uintptr(currentAddr), writeable, executable, user) + currentAddr += pmdSize + } + + // Reset to point to the new page. + p.setPageTable(pudNode, pudIndex, pmdNode) + } else { + // A super page to be checked directly. + fn(uintptr(start), uintptr(start+pudSize), pudEntry, pudSize-1) + + // Might have been cleared. + if !pudEntry.Valid() { + clearPUDEntries++ + } + + // Note that the super page was changed. + start = next(start, pudSize) + continue + } + } + + // Map the next level, since this is valid. + pmdNode := p.getPageTable(pudNode, pudIndex) + clearPMDEntries := 0 + + for pmdIndex := int((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + pmdEntry := &pmdNode.PTEs()[pmdIndex] + if !pmdEntry.Valid() { + if !alloc { + // Skip over this entry. + clearPMDEntries++ + start = next(start, pmdSize) + continue + } + + // This level has 2-MB huge pages. If this + // region is contined in a single PMD entry? + // As above, we can skip allocating a new page. + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + fn(uintptr(start), uintptr(start+pmdSize), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = next(start, pmdSize) + continue + } + } + + // Allocate a new pmd. + p.setPageTable(pmdNode, pmdIndex, p.allocNode()) + + } else if pmdEntry.IsSuper() { + // Does this page need to be split? + if start&(pmdSize-1) != 0 || end < next(start, pmdSize) { + currentAddr := uint64(pmdEntry.Address()) + writeable := pmdEntry.Writeable() + executable := pmdEntry.Executable() + user := pmdEntry.User() + + // Install the relevant entries. + pteNode := p.allocNode() + pteEntries := pteNode.PTEs() + for index := 0; index < entriesPerPage; index++ { + pteEntry := &pteEntries[index] + pteEntry.Set(uintptr(currentAddr), writeable, executable, user) + currentAddr += pteSize + } + + // Reset to point to the new page. + p.setPageTable(pmdNode, pmdIndex, pteNode) + } else { + // A huge page to be checked directly. + fn(uintptr(start), uintptr(start+pmdSize), pmdEntry, pmdSize-1) + + // Might have been cleared. + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + // Note that the huge page was changed. + start = next(start, pmdSize) + continue + } + } + + // Map the next level, since this is valid. + pteNode := p.getPageTable(pmdNode, pmdIndex) + clearPTEEntries := 0 + + for pteIndex := int((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + pteEntry := &pteNode.PTEs()[pteIndex] + if !pteEntry.Valid() && !alloc { + clearPTEEntries++ + start += pteSize + continue + } + + // At this point, we are guaranteed that start%pteSize == 0. + fn(uintptr(start), uintptr(start+pteSize), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if alloc { + panic("PTE not set after iteration with alloc=true!") + } + clearPTEEntries++ + } + + // Note that the pte was changed. + start += pteSize + continue + } + + // Check if we no longer need this page. + if clearPTEEntries == entriesPerPage { + p.clearPageTable(pmdNode, pmdIndex) + clearPMDEntries++ + } + } + + // Check if we no longer need this page. + if clearPMDEntries == entriesPerPage { + p.clearPageTable(pudNode, pudIndex) + clearPUDEntries++ + } + } + + // Check if we no longer need this page. + if clearPUDEntries == entriesPerPage { + p.clearPageTable(p.root, pgdIndex) + } + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go new file mode 100644 index 000000000..9cbc0e3b0 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go @@ -0,0 +1,161 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +import ( + "reflect" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +type reflectTranslater struct{} + +func (r reflectTranslater) TranslateToPhysical(ptes *PTEs) uintptr { + return reflect.ValueOf(ptes).Pointer() +} + +type mapping struct { + start uintptr + length uintptr + addr uintptr + writeable bool +} + +func checkMappings(t *testing.T, pt *PageTables, m []mapping) { + var ( + current int + found []mapping + failed string + ) + + // Iterate over all the mappings. + pt.iterateRange(0, ^uintptr(0), false, func(s, e uintptr, pte *PTE, align uintptr) { + found = append(found, mapping{ + start: s, + length: e - s, + addr: pte.Address(), + writeable: pte.Writeable(), + }) + if failed != "" { + // Don't keep looking for errors. + return + } + + if current >= len(m) { + failed = "more mappings than expected" + } else if m[current].start != s { + failed = "start didn't match expected" + } else if m[current].length != (e - s) { + failed = "end didn't match expected" + } else if m[current].addr != pte.Address() { + failed = "address didn't match expected" + } else if m[current].writeable != pte.Writeable() { + failed = "writeable didn't match" + } + current++ + }) + + // Were we expected additional mappings? + if failed == "" && current != len(m) { + failed = "insufficient mappings found" + } + + // Emit a meaningful error message on failure. + if failed != "" { + t.Errorf("%s; got %#v, wanted %#v", failed, found, m) + } +} + +func TestAllocFree(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + pt.Release() +} + +func TestUnmap(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map and unmap one entry. + pt.Map(0x400000, pteSize, true, usermem.ReadWrite, pteSize*42) + pt.Unmap(0x400000, pteSize) + + checkMappings(t, pt, nil) + pt.Release() +} + +func TestReadOnly(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map one entry. + pt.Map(0x400000, pteSize, true, usermem.Read, pteSize*42) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, false}, + }) + pt.Release() +} + +func TestReadWrite(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map one entry. + pt.Map(0x400000, pteSize, true, usermem.ReadWrite, pteSize*42) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, true}, + }) + pt.Release() +} + +func TestSerialEntries(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map two sequential entries. + pt.Map(0x400000, pteSize, true, usermem.ReadWrite, pteSize*42) + pt.Map(0x401000, pteSize, true, usermem.ReadWrite, pteSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, true}, + {0x401000, pteSize, pteSize * 47, true}, + }) + pt.Release() +} + +func TestSpanningEntries(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Span a pgd with two pages. + pt.Map(0x00007efffffff000, 2*pteSize, true, usermem.Read, pteSize*42) + + checkMappings(t, pt, []mapping{ + {0x00007efffffff000, pteSize, pteSize * 42, false}, + {0x00007f0000000000, pteSize, pteSize * 43, false}, + }) + pt.Release() +} + +func TestSparseEntries(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map two entries in different pgds. + pt.Map(0x400000, pteSize, true, usermem.ReadWrite, pteSize*42) + pt.Map(0x00007f0000000000, pteSize, true, usermem.Read, pteSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, true}, + {0x00007f0000000000, pteSize, pteSize * 47, false}, + }) + pt.Release() +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_unsafe.go b/pkg/sentry/platform/ring0/pagetables/pagetables_unsafe.go new file mode 100644 index 000000000..a2b44fb79 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_unsafe.go @@ -0,0 +1,31 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +import ( + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// PTEs returns aligned PTE entries. +func (n *Node) PTEs() *PTEs { + addr := uintptr(unsafe.Pointer(&n.unalignedData[0])) + offset := addr & (usermem.PageSize - 1) + if offset != 0 { + offset = usermem.PageSize - offset + } + return (*PTEs)(unsafe.Pointer(addr + offset)) +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go new file mode 100644 index 000000000..dac66373f --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go @@ -0,0 +1,79 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build i386 amd64 + +package pagetables + +// Opts are pagetable options. +type Opts struct { + EnablePCID bool +} + +// archPageTables has x86-specific features. +type archPageTables struct { + // pcids is the PCID database. + pcids *PCIDs + + // pcid is the globally unique identifier, or zero if none were + // available or pcids is nil. + pcid uint16 +} + +// init initializes arch-specific features. +func (a *archPageTables) init(opts Opts) { + if opts.EnablePCID { + a.pcids = NewPCIDs() + a.pcid = a.pcids.allocate() + } +} + +// initFrom initializes arch-specific features from an existing entry.' +func (a *archPageTables) initFrom(other *archPageTables) { + a.pcids = other.pcids // Refer to the same PCID database. + if a.pcids != nil { + a.pcid = a.pcids.allocate() + } +} + +// release is called from Release. +func (a *archPageTables) release() { + // Return the PCID. + if a.pcids != nil { + a.pcids.free(a.pcid) + } +} + +// CR3 returns the CR3 value for these tables. +// +// This may be called in interrupt contexts. +// +//go:nosplit +func (p *PageTables) CR3() uint64 { + // Bit 63 is set to avoid flushing the PCID (per SDM 4.10.4.1). + const noFlushBit uint64 = 0x8000000000000000 + if p.pcid != 0 { + return noFlushBit | uint64(p.root.physical) | uint64(p.pcid) + } + return uint64(p.root.physical) +} + +// FlushCR3 returns the CR3 value that flushes the TLB. +// +// This may be called in interrupt contexts. +// +//go:nosplit +func (p *PageTables) FlushCR3() uint64 { + return uint64(p.root.physical) | uint64(p.pcid) +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86_test.go new file mode 100644 index 000000000..1fc403c48 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86_test.go @@ -0,0 +1,79 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build i386 amd64 + +package pagetables + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +func Test2MAnd4K(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map a small page and a huge page. + pt.Map(0x400000, pteSize, true, usermem.ReadWrite, pteSize*42) + pt.Map(0x00007f0000000000, 1<<21, true, usermem.Read, pmdSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, true}, + {0x00007f0000000000, pmdSize, pmdSize * 47, false}, + }) + pt.Release() +} + +func Test1GAnd4K(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map a small page and a super page. + pt.Map(0x400000, pteSize, true, usermem.ReadWrite, pteSize*42) + pt.Map(0x00007f0000000000, pudSize, true, usermem.Read, pudSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, true}, + {0x00007f0000000000, pudSize, pudSize * 47, false}, + }) + pt.Release() +} + +func TestSplit1GPage(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map a super page and knock out the middle. + pt.Map(0x00007f0000000000, pudSize, true, usermem.Read, pudSize*42) + pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x00007f0000000000, pteSize, pudSize * 42, false}, + {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, false}, + }) + pt.Release() +} + +func TestSplit2MPage(t *testing.T) { + pt := New(reflectTranslater{}, Opts{}) + + // Map a huge page and knock out the middle. + pt.Map(0x00007f0000000000, pmdSize, true, usermem.Read, pmdSize*42) + pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x00007f0000000000, pteSize, pmdSize * 42, false}, + {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, false}, + }) + pt.Release() +} diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go new file mode 100644 index 000000000..509e8c0d9 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go @@ -0,0 +1,74 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build i386 amd64 + +package pagetables + +import ( + "sync" +) + +// maxPCID is the maximum allowed PCID. +const maxPCID = 4095 + +// PCIDs is a simple PCID database. +type PCIDs struct { + mu sync.Mutex + + // last is the last fresh PCID given out (not including the available + // pool). If last >= maxPCID, then the only PCIDs available in the + // available pool below. + last uint16 + + // available are PCIDs that have been freed. + available map[uint16]struct{} +} + +// NewPCIDs returns a new PCID set. +func NewPCIDs() *PCIDs { + return &PCIDs{ + available: make(map[uint16]struct{}), + } +} + +// allocate returns an unused PCID, or zero if all are taken. +func (p *PCIDs) allocate() uint16 { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.available) > 0 { + for id := range p.available { + delete(p.available, id) + return id + } + } + if id := p.last + 1; id <= maxPCID { + p.last = id + return id + } + // Nothing available. + return 0 +} + +// free returns a PCID to the pool. +// +// It is safe to call free with a zero pcid. That is, you may always call free +// with anything returned by allocate. +func (p *PCIDs) free(id uint16) { + p.mu.Lock() + defer p.mu.Unlock() + if id != 0 { + p.available[id] = struct{}{} + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86_test.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86_test.go new file mode 100644 index 000000000..0b555cd76 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86_test.go @@ -0,0 +1,65 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build i386 amd64 + +package pagetables + +import ( + "testing" +) + +func TestMaxPCID(t *testing.T) { + p := NewPCIDs() + for i := 0; i < maxPCID; i++ { + if id := p.allocate(); id != uint16(i+1) { + t.Errorf("got %d, expected %d", id, i+1) + } + } + if id := p.allocate(); id != 0 { + if id != 0 { + t.Errorf("got %d, expected 0", id) + } + } +} + +func TestFirstPCID(t *testing.T) { + p := NewPCIDs() + if id := p.allocate(); id != 1 { + t.Errorf("got %d, expected 1", id) + } +} + +func TestFreePCID(t *testing.T) { + p := NewPCIDs() + p.free(0) + if id := p.allocate(); id != 1 { + t.Errorf("got %d, expected 1 (not zero)", id) + } +} + +func TestReusePCID(t *testing.T) { + p := NewPCIDs() + id := p.allocate() + if id != 1 { + t.Errorf("got %d, expected 1", id) + } + p.free(id) + if id := p.allocate(); id != 1 { + t.Errorf("got %d, expected 1", id) + } + if id := p.allocate(); id != 2 { + t.Errorf("got %d, expected 2", id) + } +} diff --git a/pkg/sentry/platform/ring0/ring0.go b/pkg/sentry/platform/ring0/ring0.go new file mode 100644 index 000000000..4991031c5 --- /dev/null +++ b/pkg/sentry/platform/ring0/ring0.go @@ -0,0 +1,16 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ring0 provides basic operating system-level stubs. +package ring0 diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go new file mode 100644 index 000000000..e16f6c599 --- /dev/null +++ b/pkg/sentry/platform/ring0/x86.go @@ -0,0 +1,242 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build i386 amd64 + +package ring0 + +import ( + "gvisor.googlesource.com/gvisor/pkg/cpuid" +) + +// Useful bits. +const ( + _CR0_PE = 1 << 0 + _CR0_ET = 1 << 4 + _CR0_PG = 1 << 31 + + _CR4_PSE = 1 << 4 + _CR4_PAE = 1 << 5 + _CR4_PGE = 1 << 7 + _CR4_OSFXSR = 1 << 9 + _CR4_OSXMMEXCPT = 1 << 10 + _CR4_FSGSBASE = 1 << 16 + _CR4_PCIDE = 1 << 17 + _CR4_OSXSAVE = 1 << 18 + _CR4_SMEP = 1 << 20 + + _RFLAGS_AC = 1 << 18 + _RFLAGS_NT = 1 << 14 + _RFLAGS_IOPL = 3 << 12 + _RFLAGS_DF = 1 << 10 + _RFLAGS_IF = 1 << 9 + _RFLAGS_STEP = 1 << 8 + _RFLAGS_RESERVED = 1 << 1 + + _EFER_SCE = 0x001 + _EFER_LME = 0x100 + _EFER_NX = 0x800 + + _MSR_STAR = 0xc0000081 + _MSR_LSTAR = 0xc0000082 + _MSR_CSTAR = 0xc0000083 + _MSR_SYSCALL_MASK = 0xc0000084 +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + DivideByZero Vector = iota + Debug + NMI + Breakpoint + Overflow + BoundRangeExceeded + InvalidOpcode + DeviceNotAvailable + DoubleFault + CoprocessorSegmentOverrun + InvalidTSS + SegmentNotPresent + StackSegmentFault + GeneralProtectionFault + PageFault + _ + X87FloatingPointException + AlignmentCheck + MachineCheck + SIMDFloatingPointException + VirtualizationException + SecurityException = 0x1e + SyscallInt80 = 0x80 + _NR_INTERRUPTS = SyscallInt80 + 1 +) + +// System call vectors. +const ( + Syscall Vector = _NR_INTERRUPTS +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +// +// Note that sign-extension semantics apply to the highest order bit. +// +// FIXME: This should use the cpuid passed to Init. +func VirtualAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return (ax >> 8) & 0xff +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +// +// FIXME: This should use the cpuid passed to Init. +func PhysicalAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return ax & 0xff +} + +// Selector is a segment Selector. +type Selector uint16 + +// SegmentDescriptor is a segment descriptor. +type SegmentDescriptor struct { + bits [2]uint32 +} + +// descriptorTable is a collection of descriptors. +type descriptorTable [32]SegmentDescriptor + +// SegmentDescriptorFlags are typed flags within a descriptor. +type SegmentDescriptorFlags uint32 + +// SegmentDescriptorFlag declarations. +const ( + SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). + SegmentDescriptorWrite = 1 << 9 // Write permission. + SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. + SegmentDescriptorExecute = 1 << 11 // Execute permission. + SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. + SegmentDescriptorPresent = 1 << 15 // Present. + SegmentDescriptorAVL = 1 << 20 // Available. + SegmentDescriptorLong = 1 << 21 // Long mode. + SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. + SegmentDescriptorG = 1 << 23 // Granularity: page or byte. +) + +// Base returns the descriptor's base linear address. +func (d *SegmentDescriptor) Base() uint32 { + return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 +} + +// Limit returns the descriptor size. +func (d *SegmentDescriptor) Limit() uint32 { + l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 + if d.bits[1]&uint32(SegmentDescriptorG) != 0 { + l <<= 12 + l |= 0xFFF + } + return l +} + +// Flags returns descriptor flags. +func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { + return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) +} + +// DPL returns the descriptor privilege level. +func (d *SegmentDescriptor) DPL() int { + return int((d.bits[1] >> 13) & 3) +} + +func (d *SegmentDescriptor) setNull() { + d.bits[0] = 0 + d.bits[1] = 0 +} + +func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { + flags |= SegmentDescriptorPresent + if limit>>12 != 0 { + limit >>= 12 + flags |= SegmentDescriptorG + } + d.bits[0] = base<<16 | limit&0xFFFF + d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 +} + +func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorDB| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorG| + SegmentDescriptorLong| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorWrite| + SegmentDescriptorSystem) +} + +// setHi is only used for the TSS segment, which is magically 64-bits. +func (d *SegmentDescriptor) setHi(base uint32) { + d.bits[0] = base + d.bits[1] = 0 +} + +// Gate64 is a 64-bit task, trap, or interrupt gate. +type Gate64 struct { + bits [4]uint32 +} + +// idt64 is a 64-bit interrupt descriptor table. +type idt64 [_NR_INTERRUPTS]Gate64 + +func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { + g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF + g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 + g.bits[2] = uint32(rip >> 32) +} + +func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { + g.setInterrupt(cs, rip, dpl, ist) + g.bits[1] |= 1 << 8 +} + +// TaskState64 is a 64-bit task state structure. +type TaskState64 struct { + _ uint32 + rsp0Lo, rsp0Hi uint32 + rsp1Lo, rsp1Hi uint32 + rsp2Lo, rsp2Hi uint32 + _ [2]uint32 + ist1Lo, ist1Hi uint32 + ist2Lo, ist2Hi uint32 + ist3Lo, ist3Hi uint32 + ist4Lo, ist4Hi uint32 + ist5Lo, ist5Hi uint32 + ist6Lo, ist6Hi uint32 + ist7Lo, ist7Hi uint32 + _ [2]uint32 + _ uint16 + ioPerm uint16 +} diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD new file mode 100644 index 000000000..8b9f29403 --- /dev/null +++ b/pkg/sentry/platform/safecopy/BUILD @@ -0,0 +1,28 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "safecopy", + srcs = [ + "atomic_amd64.s", + "memclr_amd64.s", + "memcpy_amd64.s", + "safecopy.go", + "safecopy_unsafe.go", + "sighandler_amd64.s", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/syserror", + ], +) + +go_test( + name = "safecopy_test", + srcs = [ + "safecopy_test.go", + ], + embed = [":safecopy"], +) diff --git a/pkg/sentry/platform/safecopy/atomic_amd64.s b/pkg/sentry/platform/safecopy/atomic_amd64.s new file mode 100644 index 000000000..69947dec3 --- /dev/null +++ b/pkg/sentry/platform/safecopy/atomic_amd64.s @@ -0,0 +1,108 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleSwapUint32Fault returns the value stored in DI. Control is transferred +// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as swapUint32 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVL DI, sig+20(FP) + RET + +// swapUint32 atomically stores new into *addr and returns (the previous *addr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) +TEXT ·swapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint32Fault will store a different value in this address. + MOVL $0, sig+20(FP) + + MOVQ addr+0(FP), DI + MOVL new+8(FP), AX + XCHGL AX, 0(DI) + MOVL AX, old+16(FP) + RET + +// handleSwapUint64Fault returns the value stored in DI. Control is transferred +// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as swapUint64 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 + MOVL DI, sig+24(FP) + RET + +// swapUint64 atomically stores new into *addr and returns (the previous *addr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: addr must be aligned to a 8-byte boundary. +// +//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) +TEXT ·swapUint64(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint64Fault will store a different value in this address. + MOVL $0, sig+24(FP) + + MOVQ addr+0(FP), DI + MOVQ new+8(FP), AX + XCHGQ AX, 0(DI) + MOVQ AX, old+16(FP) + RET + +// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is +// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the +// signal number stored in DI. +// +// It must have the same frame configuration as compareAndSwapUint32 so that it +// can undo any potential call frame set up by the assembler. +TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVL DI, sig+20(FP) + RET + +// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns +// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is +// received during the operation, the value of prev is unspecified, and sig is +// the number of the signal that was received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) +TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, this is + // the value the caller will see; if a signal is received, + // handleCompareAndSwapUint32Fault will store a different value in this + // address. + MOVL $0, sig+20(FP) + + MOVQ addr+0(FP), DI + MOVL old+8(FP), AX + MOVL new+12(FP), DX + LOCK + CMPXCHGL DX, 0(DI) + MOVL AX, prev+16(FP) + RET diff --git a/pkg/sentry/platform/safecopy/memclr_amd64.s b/pkg/sentry/platform/safecopy/memclr_amd64.s new file mode 100644 index 000000000..7d1019f60 --- /dev/null +++ b/pkg/sentry/platform/safecopy/memclr_amd64.s @@ -0,0 +1,157 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleMemclrFault returns (the value stored in AX, the value stored in DI). +// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS, +// with the faulting address stored in AX and the signal number stored in DI. +// +// It must have the same frame configuration as memclr so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemclrFault(SB), NOSPLIT, $0-28 + MOVQ AX, addr+16(FP) + MOVL DI, sig+24(FP) + RET + +// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS +// signal is received during the write, it returns the address that caused the +// fault and the number of the signal that was received. Otherwise, it returns +// an unspecified address and a signal number of 0. +// +// Data is written in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully written. +// +// The code is derived from runtime.memclrNoHeapPointers. +// +// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memclr(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemclrFault will store a different value in this address. + MOVL $0, sig+24(FP) + + MOVQ ptr+0(FP), DI + MOVQ n+8(FP), BX + XORQ AX, AX + + // MOVOU seems always faster than REP STOSQ. +tail: + TESTQ BX, BX + JEQ _0 + CMPQ BX, $2 + JBE _1or2 + CMPQ BX, $4 + JBE _3or4 + CMPQ BX, $8 + JB _5through7 + JE _8 + CMPQ BX, $16 + JBE _9through16 + PXOR X0, X0 + CMPQ BX, $32 + JBE _17through32 + CMPQ BX, $64 + JBE _33through64 + CMPQ BX, $128 + JBE _65through128 + CMPQ BX, $256 + JBE _129through256 + // TODO: use branch table and BSR to make this just a single dispatch + // TODO: for really big clears, use MOVNTDQ, even without AVX2. + +loop: + MOVOU X0, 0(DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, 64(DI) + MOVOU X0, 80(DI) + MOVOU X0, 96(DI) + MOVOU X0, 112(DI) + MOVOU X0, 128(DI) + MOVOU X0, 144(DI) + MOVOU X0, 160(DI) + MOVOU X0, 176(DI) + MOVOU X0, 192(DI) + MOVOU X0, 208(DI) + MOVOU X0, 224(DI) + MOVOU X0, 240(DI) + SUBQ $256, BX + ADDQ $256, DI + CMPQ BX, $256 + JAE loop + JMP tail + +_1or2: + MOVB AX, (DI) + MOVB AX, -1(DI)(BX*1) + RET +_0: + RET +_3or4: + MOVW AX, (DI) + MOVW AX, -2(DI)(BX*1) + RET +_5through7: + MOVL AX, (DI) + MOVL AX, -4(DI)(BX*1) + RET +_8: + // We need a separate case for 8 to make sure we clear pointers atomically. + MOVQ AX, (DI) + RET +_9through16: + MOVQ AX, (DI) + MOVQ AX, -8(DI)(BX*1) + RET +_17through32: + MOVOU X0, (DI) + MOVOU X0, -16(DI)(BX*1) + RET +_33through64: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET +_65through128: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, -64(DI)(BX*1) + MOVOU X0, -48(DI)(BX*1) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET +_129through256: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, 64(DI) + MOVOU X0, 80(DI) + MOVOU X0, 96(DI) + MOVOU X0, 112(DI) + MOVOU X0, -128(DI)(BX*1) + MOVOU X0, -112(DI)(BX*1) + MOVOU X0, -96(DI)(BX*1) + MOVOU X0, -80(DI)(BX*1) + MOVOU X0, -64(DI)(BX*1) + MOVOU X0, -48(DI)(BX*1) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET diff --git a/pkg/sentry/platform/safecopy/memcpy_amd64.s b/pkg/sentry/platform/safecopy/memcpy_amd64.s new file mode 100644 index 000000000..96ef2eefc --- /dev/null +++ b/pkg/sentry/platform/safecopy/memcpy_amd64.s @@ -0,0 +1,242 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleMemcpyFault returns (the value stored in AX, the value stored in DI). +// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS, +// with the faulting address stored in AX and the signal number stored in DI. +// +// It must have the same frame configuration as memcpy so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemcpyFault(SB), NOSPLIT, $0-36 + MOVQ AX, addr+24(FP) + MOVL DI, sig+32(FP) + RET + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +// The code is derived from the forward copying part of runtime.memmove. +// +// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memcpy(SB), NOSPLIT, $0-36 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemcpyFault will store a different value in this address. + MOVL $0, sig+32(FP) + + MOVQ to+0(FP), DI + MOVQ from+8(FP), SI + MOVQ n+16(FP), BX + + // REP instructions have a high startup cost, so we handle small sizes + // with some straightline code. The REP MOVSQ instruction is really fast + // for large sizes. The cutover is approximately 2K. +tail: + // move_129through256 or smaller work whether or not the source and the + // destination memory regions overlap because they load all data into + // registers before writing it back. move_256through2048 on the other + // hand can be used only when the memory regions don't overlap or the copy + // direction is forward. + TESTQ BX, BX + JEQ move_0 + CMPQ BX, $2 + JBE move_1or2 + CMPQ BX, $4 + JBE move_3or4 + CMPQ BX, $8 + JB move_5through7 + JE move_8 + CMPQ BX, $16 + JBE move_9through16 + CMPQ BX, $32 + JBE move_17through32 + CMPQ BX, $64 + JBE move_33through64 + CMPQ BX, $128 + JBE move_65through128 + CMPQ BX, $256 + JBE move_129through256 + // TODO: use branch table and BSR to make this just a single dispatch + +/* + * forward copy loop + */ + CMPQ BX, $2048 + JLS move_256through2048 + + // Check alignment + MOVL SI, AX + ORL DI, AX + TESTL $7, AX + JEQ fwdBy8 + + // Do 1 byte at a time + MOVQ BX, CX + REP; MOVSB + RET + +fwdBy8: + // Do 8 bytes at a time + MOVQ BX, CX + SHRQ $3, CX + ANDQ $7, BX + REP; MOVSQ + JMP tail + +move_1or2: + MOVB (SI), AX + MOVB AX, (DI) + MOVB -1(SI)(BX*1), CX + MOVB CX, -1(DI)(BX*1) + RET +move_0: + RET +move_3or4: + MOVW (SI), AX + MOVW AX, (DI) + MOVW -2(SI)(BX*1), CX + MOVW CX, -2(DI)(BX*1) + RET +move_5through7: + MOVL (SI), AX + MOVL AX, (DI) + MOVL -4(SI)(BX*1), CX + MOVL CX, -4(DI)(BX*1) + RET +move_8: + // We need a separate case for 8 to make sure we write pointers atomically. + MOVQ (SI), AX + MOVQ AX, (DI) + RET +move_9through16: + MOVQ (SI), AX + MOVQ AX, (DI) + MOVQ -8(SI)(BX*1), CX + MOVQ CX, -8(DI)(BX*1) + RET +move_17through32: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU -16(SI)(BX*1), X1 + MOVOU X1, -16(DI)(BX*1) + RET +move_33through64: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU -32(SI)(BX*1), X2 + MOVOU X2, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X3 + MOVOU X3, -16(DI)(BX*1) + RET +move_65through128: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU -64(SI)(BX*1), X4 + MOVOU X4, -64(DI)(BX*1) + MOVOU -48(SI)(BX*1), X5 + MOVOU X5, -48(DI)(BX*1) + MOVOU -32(SI)(BX*1), X6 + MOVOU X6, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X7 + MOVOU X7, -16(DI)(BX*1) + RET +move_129through256: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU -128(SI)(BX*1), X8 + MOVOU X8, -128(DI)(BX*1) + MOVOU -112(SI)(BX*1), X9 + MOVOU X9, -112(DI)(BX*1) + MOVOU -96(SI)(BX*1), X10 + MOVOU X10, -96(DI)(BX*1) + MOVOU -80(SI)(BX*1), X11 + MOVOU X11, -80(DI)(BX*1) + MOVOU -64(SI)(BX*1), X12 + MOVOU X12, -64(DI)(BX*1) + MOVOU -48(SI)(BX*1), X13 + MOVOU X13, -48(DI)(BX*1) + MOVOU -32(SI)(BX*1), X14 + MOVOU X14, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X15 + MOVOU X15, -16(DI)(BX*1) + RET +move_256through2048: + SUBQ $256, BX + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU 128(SI), X8 + MOVOU X8, 128(DI) + MOVOU 144(SI), X9 + MOVOU X9, 144(DI) + MOVOU 160(SI), X10 + MOVOU X10, 160(DI) + MOVOU 176(SI), X11 + MOVOU X11, 176(DI) + MOVOU 192(SI), X12 + MOVOU X12, 192(DI) + MOVOU 208(SI), X13 + MOVOU X13, 208(DI) + MOVOU 224(SI), X14 + MOVOU X14, 224(DI) + MOVOU 240(SI), X15 + MOVOU X15, 240(DI) + CMPQ BX, $256 + LEAQ 256(SI), SI + LEAQ 256(DI), DI + JGE move_256through2048 + JMP tail diff --git a/pkg/sentry/platform/safecopy/safecopy.go b/pkg/sentry/platform/safecopy/safecopy.go new file mode 100644 index 000000000..90a2aad7b --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy.go @@ -0,0 +1,140 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package safecopy provides an efficient implementation of functions to access +// memory that may result in SIGSEGV or SIGBUS being sent to the accessor. +package safecopy + +import ( + "fmt" + "reflect" + "runtime" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/syserror" +) + +// SegvError is returned when a safecopy function receives SIGSEGV. +type SegvError struct { + // Addr is the address at which the SIGSEGV occurred. + Addr uintptr +} + +// Error implements error.Error. +func (e SegvError) Error() string { + return fmt.Sprintf("SIGSEGV at %#x", e.Addr) +} + +// BusError is returned when a safecopy function receives SIGBUS. +type BusError struct { + // Addr is the address at which the SIGBUS occurred. + Addr uintptr +} + +// Error implements error.Error. +func (e BusError) Error() string { + return fmt.Sprintf("SIGBUS at %#x", e.Addr) +} + +// AlignmentError is returned when a safecopy function is passed an address +// that does not meet alignment requirements. +type AlignmentError struct { + // Addr is the invalid address. + Addr uintptr + + // Alignment is the required alignment. + Alignment uintptr +} + +// Error implements error.Error. +func (e AlignmentError) Error() string { + return fmt.Sprintf("address %#x is not aligned to a %d-byte boundary", e.Addr, e.Alignment) +} + +var ( + // The begin and end addresses below are for the functions that are + // checked by the signal handler. + memcpyBegin uintptr + memcpyEnd uintptr + memclrBegin uintptr + memclrEnd uintptr + swapUint32Begin uintptr + swapUint32End uintptr + swapUint64Begin uintptr + swapUint64End uintptr + compareAndSwapUint32Begin uintptr + compareAndSwapUint32End uintptr + + // savedSigSegVHandler is a pointer to the SIGSEGV handler that was + // configured before we replaced it with our own. We still call into it + // when we get a SIGSEGV that is not interesting to us. + savedSigSegVHandler uintptr + + // same a above, but for SIGBUS signals. + savedSigBusHandler uintptr +) + +// signalHandler is our replacement signal handler for SIGSEGV and SIGBUS +// signals. +func signalHandler() + +// FindEndAddress returns the end address (one byte beyond the last) of the +// function that contains the specified address (begin). +func FindEndAddress(begin uintptr) uintptr { + f := runtime.FuncForPC(begin) + if f != nil { + for p := begin; ; p++ { + g := runtime.FuncForPC(p) + if f != g { + return p + } + } + } + return begin +} + +// initializeAddresses initializes the addresses used by the signal handler. +func initializeAddresses() { + // The following functions are written in assembly language, so they won't + // be inlined by the existing compiler/linker. Tests will fail if this + // assumption is violated. + memcpyBegin = reflect.ValueOf(memcpy).Pointer() + memcpyEnd = FindEndAddress(memcpyBegin) + memclrBegin = reflect.ValueOf(memclr).Pointer() + memclrEnd = FindEndAddress(memclrBegin) + swapUint32Begin = reflect.ValueOf(swapUint32).Pointer() + swapUint32End = FindEndAddress(swapUint32Begin) + swapUint64Begin = reflect.ValueOf(swapUint64).Pointer() + swapUint64End = FindEndAddress(swapUint64Begin) + compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer() + compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin) +} + +func init() { + initializeAddresses() + if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) + } + if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) + } + syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) { + switch e.(type) { + case SegvError, BusError, AlignmentError: + return syscall.EFAULT, true + default: + return 0, false + } + }) +} diff --git a/pkg/sentry/platform/safecopy/safecopy_test.go b/pkg/sentry/platform/safecopy/safecopy_test.go new file mode 100644 index 000000000..67df36121 --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy_test.go @@ -0,0 +1,617 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safecopy + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "runtime/debug" + "syscall" + "testing" + "unsafe" +) + +// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular +// dependency. +const pageSize = 4096 + +func initRandom(b []byte) { + for i := range b { + b[i] = byte(rand.Intn(256)) + } +} + +func randBuf(size int) []byte { + b := make([]byte, size) + initRandom(b) + return b +} + +func TestCopyInSuccess(t *testing.T) { + // Test that CopyIn does not return an error when all pages are accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := CopyIn(b, unsafe.Pointer(&a[0])) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestCopyOutSuccess(t *testing.T) { + // Test that CopyOut does not return an error when all pages are + // accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := CopyOut(unsafe.Pointer(&b[0]), a) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestCopySuccess(t *testing.T) { + // Test that Copy does not return an error when all pages are accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestZeroOutSuccess(t *testing.T) { + // Test that ZeroOut does not return an error when all pages are + // accessible. + const bufLen = 8192 + a := make([]byte, bufLen) + b := randBuf(bufLen) + + n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestSwapUint32Success(t *testing.T) { + // Test that SwapUint32 does not return an error when the page is + // accessible. + before := uint32(rand.Int31()) + after := uint32(rand.Int31()) + val := before + + old, err := SwapUint32(unsafe.Pointer(&val), after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if val != after { + t.Errorf("Unexpected new value: got %v, want %v", val, after) + } +} + +func TestSwapUint32AlignmentError(t *testing.T) { + // Test that SwapUint32 returns an AlignmentError when passed an unaligned + // address. + data := new(struct{ val uint64 }) + addr := uintptr(unsafe.Pointer(&data.val)) + 1 + want := AlignmentError{Addr: addr, Alignment: 4} + if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +func TestSwapUint64Success(t *testing.T) { + // Test that SwapUint64 does not return an error when the page is + // accessible. + before := uint64(rand.Int63()) + after := uint64(rand.Int63()) + // "The first word in ... an allocated struct or slice can be relied upon + // to be 64-bit aligned." - sync/atomic docs + data := new(struct{ val uint64 }) + data.val = before + + old, err := SwapUint64(unsafe.Pointer(&data.val), after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if data.val != after { + t.Errorf("Unexpected new value: got %v, want %v", data.val, after) + } +} + +func TestSwapUint64AlignmentError(t *testing.T) { + // Test that SwapUint64 returns an AlignmentError when passed an unaligned + // address. + data := new(struct{ val1, val2 uint64 }) + addr := uintptr(unsafe.Pointer(&data.val1)) + 1 + want := AlignmentError{Addr: addr, Alignment: 8} + if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +func TestCompareAndSwapUint32Success(t *testing.T) { + // Test that CompareAndSwapUint32 does not return an error when the page is + // accessible. + before := uint32(rand.Int31()) + after := uint32(rand.Int31()) + val := before + + old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if val != after { + t.Errorf("Unexpected new value: got %v, want %v", val, after) + } +} + +func TestCompareAndSwapUint32AlignmentError(t *testing.T) { + // Test that CompareAndSwapUint32 returns an AlignmentError when passed an + // unaligned address. + data := new(struct{ val uint64 }) + addr := uintptr(unsafe.Pointer(&data.val)) + 1 + want := AlignmentError{Addr: addr, Alignment: 4} + if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +// withSegvErrorTestMapping calls fn with a two-page mapping. The first page +// contains random data, and the second page generates SIGSEGV when accessed. +func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { + mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + } + defer syscall.Munmap(mapping) + if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil { + t.Fatalf("Mprotect failed: %v", err) + } + initRandom(mapping[:pageSize]) + + fn(mapping) +} + +// withBusErrorTestMapping calls fn with a two-page mapping. The first page +// contains random data, and the second page generates SIGBUS when accessed. +func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + defer f.Close() + if err := f.Truncate(pageSize); err != nil { + t.Fatalf("Truncate failed: %v", err) + } + mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + } + defer syscall.Munmap(mapping) + initRandom(mapping[:pageSize]) + + fn(mapping) +} + +func TestCopyInSegvError(t *testing.T) { + // Test that CopyIn returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := CopyIn(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyInBusError(t *testing.T) { + // Test that CopyIn returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := CopyIn(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyOutSegvError(t *testing.T) { + // Test that CopyOut returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := CopyOut(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyOutBusError(t *testing.T) { + // Test that CopyOut returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := CopyOut(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopySourceSegvError(t *testing.T) { + // Test that Copy returns a SegvError when copying from a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopySourceBusError(t *testing.T) { + // Test that Copy returns a BusError when copying from a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyDestinationSegvError(t *testing.T) { + // Test that Copy returns a SegvError when copying to a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyDestinationBusError(t *testing.T) { + // Test that Copy returns a BusError when copying to a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestZeroOutSegvError(t *testing.T) { + // Test that ZeroOut returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + n, err := ZeroOut(dst, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { + t.Errorf("Non-zero bytes in written part of mapping: %v", got) + } + }) + }) + } +} + +func TestZeroOutBusError(t *testing.T) { + // Test that ZeroOut returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + n, err := ZeroOut(dst, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { + t.Errorf("Non-zero bytes in written part of mapping: %v", got) + } + }) + }) + } +} + +func TestSwapUint32SegvError(t *testing.T) { + // Test that SwapUint32 returns a SegvError when reaching a page that + // signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint32(unsafe.Pointer(secondPage), 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint32BusError(t *testing.T) { + // Test that SwapUint32 returns a BusError when reaching a page that + // signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint32(unsafe.Pointer(secondPage), 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint64SegvError(t *testing.T) { + // Test that SwapUint64 returns a SegvError when reaching a page that + // signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint64(unsafe.Pointer(secondPage), 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint64BusError(t *testing.T) { + // Test that SwapUint64 returns a BusError when reaching a page that + // signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint64(unsafe.Pointer(secondPage), 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestCompareAndSwapUint32SegvError(t *testing.T) { + // Test that CompareAndSwapUint32 returns a SegvError when reaching a page + // that signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestCompareAndSwapUint32BusError(t *testing.T) { + // Test that CompareAndSwapUint32 returns a BusError when reaching a page + // that signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func testCopy(dst, src []byte) (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + debug.SetPanicOnFault(true) + copy(dst, src) + return +} + +func TestSegVOnMemmove(t *testing.T) { + // Test that SIGSEGVs received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer syscall.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} + +func TestSigbusOnMemmove(t *testing.T) { + // Test that SIGBUS received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + os.Remove(f.Name()) + defer f.Close() + + a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer syscall.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} diff --git a/pkg/sentry/platform/safecopy/safecopy_unsafe.go b/pkg/sentry/platform/safecopy/safecopy_unsafe.go new file mode 100644 index 000000000..72f243f8d --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy_unsafe.go @@ -0,0 +1,315 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safecopy + +import ( + "fmt" + "syscall" + "unsafe" +) + +// maxRegisterSize is the maximum register size used in memcpy and memclr. It +// is used to decide by how much to rewind the copy (for memcpy) or zeroing +// (for memclr) before proceeding. +const maxRegisterSize = 16 + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +//go:noescape +func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) + +// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS +// signal is received during the write, it returns the address that caused the +// fault and the number of the signal that was received. Otherwise, it returns +// an unspecified address and a signal number of 0. +// +// Data is written in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully written. +// +//go:noescape +func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) + +// swapUint32 atomically stores new into *ptr and returns (the previous *ptr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) + +// swapUint64 atomically stores new into *ptr and returns (the previous *ptr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: ptr must be aligned to a 8-byte boundary. +// +//go:noescape +func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) + +// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns +// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is +// received during the operation, the value of prev is unspecified, and sig is +// the number of the signal that was received. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) + +// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes +// copied and an error if SIGSEGV or SIGBUS is received while reading from src. +func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { + toCopy := uintptr(len(dst)) + if len(dst) == 0 { + return 0, nil + } + + fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy) + if sig == 0 { + return len(dst), nil + } + + if faultN, srcN := uintptr(fault), uintptr(src); faultN < srcN && faultN >= srcN+toCopy { + panic(fmt.Sprintf("CopyIn faulted at %#x, which is outside source [%#x, %#x)", faultN, srcN, srcN+toCopy)) + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + faultN, srcN := uintptr(fault), uintptr(src) + var done int + if faultN-srcN > maxRegisterSize { + done = int(faultN - srcN - maxRegisterSize) + } + n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done))) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// CopyOut copies len(src) bytes from src to dst. If returns the number of +// bytes done and an error if SIGSEGV or SIGBUS is received while writing to +// dst. +func CopyOut(dst unsafe.Pointer, src []byte) (int, error) { + toCopy := uintptr(len(src)) + if toCopy == 0 { + return 0, nil + } + + fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy) + if sig == 0 { + return len(src), nil + } + + if faultN, dstN := uintptr(fault), uintptr(dst); faultN < dstN && faultN >= dstN+toCopy { + panic(fmt.Sprintf("CopyOut faulted at %#x, which is outside destination [%#x, %#x)", faultN, dstN, dstN+toCopy)) + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + faultN, dstN := uintptr(fault), uintptr(dst) + var done int + if faultN-dstN > maxRegisterSize { + done = int(faultN - dstN - maxRegisterSize) + } + n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)]) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// Copy copies toCopy bytes from src to dst. It returns the number of bytes +// copied and an error if SIGSEGV or SIGBUS is received while reading from src +// or writing to dst. +// +// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap, +// the resulting contents of dst are unspecified. +func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) { + if toCopy == 0 { + return 0, nil + } + + fault, sig := memcpy(dst, src, toCopy) + if sig == 0 { + return toCopy, nil + } + + // Did the fault occur while reading from src or writing to dst? + faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst) + faultAfterSrc := ^uintptr(0) + if faultN >= srcN { + faultAfterSrc = faultN - srcN + } + faultAfterDst := ^uintptr(0) + if faultN >= dstN { + faultAfterDst = faultN - dstN + } + if faultAfterSrc >= toCopy && faultAfterDst >= toCopy { + panic(fmt.Sprintf("Copy faulted at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", faultN, srcN, srcN+toCopy, dstN, dstN+toCopy)) + } + faultedAfter := faultAfterSrc + if faultedAfter > faultAfterDst { + faultedAfter = faultAfterDst + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + var done uintptr + if faultedAfter > maxRegisterSize { + done = faultedAfter - maxRegisterSize + } + n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes +// written and an error if SIGSEGV or SIGBUS is received while writing to dst. +func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) { + if toZero == 0 { + return 0, nil + } + + fault, sig := memclr(dst, toZero) + if sig == 0 { + return toZero, nil + } + + if faultN, dstN := uintptr(fault), uintptr(dst); faultN < dstN && faultN >= dstN+toZero { + panic(fmt.Sprintf("ZeroOut faulted at %#x, which is outside destination [%#x, %#x)", faultN, dstN, dstN+toZero)) + } + + // memclr might have ended the write up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to write up to the fault. + faultN, dstN := uintptr(fault), uintptr(dst) + var done uintptr + if faultN-dstN > maxRegisterSize { + done = faultN - dstN - maxRegisterSize + } + n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns +// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is +// not aligned to a 4-byte boundary. +func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + old, sig := swapUint32(ptr, new) + return old, errorFromFaultSignal(ptr, sig) +} + +// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns +// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is +// not aligned to an 8-byte boundary. +func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) { + if addr := uintptr(ptr); addr&7 != 0 { + return 0, AlignmentError{addr, 8} + } + old, sig := swapUint64(ptr, new) + return old, errorFromFaultSignal(ptr, sig) +} + +// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32, +// except that it returns an error if SIGSEGV or SIGBUS is received while +// accessing ptr, or if ptr is not aligned to a 4-byte boundary. +func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + prev, sig := compareAndSwapUint32(ptr, old, new) + return prev, errorFromFaultSignal(ptr, sig) +} + +func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error { + switch sig { + case 0: + return nil + case int32(syscall.SIGSEGV): + return SegvError{uintptr(addr)} + case int32(syscall.SIGBUS): + return BusError{uintptr(addr)} + default: + panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) + } +} + +// ReplaceSignalHandler replaces the existing signal handler for the provided +// signal with the one that handles faults in safecopy-protected functions. +// +// It stores the value of the previously set handler in previous. +// +// This function will be called on initialization in order to install safecopy +// handlers for appropriate signals. These handlers will call the previous +// handler however, and if this is function is being used externally then the +// same courtesy is expected. +func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error { + var sa struct { + handler uintptr + flags uint64 + restorer uintptr + mask uint64 + } + const maskLen = 8 + + // Get the existing signal handler information, and save the current + // handler. Once we replace it, we will use this pointer to fall back to + // it when we receive other signals. + if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { + return e + } + + // Fail if there isn't a previous handler. + if sa.handler == 0 { + return fmt.Errorf("previous handler for signal %x isn't set", sig) + } + + *previous = sa.handler + + // Install our own handler. + sa.handler = handler + if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { + return e + } + + return nil +} diff --git a/pkg/sentry/platform/safecopy/sighandler_amd64.s b/pkg/sentry/platform/safecopy/sighandler_amd64.s new file mode 100644 index 000000000..a65cb0c26 --- /dev/null +++ b/pkg/sentry/platform/safecopy/sighandler_amd64.s @@ -0,0 +1,124 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// The signals handled by sigHandler. +#define SIGBUS 7 +#define SIGSEGV 11 + +// Offsets to the registers in context->uc_mcontext.gregs[]. +#define REG_RDI 0x68 +#define REG_RAX 0x90 +#define REG_IP 0xa8 + +// Offset to the si_addr field of siginfo. +#define SI_CODE 0x08 +#define SI_ADDR 0x10 + +// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must +// not be set up as a handler to any other signals. +// +// If the instruction causing the signal is within a safecopy-protected +// function, the signal is handled such that execution resumes in the +// appropriate fault handling stub with AX containing the faulting address and +// DI containing the signal number. Otherwise control is transferred to the +// previously configured signal handler (savedSigSegvHandler or +// savedSigBusHandler). +// +// This function cannot be written in go because it runs whenever a signal is +// received by the thread (preempting whatever was running), which includes when +// garbage collector has stopped or isn't expecting any interactions (like +// barriers). +// +// The arguments are the following: +// DI - The signal number. +// SI - Pointer to siginfo_t structure. +// DX - Pointer to ucontext structure. +TEXT ·signalHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $0x0, CX + CMPL CX, SI_CODE(SI) + JGE original_handler + + // Check if RIP is within the area we care about. + MOVQ REG_IP(DX), CX + CMPQ CX, ·memcpyBegin(SB) + JB not_memcpy + CMPQ CX, ·memcpyEnd(SB) + JAE not_memcpy + + // Modify the context such that execution will resume in the fault + // handler. + LEAQ handleMemcpyFault(SB), CX + JMP handle_fault + +not_memcpy: + CMPQ CX, ·memclrBegin(SB) + JB not_memclr + CMPQ CX, ·memclrEnd(SB) + JAE not_memclr + + LEAQ handleMemclrFault(SB), CX + JMP handle_fault + +not_memclr: + CMPQ CX, ·swapUint32Begin(SB) + JB not_swapuint32 + CMPQ CX, ·swapUint32End(SB) + JAE not_swapuint32 + + LEAQ handleSwapUint32Fault(SB), CX + JMP handle_fault + +not_swapuint32: + CMPQ CX, ·swapUint64Begin(SB) + JB not_swapuint64 + CMPQ CX, ·swapUint64End(SB) + JAE not_swapuint64 + + LEAQ handleSwapUint64Fault(SB), CX + JMP handle_fault + +not_swapuint64: + CMPQ CX, ·compareAndSwapUint32Begin(SB) + JB not_casuint32 + CMPQ CX, ·compareAndSwapUint32End(SB) + JAE not_casuint32 + + LEAQ handleCompareAndSwapUint32Fault(SB), CX + JMP handle_fault + +not_casuint32: +original_handler: + // Jump to the previous signal handler, which is likely the golang one. + XORQ CX, CX + MOVQ ·savedSigBusHandler(SB), AX + CMPL DI, $SIGSEGV + CMOVQEQ ·savedSigSegVHandler(SB), AX + JMP AX + +handle_fault: + // Entered with the address of the fault handler in RCX; store it in + // RIP. + MOVQ CX, REG_IP(DX) + + // Store the faulting address in RAX. + MOVQ SI_ADDR(SI), CX + MOVQ CX, REG_RAX(DX) + + // Store the signal number in EDI. + MOVL DI, REG_RDI(DX) + + RET |