summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go41
-rw-r--r--pkg/sentry/kernel/auth/credentials.go28
-rw-r--r--pkg/sentry/kernel/task_exec.go4
-rw-r--r--pkg/sentry/pgalloc/BUILD22
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go215
-rw-r--r--pkg/sentry/pgalloc/pgalloc_test.go198
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/address_space.go73
-rw-r--r--pkg/sentry/platform/kvm/bluepill_allocator.go (renamed from pkg/sentry/platform/kvm/allocator.go)52
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go9
-rw-r--r--pkg/sentry/platform/kvm/machine.go52
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go34
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/kernel.go24
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator.go11
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go8
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go36
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go9
-rw-r--r--pkg/sentry/socket/netstack/stack.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go145
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go4
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go3
-rw-r--r--pkg/sentry/vfs/mount.go34
-rw-r--r--pkg/sentry/vfs/options.go4
-rw-r--r--pkg/sentry/vfs/vfs.go2
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/stack/iptables.go42
-rw-r--r--pkg/tcpip/stack/iptables_types.go15
-rw-r--r--pkg/tcpip/stack/stack.go23
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/BUILD4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go42
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go51
-rw-r--r--pkg/tcpip/transport/tcp/snd.go31
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go5
-rw-r--r--pkg/tmutex/BUILD17
-rw-r--r--pkg/tmutex/tmutex.go81
-rw-r--r--pkg/tmutex/tmutex_test.go258
47 files changed, 872 insertions, 770 deletions
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 2608e7e1d..1d5aa82dc 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -38,6 +38,9 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
// Preconditions: fs.interop != InteropModeShared.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
if err := mnt.CheckBeginWrite(); err != nil {
return
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index f0e098702..3777ebdf2 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -30,6 +30,7 @@ package tmpfs
import (
"fmt"
"math"
+ "strconv"
"strings"
"sync/atomic"
@@ -124,14 +125,45 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
fs.vfsfs.Init(vfsObj, newFSType, &fs)
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+
+ defaultMode := linux.FileMode(0777)
+ if modeStr, ok := mopts["mode"]; ok {
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ return nil, nil, fmt.Errorf("Mount option \"mode='%v'\" not parsable: %v", modeStr, err)
+ }
+ defaultMode = linux.FileMode(mode)
+ }
+
+ defaultOwnerCreds := creds.Fork()
+ if uidStr, ok := mopts["uid"]; ok {
+ uid, err := strconv.ParseInt(uidStr, 10, 32)
+ if err != nil {
+ return nil, nil, fmt.Errorf("Mount option \"uid='%v'\" not parsable: %v", uidStr, err)
+ }
+ if err := defaultOwnerCreds.SetUID(auth.UID(uid)); err != nil {
+ return nil, nil, fmt.Errorf("Error using mount option \"uid='%v'\": %v", uidStr, err)
+ }
+ }
+ if gidStr, ok := mopts["gid"]; ok {
+ gid, err := strconv.ParseInt(gidStr, 10, 32)
+ if err != nil {
+ return nil, nil, fmt.Errorf("Mount option \"gid='%v'\" not parsable: %v", gidStr, err)
+ }
+ if err := defaultOwnerCreds.SetGID(auth.GID(gid)); err != nil {
+ return nil, nil, fmt.Errorf("Error using mount option \"gid='%v'\": %v", gidStr, err)
+ }
+ }
+
var root *dentry
switch rootFileType {
case linux.S_IFREG:
- root = fs.newDentry(fs.newRegularFile(creds, 0777))
+ root = fs.newDentry(fs.newRegularFile(defaultOwnerCreds, defaultMode))
case linux.S_IFLNK:
- root = fs.newDentry(fs.newSymlink(creds, tmpfsOpts.RootSymlinkTarget))
+ root = fs.newDentry(fs.newSymlink(defaultOwnerCreds, tmpfsOpts.RootSymlinkTarget))
case linux.S_IFDIR:
- root = &fs.newDirectory(creds, 01777).dentry
+ root = &fs.newDirectory(defaultOwnerCreds, defaultMode).dentry
default:
fs.vfsfs.DecRef()
return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
@@ -562,6 +594,9 @@ func (i *inode) isDir() bool {
}
func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
if err := mnt.CheckBeginWrite(); err != nil {
return
}
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
index e057d2c6d..6862f2ef5 100644
--- a/pkg/sentry/kernel/auth/credentials.go
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -232,3 +232,31 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) {
}
return NoID, syserror.EPERM
}
+
+// SetUID translates the provided uid to the root user namespace and updates c's
+// uids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetUID(uid UID) error {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKUID = kuid
+ c.EffectiveKUID = kuid
+ c.SavedKUID = kuid
+ return nil
+}
+
+// SetGID translates the provided gid to the root user namespace and updates c's
+// gids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetGID(gid GID) error {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKGID = kgid
+ c.EffectiveKGID = kgid
+ c.SavedKGID = kgid
+ return nil
+}
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 00c425cca..9b69f3cbe 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -198,6 +198,10 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
+ oldFDTable := t.fdTable
+ t.fdTable = t.fdTable.Fork()
+ oldFDTable.DecRef()
+
// Remove FDs with the CloseOnExec flag set.
t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
return flags.CloseOnExec
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 1eeb9f317..a9836ba71 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -33,6 +33,7 @@ go_template_instance(
out = "usage_set.go",
consts = {
"minDegree": "10",
+ "trackGaps": "1",
},
imports = {
"platform": "gvisor.dev/gvisor/pkg/sentry/platform",
@@ -48,6 +49,26 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "reclaim_set",
+ out = "reclaim_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ },
+ package = "pgalloc",
+ prefix = "reclaim",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "platform.FileRange",
+ "Value": "reclaimSetValue",
+ "Functions": "reclaimSetFunctions",
+ },
+)
+
go_library(
name = "pgalloc",
srcs = [
@@ -56,6 +77,7 @@ go_library(
"evictable_range_set.go",
"pgalloc.go",
"pgalloc_unsafe.go",
+ "reclaim_set.go",
"save_restore.go",
"usage_set.go",
],
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 2b11ea4ae..c8d9facc2 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -108,12 +108,6 @@ type MemoryFile struct {
usageSwapped uint64
usageLast time.Time
- // minUnallocatedPage is the minimum page that may be unallocated.
- // i.e., there are no unallocated pages below minUnallocatedPage.
- //
- // minUnallocatedPage is protected by mu.
- minUnallocatedPage uint64
-
// fileSize is the size of the backing memory file in bytes. fileSize is
// always a power-of-two multiple of chunkSize.
//
@@ -146,11 +140,9 @@ type MemoryFile struct {
// is protected by mu.
reclaimable bool
- // minReclaimablePage is the minimum page that may be reclaimable.
- // i.e., all reclaimable pages are >= minReclaimablePage.
- //
- // minReclaimablePage is protected by mu.
- minReclaimablePage uint64
+ // relcaim is the collection of regions for reclaim. relcaim is protected
+ // by mu.
+ reclaim reclaimSet
// reclaimCond is signaled (with mu locked) when reclaimable or destroyed
// transitions from false to true.
@@ -273,12 +265,10 @@ type evictableMemoryUserInfo struct {
}
const (
- chunkShift = 24
- chunkSize = 1 << chunkShift // 16 MB
+ chunkShift = 30
+ chunkSize = 1 << chunkShift // 1 GB
chunkMask = chunkSize - 1
- initialSize = chunkSize
-
// maxPage is the highest 64-bit page.
maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
)
@@ -302,19 +292,12 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
if err := file.Truncate(0); err != nil {
return nil, err
}
- if err := file.Truncate(initialSize); err != nil {
- return nil, err
- }
f := &MemoryFile{
- opts: opts,
- fileSize: initialSize,
- file: file,
- // No pages are reclaimable. DecRef will always be able to
- // decrease minReclaimablePage from this point.
- minReclaimablePage: maxPage,
- evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
+ opts: opts,
+ file: file,
+ evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
}
- f.mappings.Store(make([]uintptr, initialSize/chunkSize))
+ f.mappings.Store(make([]uintptr, 0))
f.reclaimCond.L = &f.mu
if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure {
@@ -404,39 +387,28 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
alignment = usermem.HugePageSize
}
- start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment)
- end := start + length
- // File offsets are int64s. Since length must be strictly positive, end
- // cannot legitimately be 0.
- if end < start || int64(end) <= 0 {
+ // Find a range in the underlying file.
+ fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
+ if !ok {
return platform.FileRange{}, syserror.ENOMEM
}
- // Expand the file if needed. Double the file size on each expansion;
- // uncommitted pages have effectively no cost.
- fileSize := f.fileSize
- for int64(end) > fileSize {
- if fileSize >= 2*fileSize {
- // fileSize overflow.
- return platform.FileRange{}, syserror.ENOMEM
- }
- fileSize *= 2
- }
- if fileSize > f.fileSize {
- if err := f.file.Truncate(fileSize); err != nil {
+ // Expand the file if needed. Note that findAvailableRange will
+ // appropriately double the fileSize when required.
+ if int64(fr.End) > f.fileSize {
+ if err := f.file.Truncate(int64(fr.End)); err != nil {
return platform.FileRange{}, err
}
- f.fileSize = fileSize
+ f.fileSize = int64(fr.End)
f.mappingsMu.Lock()
oldMappings := f.mappings.Load().([]uintptr)
- newMappings := make([]uintptr, fileSize>>chunkShift)
+ newMappings := make([]uintptr, f.fileSize>>chunkShift)
copy(newMappings, oldMappings)
f.mappings.Store(newMappings)
f.mappingsMu.Unlock()
}
// Mark selected pages as in use.
- fr := platform.FileRange{start, end}
if f.opts.ManualZeroing {
if err := f.forEachMappingSlice(fr, func(bs []byte) {
for i := range bs {
@@ -453,49 +425,71 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage))
}
- if minUnallocatedPage < start {
- f.minUnallocatedPage = minUnallocatedPage
- } else {
- // start was the first unallocated page. The next must be
- // somewhere beyond end.
- f.minUnallocatedPage = end
- }
-
return fr, nil
}
-// findUnallocatedRange returns the first unallocated page in usage of the
-// specified length and alignment beginning at page start and the first single
-// unallocated page.
-func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) {
- // Only searched until the first page is found.
- firstPage := start
- foundFirstPage := false
- alignMask := alignment - 1
- for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() {
- r := seg.Range()
+// findAvailableRange returns an available range in the usageSet.
+//
+// Note that scanning for available slots takes place from end first backwards,
+// then forwards. This heuristic has important consequence for how sequential
+// mappings can be merged in the host VMAs, given that addresses for both
+// application and sentry mappings are allocated top-down (from higher to
+// lower addresses). The file is also grown expoentially in order to create
+// space for mappings to be allocated downwards.
+//
+// Precondition: alignment must be a power of 2.
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) {
+ alignmentMask := alignment - 1
+ for gap := usage.UpperBoundGap(uint64(fileSize)); gap.Ok(); gap = gap.PrevLargeEnoughGap(length) {
+ // Start searching only at end of file.
+ end := gap.End()
+ if end > uint64(fileSize) {
+ end = uint64(fileSize)
+ }
- if !foundFirstPage && r.Start > firstPage {
- foundFirstPage = true
+ // Start at the top and align downwards.
+ start := end - length
+ if start > end {
+ break // Underflow.
}
+ start &^= alignmentMask
- if start >= r.End {
- // start was rounded up to an alignment boundary from the end
- // of a previous segment and is now beyond r.End.
+ // Is the gap still sufficient?
+ if start < gap.Start() {
continue
}
- // This segment represents allocated or reclaimable pages; only the
- // range from start to the segment's beginning is allocatable, and the
- // next allocatable range begins after the segment.
- if r.Start > start && r.Start-start >= length {
- break
+
+ // Allocate in the given gap.
+ return platform.FileRange{start, start + length}, true
+ }
+
+ // Check that it's possible to fit this allocation at the end of a file of any size.
+ min := usage.LastGap().Start()
+ min = (min + alignmentMask) &^ alignmentMask
+ if min+length < min {
+ // Overflow.
+ return platform.FileRange{}, false
+ }
+
+ // Determine the minimum file size required to fit this allocation at its end.
+ for {
+ if fileSize >= 2*fileSize {
+ // Is this because it's initially empty?
+ if fileSize == 0 {
+ fileSize += chunkSize
+ } else {
+ // fileSize overflow.
+ return platform.FileRange{}, false
+ }
+ } else {
+ // Double the current fileSize.
+ fileSize *= 2
}
- start = (r.End + alignMask) &^ alignMask
- if !foundFirstPage {
- firstPage = r.End
+ start := (uint64(fileSize) - length) &^ alignmentMask
+ if start >= min {
+ return platform.FileRange{start, start + length}, true
}
}
- return start, firstPage
}
// AllocateAndFill allocates memory of the given kind and fills it by calling
@@ -616,6 +610,7 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
}
val.refs--
if val.refs == 0 {
+ f.reclaim.Add(seg.Range(), reclaimSetValue{})
freed = true
// Reclassify memory as System, until it's freed by the reclaim
// goroutine.
@@ -628,10 +623,6 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
if freed {
- if fr.Start < f.minReclaimablePage {
- // We've freed at least one lower page.
- f.minReclaimablePage = fr.Start
- }
f.reclaimable = true
f.reclaimCond.Signal()
}
@@ -1030,6 +1021,7 @@ func (f *MemoryFile) String() string {
// for allocation.
func (f *MemoryFile) runReclaim() {
for {
+ // N.B. We must call f.markReclaimed on the returned FrameRange.
fr, ok := f.findReclaimable()
if !ok {
break
@@ -1085,6 +1077,10 @@ func (f *MemoryFile) runReclaim() {
}
}
+// findReclaimable finds memory that has been marked for reclaim.
+//
+// Note that there returned range will be removed from tracking. It
+// must be reclaimed (removed from f.usage) at this point.
func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
f.mu.Lock()
defer f.mu.Unlock()
@@ -1103,18 +1099,15 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
}
f.reclaimCond.Wait()
}
- // Allocate returns the first usable range in offset order and is
- // currently a linear scan, so reclaiming from the beginning of the
- // file minimizes the expected latency of Allocate.
- for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() {
- if seg.ValuePtr().refs == 0 {
- f.minReclaimablePage = seg.End()
- return seg.Range(), true
- }
+ // Allocate works from the back of the file inwards, so reclaim
+ // preserves this order to minimize the cost of the search.
+ if seg := f.reclaim.LastSegment(); seg.Ok() {
+ fr := seg.Range()
+ f.reclaim.Remove(seg)
+ return fr, true
}
- // No pages are reclaimable.
+ // Nothing is reclaimable.
f.reclaimable = false
- f.minReclaimablePage = maxPage
}
}
@@ -1122,8 +1115,8 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
seg := f.usage.FindSegment(fr.Start)
- // All of fr should be mapped to a single uncommitted reclaimable segment
- // accounted to System.
+ // All of fr should be mapped to a single uncommitted reclaimable
+ // segment accounted to System.
if !seg.Ok() {
panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage))
}
@@ -1137,14 +1130,10 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
}); got != want {
panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage))
}
- // Deallocate reclaimed pages. Even though all of seg is reclaimable, the
- // caller of markReclaimed may not have decommitted it, so we can only mark
- // fr as reclaimed.
+ // Deallocate reclaimed pages. Even though all of seg is reclaimable,
+ // the caller of markReclaimed may not have decommitted it, so we can
+ // only mark fr as reclaimed.
f.usage.Remove(f.usage.Isolate(seg, fr))
- if fr.Start < f.minUnallocatedPage {
- // We've deallocated at least one lower page.
- f.minUnallocatedPage = fr.Start
- }
}
// StartEvictions requests that f evict all evictable allocations. It does not
@@ -1255,3 +1244,27 @@ func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetVal
func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) {
return evictableRangeSetValue{}, evictableRangeSetValue{}
}
+
+// reclaimSetValue is the value type of reclaimSet.
+type reclaimSetValue struct{}
+
+type reclaimSetFunctions struct{}
+
+func (reclaimSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (reclaimSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
+}
+
+func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+ return reclaimSetValue{}, true
+}
+
+func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+ return reclaimSetValue{}, reclaimSetValue{}
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go
index 293f22c6b..b5b68eb52 100644
--- a/pkg/sentry/pgalloc/pgalloc_test.go
+++ b/pkg/sentry/pgalloc/pgalloc_test.go
@@ -23,39 +23,49 @@ import (
const (
page = usermem.PageSize
hugepage = usermem.HugePageSize
+ topPage = (1 << 63) - page
)
func TestFindUnallocatedRange(t *testing.T) {
for _, test := range []struct {
- desc string
- usage *usageSegmentDataSlices
- start uint64
- length uint64
- alignment uint64
- unallocated uint64
- minUnallocated uint64
+ desc string
+ usage *usageSegmentDataSlices
+ fileSize int64
+ length uint64
+ alignment uint64
+ start uint64
+ expectFail bool
}{
{
- desc: "Initial allocation succeeds",
- usage: &usageSegmentDataSlices{},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ desc: "Initial allocation succeeds",
+ usage: &usageSegmentDataSlices{},
+ length: page,
+ alignment: page,
+ start: chunkSize - page, // Grows by chunkSize, allocate down.
},
{
- desc: "Allocation begins at start of file",
+ desc: "Allocation finds empty space at start of file",
usage: &usageSegmentDataSlices{
Start: []uint64{page},
End: []uint64{2 * page},
Values: []usageInfo{{refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocation finds empty space at end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0},
+ End: []uint64{page},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "In-use frames are not allocatable",
@@ -64,11 +74,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 3 * page, // Double fileSize, allocate top-down.
},
{
desc: "Reclaimable frames are not allocatable",
@@ -77,11 +86,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: 3 * page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: 5 * page, // Double fileSize, grow down.
},
{
desc: "Gaps between in-use frames are allocatable",
@@ -90,11 +98,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "Inadequately-sized gaps are rejected",
@@ -103,14 +110,13 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: 2 * page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: 2 * page,
+ alignment: page,
+ start: 4 * page, // Double fileSize, grow down.
},
{
- desc: "Hugepage alignment is honored",
+ desc: "Alignment is honored at end of file",
usage: &usageSegmentDataSlices{
Start: []uint64{0, hugepage + page},
// Hugepage-sized gap here that shouldn't be allocated from
@@ -118,37 +124,95 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, hugepage + 2*page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: hugepage,
- alignment: hugepage,
- unallocated: 2 * hugepage,
- minUnallocated: page,
+ fileSize: hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down.
+ },
+ {
+ desc: "Alignment is honored before end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2*hugepage + page},
+ // Page will need to be shifted down from top.
+ End: []uint64{page, 2*hugepage + 2*page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 2*hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: hugepage,
},
{
- desc: "Pages before start ignored",
+ desc: "Allocations are compact if possible",
usage: &usageSegmentDataSlices{
Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 4 * page,
+ length: page,
+ alignment: page,
+ start: 2 * page,
+ },
+ {
+ desc: "Top-down allocation within one gap",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 4 * page, 7 * page},
+ End: []uint64{2 * page, 5 * page, 8 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 6 * page,
+ },
+ {
+ desc: "Top-down allocation between multiple gaps",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page, 5 * page},
+ End: []uint64{2 * page, 4 * page, 6 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 6 * page,
+ length: page,
+ alignment: page,
+ start: 4 * page,
},
{
- desc: "start may be in the middle of segment",
+ desc: "Top-down allocation with large top gap",
usage: &usageSegmentDataSlices{
- Start: []uint64{0, 3 * page},
+ Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 7 * page,
+ },
+ {
+ desc: "Gaps found with possible overflow",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, topPage - page},
+ End: []uint64{2 * page, topPage},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: topPage,
+ length: page,
+ alignment: page,
+ start: topPage - 2*page,
+ },
+ {
+ desc: "Overflow detected",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page},
+ End: []uint64{topPage},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: topPage,
+ length: 2 * page,
+ alignment: page,
+ expectFail: true,
},
} {
t.Run(test.desc, func(t *testing.T) {
@@ -156,12 +220,18 @@ func TestFindUnallocatedRange(t *testing.T) {
if err := usage.ImportSortedSlices(test.usage); err != nil {
t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err)
}
- unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment)
- if unallocated != test.unallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated)
+ fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment)
+ if !test.expectFail && !ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if test.expectFail && ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if ok && fr.Start != test.start {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
}
- if minUnallocated != test.minUnallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated)
+ if ok && fr.End != test.start+test.length {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length)
}
})
}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 159f7eafd..4792454c4 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -6,8 +6,8 @@ go_library(
name = "kvm",
srcs = [
"address_space.go",
- "allocator.go",
"bluepill.go",
+ "bluepill_allocator.go",
"bluepill_amd64.go",
"bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index be213bfe8..faf1d5e1c 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -26,16 +26,15 @@ import (
// dirtySet tracks vCPUs for invalidation.
type dirtySet struct {
- vCPUs []uint64
+ vCPUMasks []uint64
}
// forEach iterates over all CPUs in the dirty set.
+//
+//go:nosplit
func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- for index := range ds.vCPUs {
- mask := atomic.SwapUint64(&ds.vCPUs[index], 0)
+ for index := range ds.vCPUMasks {
+ mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0)
if mask != 0 {
for bit := 0; bit < 64; bit++ {
if mask&(1<<uint64(bit)) == 0 {
@@ -54,7 +53,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
index := uint64(c.id) / 64
bit := uint64(1) << uint(c.id%64)
- oldValue := atomic.LoadUint64(&ds.vCPUs[index])
+ oldValue := atomic.LoadUint64(&ds.vCPUMasks[index])
if oldValue&bit != 0 {
return false // Not clean.
}
@@ -62,7 +61,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
// Set the bit unilaterally, and ensure that a flush takes place. Note
// that it's possible for races to occur here, but since the flush is
// taking place long after these lines there's no race in practice.
- atomicbitops.OrUint64(&ds.vCPUs[index], bit)
+ atomicbitops.OrUint64(&ds.vCPUMasks[index], bit)
return true // Previously clean.
}
@@ -113,7 +112,12 @@ type hostMapEntry struct {
length uintptr
}
-func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+// mapLocked maps the given host entry.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
for m.length > 0 {
physical, length, ok := translateToPhysical(m.addr)
if !ok {
@@ -133,18 +137,10 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
// important; if the pagetable mappings were installed before
// ensuring the physical pages were available, then some other
// thread could theoretically access them.
- //
- // Due to the way KVM's shadow paging implementation works,
- // modifications to the page tables while in host mode may not
- // be trapped, leading to the shadow pages being out of sync.
- // Therefore, we need to ensure that we are in guest mode for
- // page table modifications. See the call to bluepill, below.
- as.machine.retryInGuest(func() {
- inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
- AccessType: at,
- User: true,
- }, physical) || inv
- })
+ inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
+ AccessType: at,
+ User: true,
+ }, physical) || inv
m.addr += length
m.length -= length
addr += usermem.Addr(length)
@@ -176,6 +172,10 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return err
}
+ // See block in mapLocked.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+
// Map the mappings in the sentry's address space (guest physical memory)
// into the application's address space (guest virtual memory).
inv := false
@@ -190,7 +190,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
_ = s[i] // Touch to commit.
}
}
- prev := as.mapHost(addr, hostMapEntry{
+
+ // See bluepill_allocator.go.
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ // Perform the mapping.
+ prev := as.mapLocked(addr, hostMapEntry{
addr: b.Addr(),
length: uintptr(b.Len()),
}, at)
@@ -204,17 +209,27 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return nil
}
+// unmapLocked is an escape-checked wrapped around Unmap.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool {
+ return as.pageTables.Unmap(addr, uintptr(length))
+}
+
// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
as.mu.Lock()
defer as.mu.Unlock()
- // See above re: retryInGuest.
- var prev bool
- as.machine.retryInGuest(func() {
- prev = as.pageTables.Unmap(addr, uintptr(length)) || prev
- })
- if prev {
+ // See above & bluepill_allocator.go.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ if prev := as.unmapLocked(addr, length); prev {
+ // Invalidate all active vCPUs.
as.invalidate()
// Recycle any freed intermediate pages.
@@ -227,7 +242,7 @@ func (as *addressSpace) Release() {
as.Unmap(0, ^uint64(0))
// Free all pages from the allocator.
- as.pageTables.Allocator.(allocator).base.Drain()
+ as.pageTables.Allocator.(*allocator).base.Drain()
// Drop all cached machine references.
as.machine.dropPageTables(as.pageTables)
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go
index 3f35414bb..9485e1301 100644
--- a/pkg/sentry/platform/kvm/allocator.go
+++ b/pkg/sentry/platform/kvm/bluepill_allocator.go
@@ -21,56 +21,80 @@ import (
)
type allocator struct {
- base *pagetables.RuntimeAllocator
+ base pagetables.RuntimeAllocator
+
+ // cpu must be set prior to any pagetable operation.
+ //
+ // Due to the way KVM's shadow paging implementation works,
+ // modifications to the page tables while in host mode may not be
+ // trapped, leading to the shadow pages being out of sync. Therefore,
+ // we need to ensure that we are in guest mode for page table
+ // modifications. See the call to bluepill, below.
+ cpu *vCPU
}
// newAllocator is used to define the allocator.
-func newAllocator() allocator {
- return allocator{
- base: pagetables.NewRuntimeAllocator(),
- }
+func newAllocator() *allocator {
+ a := new(allocator)
+ a.base.Init()
+ return a
}
// NewPTEs implements pagetables.Allocator.NewPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) NewPTEs() *pagetables.PTEs {
- return a.base.NewPTEs()
+func (a *allocator) NewPTEs() *pagetables.PTEs {
+ ptes := a.base.NewPTEs() // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
+ return ptes
}
// PhysicalFor returns the physical address for a set of PTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
+func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
virtual := a.base.PhysicalFor(ptes)
physical, _, ok := translateToPhysical(virtual)
if !ok {
- panic(fmt.Sprintf("PhysicalFor failed for %p", ptes))
+ panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic.
}
return physical
}
// LookupPTEs implements pagetables.Allocator.LookupPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
+func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
if !ok {
- panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
+ panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic.
}
return a.base.LookupPTEs(virtualStart + (physical - physicalStart))
}
// FreePTEs implements pagetables.Allocator.FreePTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) FreePTEs(ptes *pagetables.PTEs) {
- a.base.FreePTEs(ptes)
+func (a *allocator) FreePTEs(ptes *pagetables.PTEs) {
+ a.base.FreePTEs(ptes) // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
}
// Recycle implements pagetables.Allocator.Recycle.
//
//go:nosplit
-func (a allocator) Recycle() {
+func (a *allocator) Recycle() {
a.base.Recycle()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 133c2203d..ddc1554d5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -63,6 +63,8 @@ func bluepillArchEnter(context *arch.SignalContext64) *vCPU {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -72,13 +74,15 @@ func (c *vCPU) KernelSyscall() {
// We only trigger a bluepill entry in the bluepill function, and can
// therefore be guaranteed that there is no floating point state to be
// loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
}
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
@@ -89,9 +93,9 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
}
// bluepillArchExit is called during bluepillEnter.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index c215d443c..83643c602 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -66,6 +66,8 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -88,6 +90,8 @@ func (c *vCPU) KernelSyscall() {
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 2407014e9..c025aa0bb 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -64,6 +64,8 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
// signal stack. It should only execute raw system calls and functions that are
// explicitly marked go:nosplit.
//
+// +checkescape:all
+//
//go:nosplit
func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
@@ -82,7 +84,8 @@ func bluepillHandler(context unsafe.Pointer) {
}
for {
- switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno {
+ _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no.
+ switch errno {
case 0: // Expected case.
case syscall.EINTR:
// First, we process whatever pending signal
@@ -90,7 +93,7 @@ func bluepillHandler(context unsafe.Pointer) {
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
timeout := syscall.Timespec{}
- sig, _, errno := syscall.RawSyscall6(
+ sig, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_RT_SIGTIMEDWAIT,
uintptr(unsafe.Pointer(&bounceSignalMask)),
0, // siginfo.
@@ -125,7 +128,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_NMI, 0); errno != 0 {
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index f1afc74dc..6c54712d1 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -52,16 +52,19 @@ type machine struct {
// available is notified when vCPUs are available.
available sync.Cond
- // vCPUs are the machine vCPUs.
+ // vCPUsByTID are the machine vCPUs.
//
// These are populated dynamically.
- vCPUs map[uint64]*vCPU
+ vCPUsByTID map[uint64]*vCPU
// vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID.
- vCPUsByID map[int]*vCPU
+ vCPUsByID []*vCPU
// maxVCPUs is the maximum number of vCPUs supported by the machine.
maxVCPUs int
+
+ // nextID is the next vCPU ID.
+ nextID uint32
}
const (
@@ -137,9 +140,8 @@ type dieState struct {
//
// Precondition: mu must be held.
func (m *machine) newVCPU() *vCPU {
- id := len(m.vCPUs)
-
// Create the vCPU.
+ id := int(atomic.AddUint32(&m.nextID, 1) - 1)
fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id))
if errno != 0 {
panic(fmt.Sprintf("error creating new vCPU: %v", errno))
@@ -176,11 +178,7 @@ func (m *machine) newVCPU() *vCPU {
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
- m := &machine{
- fd: vm,
- vCPUs: make(map[uint64]*vCPU),
- vCPUsByID: make(map[int]*vCPU),
- }
+ m := &machine{fd: vm}
m.available.L = &m.mu
m.kernel.Init(ring0.KernelOpts{
PageTables: pagetables.New(newAllocator()),
@@ -194,6 +192,10 @@ func newMachine(vm int) (*machine, error) {
}
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
+ // Create the vCPUs map/slices.
+ m.vCPUsByTID = make(map[uint64]*vCPU)
+ m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -274,6 +276,8 @@ func newMachine(vm int) (*machine, error) {
// not available. This attempts to be efficient for calls in the hot path.
//
// This panics on error.
+//
+//go:nosplit
func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
for end := physical + length; physical < end; {
_, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
@@ -304,7 +308,11 @@ func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
// Destroy vCPUs.
- for _, c := range m.vCPUs {
+ for _, c := range m.vCPUsByID {
+ if c == nil {
+ continue
+ }
+
// Ensure the vCPU is not still running in guest mode. This is
// possible iff teardown has been done by other threads, and
// somehow a single thread has not executed any system calls.
@@ -337,7 +345,7 @@ func (m *machine) Get() *vCPU {
tid := procid.Current()
// Check for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.RUnlock()
return c
@@ -356,7 +364,7 @@ func (m *machine) Get() *vCPU {
tid = procid.Current()
// Recheck for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.Unlock()
return c
@@ -364,10 +372,10 @@ func (m *machine) Get() *vCPU {
for {
// Scan for an available vCPU.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -375,17 +383,17 @@ func (m *machine) Get() *vCPU {
}
// Create a new vCPU (maybe).
- if len(m.vCPUs) < m.maxVCPUs {
+ if int(m.nextID) < m.maxVCPUs {
c := m.newVCPU()
c.lock()
- m.vCPUs[tid] = c
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
}
// Scan for something not in user mode.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) {
continue
}
@@ -403,8 +411,8 @@ func (m *machine) Get() *vCPU {
}
// Steal the vCPU.
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -431,7 +439,7 @@ func (m *machine) Put(c *vCPU) {
// newDirtySet returns a new dirty set.
func (m *machine) newDirtySet() *dirtySet {
return &dirtySet{
- vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
+ vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 923ce3909..acc823ba6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -51,9 +51,10 @@ func (m *machine) initArchState() error {
recover()
debug.SetPanicOnFault(old)
}()
- m.retryInGuest(func() {
- ring0.SetCPUIDFaulting(true)
- })
+ c := m.Get()
+ defer m.Put(c)
+ bluepill(c)
+ ring0.SetCPUIDFaulting(true)
return nil
}
@@ -89,8 +90,8 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
defer m.mu.Unlock()
// Clear from all PCIDs.
- for _, c := range m.vCPUs {
- if c.PCIDs != nil {
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
c.PCIDs.Drop(pt)
}
}
@@ -335,29 +336,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
}
}
-// retryInGuest runs the given function in guest mode.
-//
-// If the function does not complete in guest mode (due to execution of a
-// system call due to a GC stall, for example), then it will be retried. The
-// given function must be idempotent as a result of the retry mechanism.
-func (m *machine) retryInGuest(fn func()) {
- c := m.Get()
- defer m.Put(c)
- for {
- c.ClearErrorCode() // See below.
- bluepill(c) // Force guest mode.
- fn() // Execute the given function.
- _, user := c.ErrorCode()
- if user {
- // If user is set, then we haven't bailed back to host
- // mode via a kernel exception or system call. We
- // consider the full function to have executed in guest
- // mode and we can return.
- break
- }
- }
-}
-
// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
// There is no need to return read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 7156c245f..290f035dd 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -154,7 +154,7 @@ func (c *vCPU) setUserRegisters(uregs *userRegs) error {
//
//go:nosplit
func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_GET_REGS,
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index de7df4f80..9f86f6a7a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -115,7 +115,7 @@ func (a *atomicAddressSpace) get() *addressSpace {
//
//go:nosplit
func (c *vCPU) notify() {
- _, _, errno := syscall.RawSyscall6(
+ _, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_FUTEX,
uintptr(unsafe.Pointer(&c.state)),
linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG,
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 900c0bba7..021693791 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -31,23 +31,39 @@ type defaultHooks struct{}
// KernelSyscall implements Hooks.KernelSyscall.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelSyscall() { Halt() }
+func (defaultHooks) KernelSyscall() {
+ Halt()
+}
// KernelException implements Hooks.KernelException.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelException(Vector) { Halt() }
+func (defaultHooks) KernelException(Vector) {
+ Halt()
+}
// kernelSyscall is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() }
+func kernelSyscall(c *CPU) {
+ c.hooks.KernelSyscall()
+}
// kernelException is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) }
+func kernelException(c *CPU, vector Vector) {
+ c.hooks.KernelException(vector)
+}
// Init initializes a new CPU.
//
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 0feff8778..d37981dbf 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -178,6 +178,8 @@ func IsCanonical(addr uint64) bool {
//
// Precondition: the Rip, Rsp, Fs and Gs registers must be canonical.
//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
@@ -192,9 +194,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
// Perform the switch.
swapgs() // GS will be swapped on return.
- WriteFS(uintptr(regs.Fs_base)) // Set application FS.
- WriteGS(uintptr(regs.Gs_base)) // Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
jumpToKernel() // Switch to upper half.
writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
@@ -204,8 +206,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
jumpToUser() // Return to lower half.
- SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point.
- WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS.
+ SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
}
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go
index 23fd5c352..8d75b7599 100644
--- a/pkg/sentry/platform/ring0/pagetables/allocator.go
+++ b/pkg/sentry/platform/ring0/pagetables/allocator.go
@@ -53,9 +53,14 @@ type RuntimeAllocator struct {
// NewRuntimeAllocator returns an allocator that uses runtime allocation.
func NewRuntimeAllocator() *RuntimeAllocator {
- return &RuntimeAllocator{
- used: make(map[*PTEs]struct{}),
- }
+ r := new(RuntimeAllocator)
+ r.Init()
+ return r
+}
+
+// Init initializes a RuntimeAllocator.
+func (r *RuntimeAllocator) Init() {
+ r.used = make(map[*PTEs]struct{})
}
// Recycle returns freed pages to the pool.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 87e88e97d..7f18ac296 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -86,6 +86,8 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
if !opts.AccessType.Any() {
@@ -128,6 +130,8 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
w := unmapWalker{
@@ -162,6 +166,8 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
w := emptyWalker{
@@ -197,6 +203,8 @@ func (*lookupVisitor) requiresSplit() bool { return false }
// Lookup returns the physical address for the given virtual address.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) {
mask := uintptr(usermem.PageSize - 1)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 47ff48c00..66015e2bc 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -144,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
}
func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
- ipt := stk.IPTables()
- table, ok := ipt.Tables[tablename.String()]
+ table, ok := stk.IPTables().GetTable(tablename.String())
if !ok {
return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
-// FillDefaultIPTables sets stack's IPTables to the default tables and
-// populates them with metadata.
-func FillDefaultIPTables(stk *stack.Stack) {
- ipt := stack.DefaultTables()
-
- // In order to fill in the metadata, we have to translate ipt from its
- // netstack format to Linux's giant-binary-blob format.
- for name, table := range ipt.Tables {
- _, metadata, err := convertNetstackToBinary(name, table)
- if err != nil {
- panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+// FillIPTablesMetadata populates stack's IPTables with metadata.
+func FillIPTablesMetadata(stk *stack.Stack) {
+ stk.IPTables().ModifyTables(func(tables map[string]stack.Table) {
+ // In order to fill in the metadata, we have to translate ipt from its
+ // netstack format to Linux's giant-binary-blob format.
+ for name, table := range tables {
+ _, metadata, err := convertNetstackToBinary(name, table)
+ if err != nil {
+ panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+ }
+ table.SetMetadata(metadata)
+ tables[name] = table
}
- table.SetMetadata(metadata)
- ipt.Tables[name] = table
- }
-
- stk.SetIPTables(ipt)
+ })
}
// convertNetstackToBinary converts the iptables as stored in netstack to the
@@ -573,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stk.IPTables()
table.SetMetadata(metadata{
HookEntry: replace.HookEntry,
Underflow: replace.Underflow,
NumEntries: replace.NumEntries,
Size: replace.Size,
})
- ipt.Tables[replace.Name.String()] = table
- stk.SetIPTables(ipt)
+ stk.IPTables().ReplaceTable(replace.Name.String(), table)
return nil
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 333e0042e..8f0f5466e 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -50,5 +50,6 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 60df51dae..e1e0c5931 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -33,6 +33,7 @@ import (
"syscall"
"time"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
@@ -719,6 +720,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
defer s.EventUnregister(&e)
if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
return syserr.TranslateNetstackError(err)
}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index f5fa18136..9b44c2b89 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (stack.IPTables, error) {
+func (s *Stack) IPTables() (*stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
-// FillDefaultIPTables sets the stack's iptables to the default tables, which
-// allow and do not modify all traffic.
-func (s *Stack) FillDefaultIPTables() {
- netfilter.FillDefaultIPTables(s.Stack)
+// FillIPTablesMetadata populates stack's IPTables with metadata.
+func (s *Stack) FillIPTablesMetadata() {
+ netfilter.FillIPTablesMetadata(s.Stack)
}
// Resume implements inet.Stack.Resume.
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 9c8b44f64..c0d005247 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -16,6 +16,7 @@ go_library(
"ioctl.go",
"memfd.go",
"mmap.go",
+ "mount.go",
"path.go",
"pipe.go",
"poll.go",
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
new file mode 100644
index 000000000..adeaa39cc
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ // For null-terminated strings related to mount(2), Linux copies in at most
+ // a page worth of data. See fs/namespace.c:copy_mount_string().
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ source, err := t.CopyInString(sourceAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, err := copyInPath(t, targetAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the current mount namespace's associated user
+ // namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODEV |
+ linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var opts vfs.MountOptions
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ opts.Flags.NoATime = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ opts.Flags.NoExec = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ opts.ReadOnly = true
+ }
+ opts.GetFilesystemOptions.Data = data
+
+ target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer target.Release()
+
+ return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts)
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ opts := vfs.UmountOptions{
+ Flags: uint32(flags),
+ }
+
+ return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index ef8358b8a..7b6e7571a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -90,8 +90,8 @@ func Override() {
s.Table[138] = syscalls.Supported("fstatfs", Fstatfs)
s.Table[161] = syscalls.Supported("chroot", Chroot)
s.Table[162] = syscalls.Supported("sync", Sync)
- delete(s.Table, 165) // mount
- delete(s.Table, 166) // umount2
+ s.Table[165] = syscalls.Supported("mount", Mount)
+ s.Table[166] = syscalls.Supported("umount2", Umount2)
delete(s.Table, 187) // readahead
s.Table[188] = syscalls.Supported("setxattr", Setxattr)
s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 286510195..8882fa84a 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -43,7 +43,7 @@ type Dentry struct {
// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is
// either d2's parent or an ancestor of d2's parent.
func IsAncestorDentry(d, d2 *Dentry) bool {
- for {
+ for d2 != nil { // Stop at root, where d2.parent == nil.
if d2.parent == d {
return true
}
@@ -52,6 +52,7 @@ func IsAncestorDentry(d, d2 *Dentry) bool {
}
d2 = d2.parent
}
+ return false
}
// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d.
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 3adb7c97d..32f901bd8 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -55,6 +55,10 @@ type Mount struct {
// ID is the immutable mount ID.
ID uint64
+ // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers". Immutable.
+ Flags MountFlags
+
// key is protected by VirtualFilesystem.mountMu and
// VirtualFilesystem.mounts.seq, and may be nil. References are held on
// key.parent and key.point if they are not nil.
@@ -81,10 +85,6 @@ type Mount struct {
// umounted is true. umounted is protected by VirtualFilesystem.mountMu.
umounted bool
- // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
- // for MS_RDONLY which is tracked in "writers".
- flags MountFlags
-
// The lower 63 bits of writers is the number of calls to
// Mount.CheckBeginWrite() that have not yet been paired with a call to
// Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
@@ -95,10 +95,10 @@ type Mount struct {
func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
mnt := &Mount{
ID: atomic.AddUint64(&vfs.lastMountID, 1),
+ Flags: opts.Flags,
vfs: vfs,
fs: fs,
root: root,
- flags: opts.Flags,
ns: mntns,
refs: 1,
}
@@ -113,13 +113,12 @@ func (mnt *Mount) Options() MountOptions {
mnt.vfs.mountMu.Lock()
defer mnt.vfs.mountMu.Unlock()
return MountOptions{
- Flags: mnt.flags,
+ Flags: mnt.Flags,
ReadOnly: mnt.readOnly(),
}
}
-// A MountNamespace is a collection of Mounts.
-//
+// A MountNamespace is a collection of Mounts.//
// MountNamespaces are reference-counted. Unless otherwise specified, all
// MountNamespace methods require that a reference is held.
//
@@ -127,6 +126,9 @@ func (mnt *Mount) Options() MountOptions {
//
// +stateify savable
type MountNamespace struct {
+ // Owner is the usernamespace that owns this mount namespace.
+ Owner *auth.UserNamespace
+
// root is the MountNamespace's root mount. root is immutable.
root *Mount
@@ -163,6 +165,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return nil, err
}
mntns := &MountNamespace{
+ Owner: creds.UserNamespace,
refs: 1,
mountpoints: make(map[*Dentry]uint32),
}
@@ -279,6 +282,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
}
// MNT_FORCE is currently unimplemented except for the permission check.
+ // Force unmounting specifically requires CAP_SYS_ADMIN in the root user
+ // namespace, and not in the owner user namespace for the target mount. See
+ // fs/namespace.c:SYSCALL_DEFINE2(umount, ...)
if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
return syserror.EPERM
}
@@ -753,7 +759,10 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi
if mnt.readOnly() {
opts = "ro"
}
- if mnt.flags.NoExec {
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
opts += ",noexec"
}
@@ -838,11 +847,12 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
if mnt.readOnly() {
opts = "ro"
}
- if mnt.flags.NoExec {
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
opts += ",noexec"
}
- // TODO(gvisor.dev/issue/1193): Add "noatime" if MS_NOATIME is
- // set.
fmt.Fprintf(buf, "%s ", opts)
// (7) Optional fields: zero or more fields of the form "tag[:value]".
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 53d364c5c..f223aeda8 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -75,6 +75,10 @@ type MknodOptions struct {
type MountFlags struct {
// NoExec is equivalent to MS_NOEXEC.
NoExec bool
+
+ // NoATime is equivalent to MS_NOATIME and indicates that the
+ // filesystem should not update access time in-place.
+ NoATime bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 52643a7c5..9acca8bc7 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -405,7 +405,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
vfs.putResolvingPath(rp)
if opts.FileExec {
- if fd.Mount().flags.NoExec {
+ if fd.Mount().Flags.NoExec {
fd.DecRef()
return nil, syserror.EACCES
}
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index e57d45f2a..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -22,7 +22,6 @@ go_test(
size = "small",
srcs = ["gonet_test.go"],
library = ":gonet",
- tags = ["flaky"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index d989dbe91..4e9b404c8 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -43,11 +43,11 @@ const HookUnset = -1
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() IPTables {
+func DefaultTables() *IPTables {
// TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
// iotas.
- return IPTables{
- Tables: map[string]Table{
+ return &IPTables{
+ tables: map[string]Table{
TablenameNat: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
@@ -106,7 +106,7 @@ func DefaultTables() IPTables {
UserChains: map[string]int{},
},
},
- Priorities: map[Hook][]string{
+ priorities: map[Hook][]string{
Input: []string{TablenameNat, TablenameFilter},
Prerouting: []string{TablenameMangle, TablenameNat},
Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
@@ -158,6 +158,36 @@ func EmptyNatTable() Table {
}
}
+// GetTable returns table by name.
+func (it *IPTables) GetTable(name string) (Table, bool) {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ t, ok := it.tables[name]
+ return t, ok
+}
+
+// ReplaceTable replaces or inserts table by name.
+func (it *IPTables) ReplaceTable(name string, table Table) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ it.tables[name] = table
+}
+
+// ModifyTables acquires write-lock and calls fn with internal name-to-table
+// map. This function can be used to update multiple tables atomically.
+func (it *IPTables) ModifyTables(fn func(map[string]Table)) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ fn(it.tables)
+}
+
+// GetPriorities returns slice of priorities associated with hook.
+func (it *IPTables) GetPriorities(hook Hook) []string {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ return it.priorities[hook]
+}
+
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
@@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
it.connections.HandlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.Priorities[hook] {
- table := it.Tables[tablename]
+ for _, tablename := range it.GetPriorities(hook) {
+ table, _ := it.GetTable(tablename)
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index af72b9c46..4a6a5c6f1 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -16,6 +16,7 @@ package stack
import (
"strings"
+ "sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -78,13 +79,17 @@ const (
// IPTables holds all the tables for a netstack.
type IPTables struct {
- // Tables maps table names to tables. User tables have arbitrary names.
- Tables map[string]Table
+ // mu protects tables and priorities.
+ mu sync.RWMutex
- // Priorities maps each hook to a list of table names. The order of the
+ // tables maps table names to tables. User tables have arbitrary names. mu
+ // needs to be locked for accessing.
+ tables map[string]Table
+
+ // priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
- // hook.
- Priorities map[Hook][]string
+ // hook. mu needs to be locked for accessing.
+ priorities map[Hook][]string
connections ConnTrackTable
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 8af06cb9a..294ce8775 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -424,12 +424,8 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
- // tablesMu protects iptables.
- tablesMu sync.RWMutex
-
- // tables are the iptables packet filtering and manipulation rules. The are
- // protected by tablesMu.`
- tables IPTables
+ // tables are the iptables packet filtering and manipulation rules.
+ tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -676,6 +672,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
@@ -1741,18 +1738,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() IPTables {
- s.tablesMu.RLock()
- t := s.tables
- s.tablesMu.RUnlock()
- return t
-}
-
-// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt IPTables) {
- s.tablesMu.Lock()
- s.tables = ipt
- s.tablesMu.Unlock()
+func (s *Stack) IPTables() *IPTables {
+ return s.tables
}
// ICMPLimit returns the maximum number of ICMP messages that can be sent
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 29ff68df3..57e0a069b 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -511,6 +506,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index bab2d63ae..baf08eda6 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -132,11 +132,6 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (stack.IPTables, error) {
- return ep.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 25a17940d..21c34fac2 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
if !e.associated {
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f38eb6833..e26f01fae 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -86,10 +86,6 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- # FIXME(b/68809571)
- tags = [
- "flaky",
- ],
deps = [
":tcp",
"//pkg/sync",
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d048ef90c..19f7bf449 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,7 +63,8 @@ const (
StateClosing
)
-// connected is the set of states where an endpoint is connected to a peer.
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
func (s EndpointState) connected() bool {
switch s {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
@@ -73,6 +74,40 @@ func (s EndpointState) connected() bool {
}
}
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// String implements fmt.Stringer.String.
func (s EndpointState) String() string {
switch s {
@@ -1172,11 +1207,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index fc43c11e2..cbb779666 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.EndpointState() {
- case StateInitial, StateBound:
- // TODO(b/138137272): this enumeration duplicates
- // EndpointState.connected. remove it.
- case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
@@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case StateListen, StateConnecting:
+ case epState == StateListen || epState == StateConnecting:
e.drainSegmentLocked()
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
- break
}
- case StateError, StateClose:
+ case epState.closed():
for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
@@ -148,23 +148,23 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state EndpointState) {
+func (e *endpoint) loadState(epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
// For restore purposes we treat TimeWait like a connected endpoint.
- if state.connected() || state == StateTimeWait {
+ if epState.connected() || epState == StateTimeWait {
connectedLoading.Add(1)
}
- switch state {
- case StateListen:
+ switch {
+ case epState == StateListen:
listenLoading.Add(1)
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
// as the endpoint is still being loaded and the stack reference is not
// yet initialized.
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
}
// afterLoad is invoked by stateify.
@@ -183,8 +183,8 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- state := e.origEndpointState
- switch state {
+ epState := e.origEndpointState
+ switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -208,8 +208,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ switch {
+ case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
e.connectingAddress = e.ID.RemoteAddress
@@ -232,13 +232,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
closed := e.closed
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
- if state == StateFinWait2 && closed {
+ if epState == StateFinWait2 && closed {
// If the endpoint has been closed then make sure we notify so
// that the FIN_WAIT2 timer is started after a restore.
e.notifyProtocolGoroutine(notifyClose)
}
connectedLoading.Done()
- case StateListen:
+ case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -255,7 +255,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -267,7 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateBound:
+ case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -276,7 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind()
tcpip.AsyncLoading.Done()
}()
- case StateClose:
+ case epState == StateClose:
if e.isPortReserved {
tcpip.AsyncLoading.Add(1)
go func() {
@@ -291,12 +291,11 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.state = StateClose
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
- case StateError:
+ case epState == StateError:
e.state = StateError
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
-
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 3a19c4468..acacb42e4 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) {
nSeg.data.TrimFront(size)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
seg.data.CapLength(size)
}
@@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(seg.sequenceNumber.Size(end))
+ available := int(s.sndNxt.Size(end))
if available > limit {
available = limit
}
@@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// sent all at once.
return false
}
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
return false
}
}
@@ -843,9 +860,17 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if available == 0 {
return false
}
+
+ // The segment size limit is computed as a function of sender congestion
+ // window and MSS. When sender congestion window is > 1, this limit can
+ // be larger than MSS. Ensure that the currently available send space
+ // is not greater than minimum of this limit and MSS.
if available > limit {
available = limit
}
+ if available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
if seg.data.Size() > available {
s.splitSeg(seg, available)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 79faa7869..663af8fec 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -247,11 +247,6 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
deleted file mode 100644
index 2dcba84ae..000000000
--- a/pkg/tmutex/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tmutex",
- srcs = ["tmutex.go"],
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "tmutex_test",
- size = "medium",
- srcs = ["tmutex_test.go"],
- library = ":tmutex",
- deps = ["//pkg/sync"],
-)
diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go
deleted file mode 100644
index c4685020d..000000000
--- a/pkg/tmutex/tmutex.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package tmutex provides the implementation of a mutex that implements an
-// efficient TryLock function in addition to Lock and Unlock.
-package tmutex
-
-import (
- "sync/atomic"
-)
-
-// Mutex is a mutual exclusion primitive that implements TryLock in addition
-// to Lock and Unlock.
-type Mutex struct {
- v int32
- ch chan struct{}
-}
-
-// Init initializes the mutex.
-func (m *Mutex) Init() {
- m.v = 1
- m.ch = make(chan struct{}, 1)
-}
-
-// Lock acquires the mutex. If it is currently held by another goroutine, Lock
-// will wait until it has a chance to acquire it.
-func (m *Mutex) Lock() {
- // Uncontended case.
- if atomic.AddInt32(&m.v, -1) == 0 {
- return
- }
-
- for {
- // Try to acquire the mutex again, at the same time making sure
- // that m.v is negative, which indicates to the owner of the
- // lock that it is contended, which will force it to try to wake
- // someone up when it releases the mutex.
- if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
- return
- }
-
- // Wait for the mutex to be released before trying again.
- <-m.ch
- }
-}
-
-// TryLock attempts to acquire the mutex without blocking. If the mutex is
-// currently held by another goroutine, it fails to acquire it and returns
-// false.
-func (m *Mutex) TryLock() bool {
- v := atomic.LoadInt32(&m.v)
- if v <= 0 {
- return false
- }
- return atomic.CompareAndSwapInt32(&m.v, 1, 0)
-}
-
-// Unlock releases the mutex.
-func (m *Mutex) Unlock() {
- if atomic.SwapInt32(&m.v, 1) == 0 {
- // There were no pending waiters.
- return
- }
-
- // Wake some waiter up.
- select {
- case m.ch <- struct{}{}:
- default:
- }
-}
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
deleted file mode 100644
index 05540696a..000000000
--- a/pkg/tmutex/tmutex_test.go
+++ /dev/null
@@ -1,258 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tmutex
-
-import (
- "fmt"
- "runtime"
- "sync/atomic"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-func TestBasicLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- m.Lock()
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- ch <- struct{}{}
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-
- // Make sure we can lock and unlock again.
- m.Lock()
- m.Unlock()
-}
-
-func TestTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Try to lock. It should succeed.
- if !m.TryLock() {
- t.Fatalf("TryLock failed on unlocked mutex")
- }
-
- // Try to lock again, it should now fail.
- if m.TryLock() {
- t.Fatalf("TryLock succeeded on locked mutex")
- }
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-}
-
-func TestMutualExclusion(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Test mutual exclusion by running "gr" goroutines concurrently, and
- // have each one increment a counter "iters" times within the critical
- // section established by the mutex.
- //
- // If at the end the counter is not gr * iters, then we know that
- // goroutines ran concurrently within the critical section.
- //
- // If one of the goroutines doesn't complete, it's likely a bug that
- // causes to it to wait forever.
- const gr = 1000
- const iters = 100000
- v := 0
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(1)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- if v != gr*iters {
- t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
- }
-}
-
-func TestMutualExclusionWithTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Similar to the previous, with the addition of some goroutines that
- // only increment the count if TryLock succeeds.
- const gr = 1000
- const iters = 100000
- total := int64(gr * iters)
- var tryTotal int64
- v := int64(0)
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(2)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- go func() {
- local := int64(0)
- for j := 0; j < iters; j++ {
- if m.TryLock() {
- v++
- m.Unlock()
- local++
- }
- }
- atomic.AddInt64(&tryTotal, local)
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- t.Logf("tryTotal = %d", tryTotal)
- total += tryTotal
-
- if v != total {
- t.Fatalf("Bad count: got %v, want %v", v, total)
- }
-}
-
-// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
-// differences:
-//
-// - The number of goroutines is variable, with the maximum value depending on
-// GOMAXPROCS.
-//
-// - The number of iterations per benchmark is controlled by the benchmarking
-// framework.
-//
-// - Care is taken to ensure that all goroutines participating in the benchmark
-// have been created before the benchmark begins.
-func BenchmarkTmutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m Mutex
- m.Init()
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
-
-// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
-// a comparison point.
-func BenchmarkSyncMutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m sync.Mutex
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}