From 56a9a13976ad800a8a34b194d35f0169d0a0bb23 Mon Sep 17 00:00:00 2001
From: Andrei Vagin <avagin@google.com>
Date: Tue, 23 Mar 2021 18:44:38 -0700
Subject: Move the code that manages floating-point state to a separate package

This change is inspired by Adin's cl/355256448.

PiperOrigin-RevId: 364695931
---
 pkg/sentry/platform/kvm/BUILD                   |  2 ++
 pkg/sentry/platform/kvm/bluepill_amd64.go       |  6 ++---
 pkg/sentry/platform/kvm/bluepill_arm64.go       | 10 ++++----
 pkg/sentry/platform/kvm/context.go              |  2 +-
 pkg/sentry/platform/kvm/kvm_amd64_test.go       |  2 +-
 pkg/sentry/platform/kvm/kvm_test.go             | 33 +++++++++++++------------
 pkg/sentry/platform/kvm/machine_amd64.go        |  5 ++--
 pkg/sentry/platform/kvm/machine_arm64.go        |  3 ++-
 pkg/sentry/platform/kvm/machine_arm64_unsafe.go |  3 ++-
 9 files changed, 36 insertions(+), 30 deletions(-)

(limited to 'pkg/sentry/platform/kvm')

diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 4f9e781af..03a76eb9b 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -50,6 +50,7 @@ go_library(
         "//pkg/safecopy",
         "//pkg/seccomp",
         "//pkg/sentry/arch",
+        "//pkg/sentry/arch/fpu",
         "//pkg/sentry/memmap",
         "//pkg/sentry/platform",
         "//pkg/sentry/platform/interrupt",
@@ -78,6 +79,7 @@ go_test(
         "//pkg/ring0",
         "//pkg/ring0/pagetables",
         "//pkg/sentry/arch",
+        "//pkg/sentry/arch/fpu",
         "//pkg/sentry/platform",
         "//pkg/sentry/platform/kvm/testutil",
         "//pkg/sentry/time",
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index f4b9a5321..d761bbdee 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -73,7 +73,7 @@ func (c *vCPU) KernelSyscall() {
 	// We only trigger a bluepill entry in the bluepill function, and can
 	// therefore be guaranteed that there is no floating point state to be
 	// loaded on resuming from halt. We only worry about saving on exit.
-	ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
+	ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
 	ring0.Halt()
 	ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
 }
@@ -92,7 +92,7 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
 		regs.Rip = 0
 	}
 	// See above.
-	ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
+	ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
 	ring0.Halt()
 	ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
 }
@@ -124,5 +124,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
 	// Set the context pointer to the saved floating point state. This is
 	// where the guest data has been serialized, the kernel will restore
 	// from this new pointer value.
-	context.Fpstate = uint64(uintptrValue((*byte)(c.floatingPointState)))
+	context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer()))
 }
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index e26b7da8d..578852c3f 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -92,7 +92,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
 
 	lazyVfp := c.GetLazyVFP()
 	if lazyVfp != 0 {
-		fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+		fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
 		context.Fpsimd64.Fpsr = fpsimd.Fpsr
 		context.Fpsimd64.Fpcr = fpsimd.Fpcr
 		context.Fpsimd64.Vregs = fpsimd.Vregs
@@ -112,12 +112,12 @@ func (c *vCPU) KernelSyscall() {
 
 	fpDisableTrap := ring0.CPACREL1()
 	if fpDisableTrap != 0 {
-		fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+		fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
 		fpcr := ring0.GetFPCR()
 		fpsr := ring0.GetFPSR()
 		fpsimd.Fpcr = uint32(fpcr)
 		fpsimd.Fpsr = uint32(fpsr)
-		ring0.SaveVRegs((*byte)(c.floatingPointState))
+		ring0.SaveVRegs(c.floatingPointState.BytePointer())
 	}
 
 	ring0.Halt()
@@ -136,12 +136,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
 
 	fpDisableTrap := ring0.CPACREL1()
 	if fpDisableTrap != 0 {
-		fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+		fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
 		fpcr := ring0.GetFPCR()
 		fpsr := ring0.GetFPSR()
 		fpsimd.Fpcr = uint32(fpcr)
 		fpsimd.Fpsr = uint32(fpsr)
-		ring0.SaveVRegs((*byte)(c.floatingPointState))
+		ring0.SaveVRegs(c.floatingPointState.BytePointer())
 	}
 
 	ring0.Halt()
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index aeae01dbd..706fa53dc 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -65,7 +65,7 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a
 	// Prepare switch options.
 	switchOpts := ring0.SwitchOpts{
 		Registers:          &ac.StateData().Regs,
-		FloatingPointState: (*byte)(ac.FloatingPointData()),
+		FloatingPointState: ac.FloatingPointData(),
 		PageTables:         localAS.pageTables,
 		Flush:              localAS.Touch(cpu),
 		FullRestore:        ac.FullRestore(),
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index 76fc594a0..e44e995a0 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -33,7 +33,7 @@ func TestSegments(t *testing.T) {
 			var si arch.SignalInfo
 			if _, err := c.SwitchToUser(ring0.SwitchOpts{
 				Registers:          regs,
-				FloatingPointState: dummyFPState,
+				FloatingPointState: &dummyFPState,
 				PageTables:         pt,
 				FullRestore:        true,
 			}, &si); err == platform.ErrContextInterrupt {
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 6243b9a04..5bce16dde 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -25,13 +25,14 @@ import (
 	"gvisor.dev/gvisor/pkg/ring0"
 	"gvisor.dev/gvisor/pkg/ring0/pagetables"
 	"gvisor.dev/gvisor/pkg/sentry/arch"
+	"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
 	"gvisor.dev/gvisor/pkg/sentry/platform"
 	"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
 	ktime "gvisor.dev/gvisor/pkg/sentry/time"
 	"gvisor.dev/gvisor/pkg/usermem"
 )
 
-var dummyFPState = (*byte)(arch.NewFloatingPointData())
+var dummyFPState = fpu.NewState()
 
 type testHarness interface {
 	Errorf(format string, args ...interface{})
@@ -159,7 +160,7 @@ func TestApplicationSyscall(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 			FullRestore:        true,
 		}, &si); err == platform.ErrContextInterrupt {
@@ -173,7 +174,7 @@ func TestApplicationSyscall(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 		}, &si); err == platform.ErrContextInterrupt {
 			return true // Retry.
@@ -190,7 +191,7 @@ func TestApplicationFault(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 			FullRestore:        true,
 		}, &si); err == platform.ErrContextInterrupt {
@@ -205,7 +206,7 @@ func TestApplicationFault(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 		}, &si); err == platform.ErrContextInterrupt {
 			return true // Retry.
@@ -223,7 +224,7 @@ func TestRegistersSyscall(t *testing.T) {
 			var si arch.SignalInfo
 			if _, err := c.SwitchToUser(ring0.SwitchOpts{
 				Registers:          regs,
-				FloatingPointState: dummyFPState,
+				FloatingPointState: &dummyFPState,
 				PageTables:         pt,
 			}, &si); err == platform.ErrContextInterrupt {
 				continue // Retry.
@@ -246,7 +247,7 @@ func TestRegistersFault(t *testing.T) {
 			var si arch.SignalInfo
 			if _, err := c.SwitchToUser(ring0.SwitchOpts{
 				Registers:          regs,
-				FloatingPointState: dummyFPState,
+				FloatingPointState: &dummyFPState,
 				PageTables:         pt,
 				FullRestore:        true,
 			}, &si); err == platform.ErrContextInterrupt {
@@ -272,7 +273,7 @@ func TestBounce(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 		}, &si); err != platform.ErrContextInterrupt {
 			t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
@@ -287,7 +288,7 @@ func TestBounce(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 			FullRestore:        true,
 		}, &si); err != platform.ErrContextInterrupt {
@@ -319,7 +320,7 @@ func TestBounceStress(t *testing.T) {
 			var si arch.SignalInfo
 			if _, err := c.SwitchToUser(ring0.SwitchOpts{
 				Registers:          regs,
-				FloatingPointState: dummyFPState,
+				FloatingPointState: &dummyFPState,
 				PageTables:         pt,
 			}, &si); err != platform.ErrContextInterrupt {
 				t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
@@ -340,7 +341,7 @@ func TestInvalidate(t *testing.T) {
 			var si arch.SignalInfo
 			if _, err := c.SwitchToUser(ring0.SwitchOpts{
 				Registers:          regs,
-				FloatingPointState: dummyFPState,
+				FloatingPointState: &dummyFPState,
 				PageTables:         pt,
 			}, &si); err == platform.ErrContextInterrupt {
 				continue // Retry.
@@ -355,7 +356,7 @@ func TestInvalidate(t *testing.T) {
 			var si arch.SignalInfo
 			if _, err := c.SwitchToUser(ring0.SwitchOpts{
 				Registers:          regs,
-				FloatingPointState: dummyFPState,
+				FloatingPointState: &dummyFPState,
 				PageTables:         pt,
 				Flush:              true,
 			}, &si); err == platform.ErrContextInterrupt {
@@ -379,7 +380,7 @@ func TestEmptyAddressSpace(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 		}, &si); err == platform.ErrContextInterrupt {
 			return true // Retry.
@@ -393,7 +394,7 @@ func TestEmptyAddressSpace(t *testing.T) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 			FullRestore:        true,
 		}, &si); err == platform.ErrContextInterrupt {
@@ -469,7 +470,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 		}, &si); err == platform.ErrContextInterrupt {
 			a++
@@ -506,7 +507,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
 		var si arch.SignalInfo
 		if _, err := c.SwitchToUser(ring0.SwitchOpts{
 			Registers:          regs,
-			FloatingPointState: dummyFPState,
+			FloatingPointState: &dummyFPState,
 			PageTables:         pt,
 		}, &si); err == platform.ErrContextInterrupt {
 			a++
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 6e583baa3..8f2c82e73 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -27,6 +27,7 @@ import (
 	"gvisor.dev/gvisor/pkg/ring0"
 	"gvisor.dev/gvisor/pkg/ring0/pagetables"
 	"gvisor.dev/gvisor/pkg/sentry/arch"
+	"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
 	"gvisor.dev/gvisor/pkg/sentry/platform"
 	ktime "gvisor.dev/gvisor/pkg/sentry/time"
 	"gvisor.dev/gvisor/pkg/usermem"
@@ -70,7 +71,7 @@ type vCPUArchState struct {
 
 	// floatingPointState is the floating point state buffer used in guest
 	// to host transitions. See usage in bluepill_amd64.go.
-	floatingPointState *arch.FloatingPointData
+	floatingPointState fpu.State
 }
 
 const (
@@ -151,7 +152,7 @@ func (c *vCPU) initArchState() error {
 	// This will be saved prior to leaving the guest, and we restore from
 	// this always. We cannot use the pointer in the context alone because
 	// we don't know how large the area there is in reality.
-	c.floatingPointState = arch.NewFloatingPointData()
+	c.floatingPointState = fpu.NewState()
 
 	// Set the time offset to the host native time.
 	return c.setSystemTime()
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 7d7857067..2edc9d1b2 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -20,6 +20,7 @@ import (
 	"gvisor.dev/gvisor/pkg/ring0"
 	"gvisor.dev/gvisor/pkg/ring0/pagetables"
 	"gvisor.dev/gvisor/pkg/sentry/arch"
+	"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
 	"gvisor.dev/gvisor/pkg/sentry/platform"
 	"gvisor.dev/gvisor/pkg/usermem"
 )
@@ -32,7 +33,7 @@ type vCPUArchState struct {
 
 	// floatingPointState is the floating point state buffer used in guest
 	// to host transitions. See usage in bluepill_arm64.go.
-	floatingPointState *arch.FloatingPointData
+	floatingPointState fpu.State
 }
 
 const (
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 059aa43d0..e7d5f3193 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -26,6 +26,7 @@ import (
 	"gvisor.dev/gvisor/pkg/ring0"
 	"gvisor.dev/gvisor/pkg/ring0/pagetables"
 	"gvisor.dev/gvisor/pkg/sentry/arch"
+	"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
 	"gvisor.dev/gvisor/pkg/sentry/platform"
 	"gvisor.dev/gvisor/pkg/usermem"
 )
@@ -150,7 +151,7 @@ func (c *vCPU) initArchState() error {
 		c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
 	}
 
-	c.floatingPointState = arch.NewFloatingPointData()
+	c.floatingPointState = fpu.NewState()
 
 	return c.setSystemTime()
 }
-- 
cgit v1.2.3