summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorNeel Natu <neelnatu@google.com>2019-06-20 12:54:40 -0700
committergVisor bot <gvisor-bot@google.com>2019-06-20 12:56:00 -0700
commit0b2135072d3a6b418f87f166b58dcf877f7c2fba (patch)
tree40c316afa59e5786fac24b7e8be940ae645ffb16
parentb46ec3704b60bebdd63a597c62f3f471ee0d9be9 (diff)
Implement madvise(MADV_DONTFORK)
PiperOrigin-RevId: 254253777
-rw-r--r--pkg/sentry/mm/lifecycle.go37
-rw-r--r--pkg/sentry/mm/mm.go3
-rw-r--r--pkg/sentry/mm/syscalls.go26
-rw-r--r--pkg/sentry/mm/vma.go1
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go6
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/madvise.cc109
-rw-r--r--test/syscalls/linux/mremap.cc11
-rw-r--r--test/util/memory_util.h11
9 files changed, 193 insertions, 12 deletions
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 06e4372ff..4e9ca1de6 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -86,10 +86,22 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
}
// Copy vmas.
+ dontforks := false
dstvgap := mm2.vmas.FirstGap()
for srcvseg := mm.vmas.FirstSegment(); srcvseg.Ok(); srcvseg = srcvseg.NextSegment() {
vma := srcvseg.Value() // makes a copy of the vma
vmaAR := srcvseg.Range()
+
+ if vma.dontfork {
+ length := uint64(vmaAR.Length())
+ mm2.usageAS -= length
+ if vma.isPrivateDataLocked() {
+ mm2.dataAS -= length
+ }
+ dontforks = true
+ continue
+ }
+
// Inform the Mappable, if any, of the new mapping.
if vma.mappable != nil {
if err := vma.mappable.AddMapping(ctx, mm2, vmaAR, vma.off, vma.canWriteMappableLocked()); err != nil {
@@ -118,6 +130,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
defer mm2.activeMu.Unlock()
mm.activeMu.Lock()
defer mm.activeMu.Unlock()
+ if dontforks {
+ defer mm.pmas.MergeRange(mm.applicationAddrRange())
+ }
+ srcvseg := mm.vmas.FirstSegment()
dstpgap := mm2.pmas.FirstGap()
var unmapAR usermem.AddrRange
for srcpseg := mm.pmas.FirstSegment(); srcpseg.Ok(); srcpseg = srcpseg.NextSegment() {
@@ -125,6 +141,27 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
if !pma.private {
continue
}
+
+ if dontforks {
+ // Find the 'vma' that contains the starting address
+ // associated with the 'pma' (there must be one).
+ srcvseg = srcvseg.seekNextLowerBound(srcpseg.Start())
+ if checkInvariants {
+ if !srcvseg.Ok() {
+ panic(fmt.Sprintf("no vma covers pma range %v", srcpseg.Range()))
+ }
+ if srcpseg.Start() < srcvseg.Start() {
+ panic(fmt.Sprintf("vma %v ran ahead of pma %v", srcvseg.Range(), srcpseg.Range()))
+ }
+ }
+
+ srcpseg = mm.pmas.Isolate(srcpseg, srcvseg.Range())
+ if srcvseg.ValuePtr().dontfork {
+ continue
+ }
+ pma = srcpseg.ValuePtr()
+ }
+
if !pma.needCOW {
pma.needCOW = true
if pma.effectivePerms.Write {
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 2ec2ad99b..7bb96b159 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -274,6 +274,9 @@ type vma struct {
// metag, none of which we currently support.
growsDown bool `state:"manual"`
+ // dontfork is the MADV_DONTFORK setting for this vma configured by madvise().
+ dontfork bool
+
mlockMode memmap.MLockMode
// numaPolicy is the NUMA policy for this vma set by mbind().
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 9aa39e31d..c2466c988 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -1026,6 +1026,32 @@ func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy
}
}
+// SetDontFork implements the semantics of madvise MADV_DONTFORK.
+func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork bool) error {
+ ar, ok := addr.ToRange(length)
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ defer func() {
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ }()
+
+ for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vma := vseg.ValuePtr()
+ vma.dontfork = dontfork
+ }
+
+ if mm.vmas.SpanRange(ar) != ar.Length() {
+ return syserror.ENOMEM
+ }
+ return nil
+}
+
// Decommit implements the semantics of Linux's madvise(MADV_DONTNEED).
func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
ar, ok := addr.ToRange(length)
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index 9f846cdb8..074e2b141 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -439,6 +439,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa
vma1.mlockMode != vma2.mlockMode ||
vma1.numaPolicy != vma2.numaPolicy ||
vma1.numaNodemask != vma2.numaNodemask ||
+ vma1.dontfork != vma2.dontfork ||
vma1.id != vma2.id ||
vma1.hint != vma2.hint {
return vma{}, false
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 8a45dceeb..d831833bc 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -180,6 +180,10 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
switch adv {
case linux.MADV_DONTNEED:
return 0, nil, t.MemoryManager().Decommit(addr, length)
+ case linux.MADV_DOFORK:
+ return 0, nil, t.MemoryManager().SetDontFork(addr, length, false)
+ case linux.MADV_DONTFORK:
+ return 0, nil, t.MemoryManager().SetDontFork(addr, length, true)
case linux.MADV_HUGEPAGE, linux.MADV_NOHUGEPAGE:
fallthrough
case linux.MADV_MERGEABLE, linux.MADV_UNMERGEABLE:
@@ -191,7 +195,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
case linux.MADV_NORMAL, linux.MADV_RANDOM, linux.MADV_SEQUENTIAL, linux.MADV_WILLNEED:
// Do nothing, we totally ignore the suggestions above.
return 0, nil, nil
- case linux.MADV_REMOVE, linux.MADV_DOFORK, linux.MADV_DONTFORK:
+ case linux.MADV_REMOVE:
// These "suggestions" have application-visible side effects, so we
// have to indicate that we don't support them.
return 0, nil, syserror.ENOSYS
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 7e3ad08a9..0618fea58 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -985,6 +985,7 @@ cc_binary(
"//test/util:file_descriptor",
"//test/util:logging",
"//test/util:memory_util",
+ "//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc
index f6ad4d18b..352fcc6c4 100644
--- a/test/syscalls/linux/madvise.cc
+++ b/test/syscalls/linux/madvise.cc
@@ -29,6 +29,7 @@
#include "test/util/file_descriptor.h"
#include "test/util/logging.h"
#include "test/util/memory_util.h"
+#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -136,6 +137,114 @@ TEST(MadviseDontneedTest, IgnoresPermissions) {
EXPECT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds());
}
+TEST(MadviseDontforkTest, AddressLength) {
+ auto m =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE));
+ char *addr = static_cast<char *>(m.ptr());
+
+ // Address must be page aligned.
+ EXPECT_THAT(madvise(addr + 1, kPageSize, MADV_DONTFORK),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Zero length madvise always succeeds.
+ EXPECT_THAT(madvise(addr, 0, MADV_DONTFORK), SyscallSucceeds());
+
+ // Length must not roll over after rounding up.
+ size_t badlen = std::numeric_limits<std::size_t>::max() - (kPageSize / 2);
+ EXPECT_THAT(madvise(0, badlen, MADV_DONTFORK), SyscallFailsWithErrno(EINVAL));
+
+ // Length need not be page aligned - it is implicitly rounded up.
+ EXPECT_THAT(madvise(addr, 1, MADV_DONTFORK), SyscallSucceeds());
+ EXPECT_THAT(madvise(addr, kPageSize, MADV_DONTFORK), SyscallSucceeds());
+}
+
+TEST(MadviseDontforkTest, DontforkShared) {
+ // Mmap two shared file-backed pages and MADV_DONTFORK the second page.
+ TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ /* parent = */ GetAbsoluteTestTmpdir(),
+ /* content = */ std::string(kPageSize * 2, 2), TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
+
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0));
+
+ const Mapping ms1 = Mapping(reinterpret_cast<void *>(m.addr()), kPageSize);
+ const Mapping ms2 =
+ Mapping(reinterpret_cast<void *>(m.addr() + kPageSize), kPageSize);
+ m.release();
+
+ ASSERT_THAT(madvise(ms2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds());
+
+ const auto rest = [&] {
+ // First page is mapped in child and modifications are visible to parent
+ // via the shared mapping.
+ TEST_CHECK(IsMapped(ms1.addr()));
+ ExpectAllMappingBytes(ms1, 2);
+ memset(ms1.ptr(), 1, kPageSize);
+ ExpectAllMappingBytes(ms1, 1);
+
+ // Second page must not be mapped in child.
+ TEST_CHECK(!IsMapped(ms2.addr()));
+ };
+
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+
+ ExpectAllMappingBytes(ms1, 1); // page contents modified by child.
+ ExpectAllMappingBytes(ms2, 2); // page contents unchanged.
+}
+
+TEST(MadviseDontforkTest, DontforkAnonPrivate) {
+ // Mmap three anonymous pages and MADV_DONTFORK the middle page.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize * 3, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ const Mapping mp1 = Mapping(reinterpret_cast<void *>(m.addr()), kPageSize);
+ const Mapping mp2 =
+ Mapping(reinterpret_cast<void *>(m.addr() + kPageSize), kPageSize);
+ const Mapping mp3 =
+ Mapping(reinterpret_cast<void *>(m.addr() + 2 * kPageSize), kPageSize);
+ m.release();
+
+ ASSERT_THAT(madvise(mp2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds());
+
+ // Verify that all pages are zeroed and memset the first, second and third
+ // pages to 1, 2, and 3 respectively.
+ ExpectAllMappingBytes(mp1, 0);
+ memset(mp1.ptr(), 1, kPageSize);
+
+ ExpectAllMappingBytes(mp2, 0);
+ memset(mp2.ptr(), 2, kPageSize);
+
+ ExpectAllMappingBytes(mp3, 0);
+ memset(mp3.ptr(), 3, kPageSize);
+
+ const auto rest = [&] {
+ // Verify first page is mapped, verify its contents and then modify the
+ // page. The mapping is private so the modifications are not visible to
+ // the parent.
+ TEST_CHECK(IsMapped(mp1.addr()));
+ ExpectAllMappingBytes(mp1, 1);
+ memset(mp1.ptr(), 11, kPageSize);
+ ExpectAllMappingBytes(mp1, 11);
+
+ // Verify second page is not mapped.
+ TEST_CHECK(!IsMapped(mp2.addr()));
+
+ // Verify third page is mapped, verify its contents and then modify the
+ // page. The mapping is private so the modifications are not visible to
+ // the parent.
+ TEST_CHECK(IsMapped(mp3.addr()));
+ ExpectAllMappingBytes(mp3, 3);
+ memset(mp3.ptr(), 13, kPageSize);
+ ExpectAllMappingBytes(mp3, 13);
+ };
+ EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0));
+
+ // The fork and COW by child should not affect the parent mappings.
+ ExpectAllMappingBytes(mp1, 1);
+ ExpectAllMappingBytes(mp2, 2);
+ ExpectAllMappingBytes(mp3, 3);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc
index 7298d4ca8..64e435cb7 100644
--- a/test/syscalls/linux/mremap.cc
+++ b/test/syscalls/linux/mremap.cc
@@ -46,17 +46,6 @@ PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, size_t new_size,
return rv;
}
-// Returns true if the page containing addr is mapped.
-bool IsMapped(uintptr_t addr) {
- int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
- kPageSize, MS_ASYNC);
- if (rv == 0) {
- return true;
- }
- TEST_PCHECK_MSG(errno == ENOMEM, "msync failed with unexpected errno");
- return false;
-}
-
// Fixture for mremap tests parameterized by mmap flags.
using MremapParamTest = ::testing::TestWithParam<int>;
diff --git a/test/util/memory_util.h b/test/util/memory_util.h
index 8c77778ea..190c469b5 100644
--- a/test/util/memory_util.h
+++ b/test/util/memory_util.h
@@ -118,6 +118,17 @@ inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) {
return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0);
}
+// Returns true if the page containing addr is mapped.
+inline bool IsMapped(uintptr_t addr) {
+ int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
+ kPageSize, MS_ASYNC);
+ if (rv == 0) {
+ return true;
+ }
+ TEST_PCHECK_MSG(errno == ENOMEM, "msync failed with unexpected errno");
+ return false;
+}
+
} // namespace testing
} // namespace gvisor