From 70d7c52bd7583393d39177a7935cca57372d67f1 Mon Sep 17 00:00:00 2001
From: Nicolas Lacasse <nlacasse@google.com>
Date: Thu, 16 Jan 2020 13:58:25 -0800
Subject: Implement tmpfs.SetStat with a size argument.

This is similar to 'Truncate' in vfs1.

Updates https://github.com/google/gvisor/issues/1197

PiperOrigin-RevId: 290139140
---
 pkg/sentry/fsimpl/tmpfs/regular_file.go      |  35 ++++++++
 pkg/sentry/fsimpl/tmpfs/regular_file_test.go | 121 +++++++++++++++++++++++++++
 pkg/sentry/fsimpl/tmpfs/tmpfs.go             |  54 ++++++++++--
 3 files changed, 205 insertions(+), 5 deletions(-)

(limited to 'pkg/sentry/fsimpl/tmpfs')

diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index f200e767d..5fa70cc6d 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -63,6 +63,41 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
 	return &file.inode
 }
 
+// truncate grows or shrinks the file to the given size. It returns true if the
+// file size was updated.
+func (rf *regularFile) truncate(size uint64) (bool, error) {
+	rf.mu.Lock()
+	defer rf.mu.Unlock()
+
+	if size == rf.size {
+		// Nothing to do.
+		return false, nil
+	}
+
+	if size > rf.size {
+		// Growing the file.
+		if rf.seals&linux.F_SEAL_GROW != 0 {
+			// Seal does not allow growth.
+			return false, syserror.EPERM
+		}
+		rf.size = size
+		return true, nil
+	}
+
+	// Shrinking the file
+	if rf.seals&linux.F_SEAL_SHRINK != 0 {
+		// Seal does not allow shrink.
+		return false, syserror.EPERM
+	}
+
+	// TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
+	// mappings.
+
+	rf.data.Truncate(size, rf.memFile)
+	rf.size = size
+	return true, nil
+}
+
 type regularFileFD struct {
 	fileDescription
 
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index 7b0a962f0..034a29fdb 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -313,3 +313,124 @@ func TestPRead(t *testing.T) {
 		}
 	}
 }
+
+func TestTruncate(t *testing.T) {
+	ctx := contexttest.Context(t)
+	fd, cleanup, err := newFileFD(ctx, 0644)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cleanup()
+
+	// Fill the file with some data.
+	data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+	written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+	if err != nil {
+		t.Fatalf("fd.Write failed: %v", err)
+	}
+
+	// Size should be same as written.
+	sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE}
+	stat, err := fd.Stat(ctx, sizeStatOpts)
+	if err != nil {
+		t.Fatalf("fd.Stat failed: %v", err)
+	}
+	if got, want := int64(stat.Size), written; got != want {
+		t.Errorf("fd.Stat got size %d, want %d", got, want)
+	}
+
+	// Truncate down.
+	newSize := uint64(10)
+	if err := fd.SetStat(ctx, vfs.SetStatOptions{
+		Stat: linux.Statx{
+			Mask: linux.STATX_SIZE,
+			Size: newSize,
+		},
+	}); err != nil {
+		t.Errorf("fd.Truncate failed: %v", err)
+	}
+	// Size should be updated.
+	statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts)
+	if err != nil {
+		t.Fatalf("fd.Stat failed: %v", err)
+	}
+	if got, want := statAfterTruncateDown.Size, newSize; got != want {
+		t.Errorf("fd.Stat got size %d, want %d", got, want)
+	}
+	// We should only read newSize worth of data.
+	buf := make([]byte, 1000)
+	if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+		t.Fatalf("fd.PRead failed: %v", err)
+	} else if uint64(n) != newSize {
+		t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+	}
+	// Mtime and Ctime should be bumped.
+	if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() {
+		t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime)
+	}
+	if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() {
+		t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+	}
+
+	// Truncate up.
+	newSize = 100
+	if err := fd.SetStat(ctx, vfs.SetStatOptions{
+		Stat: linux.Statx{
+			Mask: linux.STATX_SIZE,
+			Size: newSize,
+		},
+	}); err != nil {
+		t.Errorf("fd.Truncate failed: %v", err)
+	}
+	// Size should be updated.
+	statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts)
+	if err != nil {
+		t.Fatalf("fd.Stat failed: %v", err)
+	}
+	if got, want := statAfterTruncateUp.Size, newSize; got != want {
+		t.Errorf("fd.Stat got size %d, want %d", got, want)
+	}
+	// We should read newSize worth of data.
+	buf = make([]byte, 1000)
+	if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+		t.Fatalf("fd.PRead failed: %v", err)
+	} else if uint64(n) != newSize {
+		t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+	}
+	// Bytes should be null after 10, since we previously truncated to 10.
+	for i := uint64(10); i < newSize; i++ {
+		if buf[i] != 0 {
+			t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i])
+			break
+		}
+	}
+	// Mtime and Ctime should be bumped.
+	if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() {
+		t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime)
+	}
+	if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() {
+		t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+	}
+
+	// Truncate to the current size.
+	newSize = statAfterTruncateUp.Size
+	if err := fd.SetStat(ctx, vfs.SetStatOptions{
+		Stat: linux.Statx{
+			Mask: linux.STATX_SIZE,
+			Size: newSize,
+		},
+	}); err != nil {
+		t.Errorf("fd.Truncate failed: %v", err)
+	}
+	statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts)
+	if err != nil {
+		t.Fatalf("fd.Stat failed: %v", err)
+	}
+	// Mtime and Ctime should not be bumped, since operation is a noop.
+	if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() {
+		t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime)
+	}
+	if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() {
+		t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime)
+	}
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index d6960ee47..1d4889c89 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -35,6 +35,7 @@ import (
 	"gvisor.dev/gvisor/pkg/sentry/pgalloc"
 	"gvisor.dev/gvisor/pkg/sentry/vfs"
 	"gvisor.dev/gvisor/pkg/sync"
+	"gvisor.dev/gvisor/pkg/syserror"
 )
 
 // FilesystemType implements vfs.FilesystemType.
@@ -121,6 +122,9 @@ func (d *dentry) DecRef() {
 
 // inode represents a filesystem object.
 type inode struct {
+	// clock is a realtime clock used to set timestamps in file operations.
+	clock time.Clock
+
 	// refs is a reference count. refs is accessed using atomic memory
 	// operations.
 	//
@@ -151,13 +155,14 @@ type inode struct {
 const maxLinks = math.MaxUint32
 
 func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
-	now := fs.clock.Now().Nanoseconds()
+	i.clock = fs.clock
 	i.refs = 1
 	i.mode = uint32(mode)
 	i.uid = uint32(creds.EffectiveKUID)
 	i.gid = uint32(creds.EffectiveKGID)
 	i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
 	// Tmpfs creation sets atime, ctime, and mtime to current time.
+	now := i.clock.Now().Nanoseconds()
 	i.atime = now
 	i.ctime = now
 	i.mtime = now
@@ -270,30 +275,69 @@ func (i *inode) statTo(stat *linux.Statx) {
 }
 
 func (i *inode) setStat(stat linux.Statx) error {
-	// TODO(gvisor.dev/issues/1197): Handle stat.Size by growing/shrinking
-	// the file.
 	if stat.Mask == 0 {
 		return nil
 	}
 	i.mu.Lock()
+	var (
+		needsMtimeBump bool
+		needsCtimeBump bool
+	)
 	mask := stat.Mask
 	if mask&linux.STATX_MODE != 0 {
 		atomic.StoreUint32(&i.mode, uint32(stat.Mode))
+		needsCtimeBump = true
 	}
 	if mask&linux.STATX_UID != 0 {
 		atomic.StoreUint32(&i.uid, stat.UID)
+		needsCtimeBump = true
 	}
 	if mask&linux.STATX_GID != 0 {
 		atomic.StoreUint32(&i.gid, stat.GID)
+		needsCtimeBump = true
+	}
+	if mask&linux.STATX_SIZE != 0 {
+		switch impl := i.impl.(type) {
+		case *regularFile:
+			updated, err := impl.truncate(stat.Size)
+			if err != nil {
+				return err
+			}
+			if updated {
+				needsMtimeBump = true
+				needsCtimeBump = true
+			}
+		case *directory:
+			return syserror.EISDIR
+		case *symlink:
+			return syserror.EINVAL
+		case *namedPipe:
+			// Nothing.
+		default:
+			panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+		}
 	}
 	if mask&linux.STATX_ATIME != 0 {
 		atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+		needsCtimeBump = true
+	}
+	if mask&linux.STATX_MTIME != 0 {
+		atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+		needsCtimeBump = true
+		// Ignore the mtime bump, since we just set it ourselves.
+		needsMtimeBump = false
 	}
 	if mask&linux.STATX_CTIME != 0 {
 		atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+		// Ignore the ctime bump, since we just set it ourselves.
+		needsCtimeBump = false
 	}
-	if mask&linux.STATX_MTIME != 0 {
-		atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+	now := i.clock.Now().Nanoseconds()
+	if needsMtimeBump {
+		atomic.StoreInt64(&i.mtime, now)
+	}
+	if needsCtimeBump {
+		atomic.StoreInt64(&i.ctime, now)
 	}
 	i.mu.Unlock()
 	return nil
-- 
cgit v1.2.3