From 0b2135072d3a6b418f87f166b58dcf877f7c2fba Mon Sep 17 00:00:00 2001 From: Neel Natu Date: Thu, 20 Jun 2019 12:54:40 -0700 Subject: Implement madvise(MADV_DONTFORK) PiperOrigin-RevId: 254253777 --- pkg/sentry/mm/lifecycle.go | 37 ++++++++++++ pkg/sentry/mm/mm.go | 3 + pkg/sentry/mm/syscalls.go | 26 ++++++++ pkg/sentry/mm/vma.go | 1 + pkg/sentry/syscalls/linux/sys_mmap.go | 6 +- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/madvise.cc | 109 ++++++++++++++++++++++++++++++++++ test/syscalls/linux/mremap.cc | 11 ---- test/util/memory_util.h | 11 ++++ 9 files changed, 193 insertions(+), 12 deletions(-) diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 06e4372ff..4e9ca1de6 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -86,10 +86,22 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { } // Copy vmas. + dontforks := false dstvgap := mm2.vmas.FirstGap() for srcvseg := mm.vmas.FirstSegment(); srcvseg.Ok(); srcvseg = srcvseg.NextSegment() { vma := srcvseg.Value() // makes a copy of the vma vmaAR := srcvseg.Range() + + if vma.dontfork { + length := uint64(vmaAR.Length()) + mm2.usageAS -= length + if vma.isPrivateDataLocked() { + mm2.dataAS -= length + } + dontforks = true + continue + } + // Inform the Mappable, if any, of the new mapping. if vma.mappable != nil { if err := vma.mappable.AddMapping(ctx, mm2, vmaAR, vma.off, vma.canWriteMappableLocked()); err != nil { @@ -118,6 +130,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { defer mm2.activeMu.Unlock() mm.activeMu.Lock() defer mm.activeMu.Unlock() + if dontforks { + defer mm.pmas.MergeRange(mm.applicationAddrRange()) + } + srcvseg := mm.vmas.FirstSegment() dstpgap := mm2.pmas.FirstGap() var unmapAR usermem.AddrRange for srcpseg := mm.pmas.FirstSegment(); srcpseg.Ok(); srcpseg = srcpseg.NextSegment() { @@ -125,6 +141,27 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { if !pma.private { continue } + + if dontforks { + // Find the 'vma' that contains the starting address + // associated with the 'pma' (there must be one). + srcvseg = srcvseg.seekNextLowerBound(srcpseg.Start()) + if checkInvariants { + if !srcvseg.Ok() { + panic(fmt.Sprintf("no vma covers pma range %v", srcpseg.Range())) + } + if srcpseg.Start() < srcvseg.Start() { + panic(fmt.Sprintf("vma %v ran ahead of pma %v", srcvseg.Range(), srcpseg.Range())) + } + } + + srcpseg = mm.pmas.Isolate(srcpseg, srcvseg.Range()) + if srcvseg.ValuePtr().dontfork { + continue + } + pma = srcpseg.ValuePtr() + } + if !pma.needCOW { pma.needCOW = true if pma.effectivePerms.Write { diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 2ec2ad99b..7bb96b159 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -274,6 +274,9 @@ type vma struct { // metag, none of which we currently support. growsDown bool `state:"manual"` + // dontfork is the MADV_DONTFORK setting for this vma configured by madvise(). + dontfork bool + mlockMode memmap.MLockMode // numaPolicy is the NUMA policy for this vma set by mbind(). diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 9aa39e31d..c2466c988 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -1026,6 +1026,32 @@ func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy } } +// SetDontFork implements the semantics of madvise MADV_DONTFORK. +func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork bool) error { + ar, ok := addr.ToRange(length) + if !ok { + return syserror.EINVAL + } + + mm.mappingMu.Lock() + defer mm.mappingMu.Unlock() + defer func() { + mm.vmas.MergeRange(ar) + mm.vmas.MergeAdjacent(ar) + }() + + for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() { + vseg = mm.vmas.Isolate(vseg, ar) + vma := vseg.ValuePtr() + vma.dontfork = dontfork + } + + if mm.vmas.SpanRange(ar) != ar.Length() { + return syserror.ENOMEM + } + return nil +} + // Decommit implements the semantics of Linux's madvise(MADV_DONTNEED). func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { ar, ok := addr.ToRange(length) diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 9f846cdb8..074e2b141 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -439,6 +439,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa vma1.mlockMode != vma2.mlockMode || vma1.numaPolicy != vma2.numaPolicy || vma1.numaNodemask != vma2.numaNodemask || + vma1.dontfork != vma2.dontfork || vma1.id != vma2.id || vma1.hint != vma2.hint { return vma{}, false diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 8a45dceeb..d831833bc 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -180,6 +180,10 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca switch adv { case linux.MADV_DONTNEED: return 0, nil, t.MemoryManager().Decommit(addr, length) + case linux.MADV_DOFORK: + return 0, nil, t.MemoryManager().SetDontFork(addr, length, false) + case linux.MADV_DONTFORK: + return 0, nil, t.MemoryManager().SetDontFork(addr, length, true) case linux.MADV_HUGEPAGE, linux.MADV_NOHUGEPAGE: fallthrough case linux.MADV_MERGEABLE, linux.MADV_UNMERGEABLE: @@ -191,7 +195,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca case linux.MADV_NORMAL, linux.MADV_RANDOM, linux.MADV_SEQUENTIAL, linux.MADV_WILLNEED: // Do nothing, we totally ignore the suggestions above. return 0, nil, nil - case linux.MADV_REMOVE, linux.MADV_DOFORK, linux.MADV_DONTFORK: + case linux.MADV_REMOVE: // These "suggestions" have application-visible side effects, so we // have to indicate that we don't support them. return 0, nil, syserror.ENOSYS diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 7e3ad08a9..0618fea58 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -985,6 +985,7 @@ cc_binary( "//test/util:file_descriptor", "//test/util:logging", "//test/util:memory_util", + "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index f6ad4d18b..352fcc6c4 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -29,6 +29,7 @@ #include "test/util/file_descriptor.h" #include "test/util/logging.h" #include "test/util/memory_util.h" +#include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -136,6 +137,114 @@ TEST(MadviseDontneedTest, IgnoresPermissions) { EXPECT_THAT(madvise(m.ptr(), m.len(), MADV_DONTNEED), SyscallSucceeds()); } +TEST(MadviseDontforkTest, AddressLength) { + auto m = + ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE)); + char *addr = static_cast(m.ptr()); + + // Address must be page aligned. + EXPECT_THAT(madvise(addr + 1, kPageSize, MADV_DONTFORK), + SyscallFailsWithErrno(EINVAL)); + + // Zero length madvise always succeeds. + EXPECT_THAT(madvise(addr, 0, MADV_DONTFORK), SyscallSucceeds()); + + // Length must not roll over after rounding up. + size_t badlen = std::numeric_limits::max() - (kPageSize / 2); + EXPECT_THAT(madvise(0, badlen, MADV_DONTFORK), SyscallFailsWithErrno(EINVAL)); + + // Length need not be page aligned - it is implicitly rounded up. + EXPECT_THAT(madvise(addr, 1, MADV_DONTFORK), SyscallSucceeds()); + EXPECT_THAT(madvise(addr, kPageSize, MADV_DONTFORK), SyscallSucceeds()); +} + +TEST(MadviseDontforkTest, DontforkShared) { + // Mmap two shared file-backed pages and MADV_DONTFORK the second page. + TempPath f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + /* parent = */ GetAbsoluteTestTmpdir(), + /* content = */ std::string(kPageSize * 2, 2), TempPath::kDefaultFileMode)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); + + Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( + nullptr, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); + + const Mapping ms1 = Mapping(reinterpret_cast(m.addr()), kPageSize); + const Mapping ms2 = + Mapping(reinterpret_cast(m.addr() + kPageSize), kPageSize); + m.release(); + + ASSERT_THAT(madvise(ms2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds()); + + const auto rest = [&] { + // First page is mapped in child and modifications are visible to parent + // via the shared mapping. + TEST_CHECK(IsMapped(ms1.addr())); + ExpectAllMappingBytes(ms1, 2); + memset(ms1.ptr(), 1, kPageSize); + ExpectAllMappingBytes(ms1, 1); + + // Second page must not be mapped in child. + TEST_CHECK(!IsMapped(ms2.addr())); + }; + + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); + + ExpectAllMappingBytes(ms1, 1); // page contents modified by child. + ExpectAllMappingBytes(ms2, 2); // page contents unchanged. +} + +TEST(MadviseDontforkTest, DontforkAnonPrivate) { + // Mmap three anonymous pages and MADV_DONTFORK the middle page. + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize * 3, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + const Mapping mp1 = Mapping(reinterpret_cast(m.addr()), kPageSize); + const Mapping mp2 = + Mapping(reinterpret_cast(m.addr() + kPageSize), kPageSize); + const Mapping mp3 = + Mapping(reinterpret_cast(m.addr() + 2 * kPageSize), kPageSize); + m.release(); + + ASSERT_THAT(madvise(mp2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds()); + + // Verify that all pages are zeroed and memset the first, second and third + // pages to 1, 2, and 3 respectively. + ExpectAllMappingBytes(mp1, 0); + memset(mp1.ptr(), 1, kPageSize); + + ExpectAllMappingBytes(mp2, 0); + memset(mp2.ptr(), 2, kPageSize); + + ExpectAllMappingBytes(mp3, 0); + memset(mp3.ptr(), 3, kPageSize); + + const auto rest = [&] { + // Verify first page is mapped, verify its contents and then modify the + // page. The mapping is private so the modifications are not visible to + // the parent. + TEST_CHECK(IsMapped(mp1.addr())); + ExpectAllMappingBytes(mp1, 1); + memset(mp1.ptr(), 11, kPageSize); + ExpectAllMappingBytes(mp1, 11); + + // Verify second page is not mapped. + TEST_CHECK(!IsMapped(mp2.addr())); + + // Verify third page is mapped, verify its contents and then modify the + // page. The mapping is private so the modifications are not visible to + // the parent. + TEST_CHECK(IsMapped(mp3.addr())); + ExpectAllMappingBytes(mp3, 3); + memset(mp3.ptr(), 13, kPageSize); + ExpectAllMappingBytes(mp3, 13); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); + + // The fork and COW by child should not affect the parent mappings. + ExpectAllMappingBytes(mp1, 1); + ExpectAllMappingBytes(mp2, 2); + ExpectAllMappingBytes(mp3, 3); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc index 7298d4ca8..64e435cb7 100644 --- a/test/syscalls/linux/mremap.cc +++ b/test/syscalls/linux/mremap.cc @@ -46,17 +46,6 @@ PosixErrorOr Mremap(void* old_address, size_t old_size, size_t new_size, return rv; } -// Returns true if the page containing addr is mapped. -bool IsMapped(uintptr_t addr) { - int const rv = msync(reinterpret_cast(addr & ~(kPageSize - 1)), - kPageSize, MS_ASYNC); - if (rv == 0) { - return true; - } - TEST_PCHECK_MSG(errno == ENOMEM, "msync failed with unexpected errno"); - return false; -} - // Fixture for mremap tests parameterized by mmap flags. using MremapParamTest = ::testing::TestWithParam; diff --git a/test/util/memory_util.h b/test/util/memory_util.h index 8c77778ea..190c469b5 100644 --- a/test/util/memory_util.h +++ b/test/util/memory_util.h @@ -118,6 +118,17 @@ inline PosixErrorOr MmapAnon(size_t length, int prot, int flags) { return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0); } +// Returns true if the page containing addr is mapped. +inline bool IsMapped(uintptr_t addr) { + int const rv = msync(reinterpret_cast(addr & ~(kPageSize - 1)), + kPageSize, MS_ASYNC); + if (rv == 0) { + return true; + } + TEST_PCHECK_MSG(errno == ENOMEM, "msync failed with unexpected errno"); + return false; +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3