// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package fsutil

import (
	"fmt"
	"io"
	"math"

	"gvisor.dev/gvisor/pkg/context"
	"gvisor.dev/gvisor/pkg/safemem"
	"gvisor.dev/gvisor/pkg/sentry/memmap"
	"gvisor.dev/gvisor/pkg/sentry/pgalloc"
	"gvisor.dev/gvisor/pkg/sentry/platform"
	"gvisor.dev/gvisor/pkg/sentry/usage"
	"gvisor.dev/gvisor/pkg/usermem"
)

// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
// platform.File. It is used to implement Mappables that store data in
// sparsely-allocated memory.
//
// type FileRangeSet <generated by go_generics>

// FileRangeSetFunctions implements segment.Functions for FileRangeSet.
type FileRangeSetFunctions struct{}

// MinKey implements segment.Functions.MinKey.
func (FileRangeSetFunctions) MinKey() uint64 {
	return 0
}

// MaxKey implements segment.Functions.MaxKey.
func (FileRangeSetFunctions) MaxKey() uint64 {
	return math.MaxUint64
}

// ClearValue implements segment.Functions.ClearValue.
func (FileRangeSetFunctions) ClearValue(_ *uint64) {
}

// Merge implements segment.Functions.Merge.
func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
	if frstart1+mr1.Length() != frstart2 {
		return 0, false
	}
	return frstart1, true
}

// Split implements segment.Functions.Split.
func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
	return frstart, frstart + (split - mr.Start)
}

// FileRange returns the FileRange mapped by seg.
func (seg FileRangeIterator) FileRange() platform.FileRange {
	return seg.FileRangeOf(seg.Range())
}

// FileRangeOf returns the FileRange mapped by mr.
//
// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
	frstart := seg.Value() + (mr.Start - seg.Start())
	return platform.FileRange{frstart, frstart + mr.Length()}
}

// Fill attempts to ensure that all memmap.Mappable offsets in required are
// mapped to a platform.File offset, by allocating from mf with the given
// memory usage kind and invoking readAt to store data into memory. (If readAt
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
// mmap(2): bytes after EOF on the same page are zeroed; pages after EOF are
// invalid.
//
// Fill may read offsets outside of required, but will never read offsets
// outside of optional. It returns a non-nil error if any error occurs, even
// if the error only affects offsets in optional, but not in required.
//
// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
// required and optional must be page-aligned.
func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
	gap := frs.LowerBoundGap(required.Start)
	for gap.Ok() && gap.Start() < required.End {
		if gap.Range().Length() == 0 {
			gap = gap.NextGap()
			continue
		}
		gr := gap.Range().Intersect(optional)

		// Read data into the gap.
		fr, err := mf.AllocateAndFill(gr.Length(), kind, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
			var done uint64
			for !dsts.IsEmpty() {
				n, err := readAt(ctx, dsts, gr.Start+done)
				done += n
				dsts = dsts.DropFirst64(n)
				if err != nil {
					if err == io.EOF {
						// MemoryFile.AllocateAndFill truncates down to a page
						// boundary, but FileRangeSet.Fill is supposed to
						// zero-fill to the end of the page in this case.
						donepgaddr, ok := usermem.Addr(done).RoundUp()
						if donepg := uint64(donepgaddr); ok && donepg != done {
							dsts.DropFirst64(donepg - done)
							done = donepg
							if dsts.IsEmpty() {
								return done, nil
							}
						}
					}
					return done, err
				}
			}
			return done, nil
		}))

		// Store anything we managed to read into the cache.
		if done := fr.Length(); done != 0 {
			gr.End = gr.Start + done
			gap = frs.Insert(gap, gr, fr.Start).NextGap()
		}

		if err != nil {
			return err
		}
	}
	return nil
}

// Drop removes segments for memmap.Mappable offsets in mr, freeing the
// corresponding platform.FileRanges.
//
// Preconditions: mr must be page-aligned.
func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
	seg := frs.LowerBoundSegment(mr.Start)
	for seg.Ok() && seg.Start() < mr.End {
		seg = frs.Isolate(seg, mr)
		mf.DecRef(seg.FileRange())
		seg = frs.Remove(seg).NextSegment()
	}
}

// DropAll removes all segments in mr, freeing the corresponding
// platform.FileRanges.
func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
	for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
		mf.DecRef(seg.FileRange())
	}
	frs.RemoveAll()
}

// Truncate updates frs to reflect Mappable truncation to the given length:
// bytes after the new EOF on the same page are zeroed, and pages after the new
// EOF are freed.
func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) {
	pgendaddr, ok := usermem.Addr(end).RoundUp()
	if ok {
		pgend := uint64(pgendaddr)

		// Free truncated pages.
		frs.SplitAt(pgend)
		seg := frs.LowerBoundSegment(pgend)
		for seg.Ok() {
			mf.DecRef(seg.FileRange())
			seg = frs.Remove(seg).NextSegment()
		}

		if end == pgend {
			return
		}
	}

	// Here we know end < end.RoundUp(). If the new EOF lands in the
	// middle of a page that we have, zero out its contents beyond the new
	// length.
	seg := frs.FindSegment(end)
	if seg.Ok() {
		fr := seg.FileRange()
		fr.Start += end - seg.Start()
		ims, err := mf.MapInternal(fr, usermem.Write)
		if err != nil {
			// There's no good recourse from here. This means
			// that we can't keep cached memory consistent with
			// the new end of file. The caller may have already
			// updated the file size on their backing file system.
			//
			// We don't want to risk blindly continuing onward,
			// so in the extremely rare cases this does happen,
			// we abandon ship.
			panic(fmt.Sprintf("Failed to map %v: %v", fr, err))
		}
		if _, err := safemem.ZeroSeq(ims); err != nil {
			panic(fmt.Sprintf("Zeroing %v failed: %v", fr, err))
		}
	}
}