summaryrefslogtreecommitdiffhomepage
path: root/pkg/safecopy
diff options
context:
space:
mode:
authorAdin Scannell <ascannell@google.com>2020-02-04 12:39:58 -0800
committergVisor bot <gvisor-bot@google.com>2020-02-04 12:47:06 -0800
commitf5072caaf85b9f067d737a874804c04e2b9039b8 (patch)
treef9c73cecd76a5ec7e0e006fa85def238b4ac1b48 /pkg/safecopy
parentdcffddf0cae026411e7e678744a1e39dc2b513cf (diff)
Fix safecopy test.
This is failing in Go1.14 due to new checkptr constraints. This version should avoid hitting this constraints by doing "dangerous" pointer math less dangerously? PiperOrigin-RevId: 293205764
Diffstat (limited to 'pkg/safecopy')
-rw-r--r--pkg/safecopy/safecopy_test.go88
-rw-r--r--pkg/safecopy/safecopy_unsafe.go98
2 files changed, 112 insertions, 74 deletions
diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go
index 5818f7f9b..7f7f69d61 100644
--- a/pkg/safecopy/safecopy_test.go
+++ b/pkg/safecopy/safecopy_test.go
@@ -138,10 +138,14 @@ func TestSwapUint32Success(t *testing.T) {
func TestSwapUint32AlignmentError(t *testing.T) {
// Test that SwapUint32 returns an AlignmentError when passed an unaligned
// address.
- data := new(struct{ val uint64 })
- addr := uintptr(unsafe.Pointer(&data.val)) + 1
- want := AlignmentError{Addr: addr, Alignment: 4}
- if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want {
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := SwapUint32(ptr, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -171,10 +175,14 @@ func TestSwapUint64Success(t *testing.T) {
func TestSwapUint64AlignmentError(t *testing.T) {
// Test that SwapUint64 returns an AlignmentError when passed an unaligned
// address.
- data := new(struct{ val1, val2 uint64 })
- addr := uintptr(unsafe.Pointer(&data.val1)) + 1
- want := AlignmentError{Addr: addr, Alignment: 8}
- if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want {
+ data := make([]byte, 16) // 2 * sizeof(uint64).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 8; offset != 0 {
+ alignedIndex = 8 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 8}
+ if _, err := SwapUint64(ptr, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -201,10 +209,14 @@ func TestCompareAndSwapUint32Success(t *testing.T) {
func TestCompareAndSwapUint32AlignmentError(t *testing.T) {
// Test that CompareAndSwapUint32 returns an AlignmentError when passed an
// unaligned address.
- data := new(struct{ val uint64 })
- addr := uintptr(unsafe.Pointer(&data.val)) + 1
- want := AlignmentError{Addr: addr, Alignment: 4}
- if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want {
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := CompareAndSwapUint32(ptr, 0, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -252,8 +264,8 @@ func TestCopyInSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := CopyIn(dst, src)
if n != bytesBeforeFault {
@@ -276,8 +288,8 @@ func TestCopyInBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := CopyIn(dst, src)
if n != bytesBeforeFault {
@@ -300,8 +312,8 @@ func TestCopyOutSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := CopyOut(dst, src)
if n != bytesBeforeFault {
@@ -324,8 +336,8 @@ func TestCopyOutBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := CopyOut(dst, src)
if n != bytesBeforeFault {
@@ -348,8 +360,8 @@ func TestCopySourceSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -372,8 +384,8 @@ func TestCopySourceBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -396,8 +408,8 @@ func TestCopyDestinationSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -420,8 +432,8 @@ func TestCopyDestinationBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -444,8 +456,8 @@ func TestZeroOutSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
n, err := ZeroOut(dst, pageSize)
if n != uintptr(bytesBeforeFault) {
t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
@@ -467,8 +479,8 @@ func TestZeroOutBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
n, err := ZeroOut(dst, pageSize)
if n != uintptr(bytesBeforeFault) {
t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
@@ -488,7 +500,7 @@ func TestSwapUint32SegvError(t *testing.T) {
// Test that SwapUint32 returns a SegvError when reaching a page that
// signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint32(unsafe.Pointer(secondPage), 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -500,7 +512,7 @@ func TestSwapUint32BusError(t *testing.T) {
// Test that SwapUint32 returns a BusError when reaching a page that
// signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint32(unsafe.Pointer(secondPage), 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -512,7 +524,7 @@ func TestSwapUint64SegvError(t *testing.T) {
// Test that SwapUint64 returns a SegvError when reaching a page that
// signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint64(unsafe.Pointer(secondPage), 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -524,7 +536,7 @@ func TestSwapUint64BusError(t *testing.T) {
// Test that SwapUint64 returns a BusError when reaching a page that
// signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint64(unsafe.Pointer(secondPage), 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -536,7 +548,7 @@ func TestCompareAndSwapUint32SegvError(t *testing.T) {
// Test that CompareAndSwapUint32 returns a SegvError when reaching a page
// that signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -548,7 +560,7 @@ func TestCompareAndSwapUint32BusError(t *testing.T) {
// Test that CompareAndSwapUint32 returns a BusError when reaching a page
// that signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index eef028e68..41dd567f3 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -16,6 +16,7 @@ package safecopy
import (
"fmt"
+ "runtime"
"syscall"
"unsafe"
)
@@ -35,7 +36,7 @@ const maxRegisterSize = 16
// successfully copied.
//
//go:noescape
-func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+func memcpy(dst, src uintptr, n uintptr) (fault uintptr, sig int32)
// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
// signal is received during the write, it returns the address that caused the
@@ -47,7 +48,7 @@ func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32
// successfully written.
//
//go:noescape
-func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+func memclr(ptr uintptr, n uintptr) (fault uintptr, sig int32)
// swapUint32 atomically stores new into *ptr and returns (the previous *ptr
// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
@@ -90,29 +91,35 @@ func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
+ n, err := copyIn(dst, uintptr(src))
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyIn is the underlying definition for CopyIn.
+func copyIn(dst []byte, src uintptr) (int, error) {
toCopy := uintptr(len(dst))
if len(dst) == 0 {
return 0, nil
}
- fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy)
+ fault, sig := memcpy(uintptr(unsafe.Pointer(&dst[0])), src, toCopy)
if sig == 0 {
return len(dst), nil
}
- faultN, srcN := uintptr(fault), uintptr(src)
- if faultN < srcN || faultN >= srcN+toCopy {
- panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, faultN, srcN, srcN+toCopy))
+ if fault < src || fault >= src+toCopy {
+ panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, fault, src, src+toCopy))
}
// memcpy might have ended the copy up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to copy up to the fault.
var done int
- if faultN-srcN > maxRegisterSize {
- done = int(faultN - srcN - maxRegisterSize)
+ if fault-src > maxRegisterSize {
+ done = int(fault - src - maxRegisterSize)
}
- n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done)))
+ n, err := copyIn(dst[done:int(fault-src)], src+uintptr(done))
done += n
if err != nil {
return done, err
@@ -124,29 +131,35 @@ func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
// bytes done and an error if SIGSEGV or SIGBUS is received while writing to
// dst.
func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
+ n, err := copyOut(uintptr(dst), src)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// copyOut is the underlying definition for CopyOut.
+func copyOut(dst uintptr, src []byte) (int, error) {
toCopy := uintptr(len(src))
if toCopy == 0 {
return 0, nil
}
- fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy)
+ fault, sig := memcpy(dst, uintptr(unsafe.Pointer(&src[0])), toCopy)
if sig == 0 {
return len(src), nil
}
- faultN, dstN := uintptr(fault), uintptr(dst)
- if faultN < dstN || faultN >= dstN+toCopy {
- panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toCopy))
+ if fault < dst || fault >= dst+toCopy {
+ panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toCopy))
}
// memcpy might have ended the copy up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to copy up to the fault.
var done int
- if faultN-dstN > maxRegisterSize {
- done = int(faultN - dstN - maxRegisterSize)
+ if fault-dst > maxRegisterSize {
+ done = int(fault - dst - maxRegisterSize)
}
- n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)])
+ n, err := copyOut(dst+uintptr(done), src[done:int(fault-dst)])
done += n
if err != nil {
return done, err
@@ -161,6 +174,14 @@ func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap,
// the resulting contents of dst are unspecified.
func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
+ n, err := copyN(uintptr(dst), uintptr(src), toCopy)
+ runtime.KeepAlive(dst)
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyN is the underlying definition for Copy.
+func copyN(dst, src uintptr, toCopy uintptr) (uintptr, error) {
if toCopy == 0 {
return 0, nil
}
@@ -171,17 +192,16 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
}
// Did the fault occur while reading from src or writing to dst?
- faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst)
faultAfterSrc := ^uintptr(0)
- if faultN >= srcN {
- faultAfterSrc = faultN - srcN
+ if fault >= src {
+ faultAfterSrc = fault - src
}
faultAfterDst := ^uintptr(0)
- if faultN >= dstN {
- faultAfterDst = faultN - dstN
+ if fault >= dst {
+ faultAfterDst = fault - dst
}
if faultAfterSrc >= toCopy && faultAfterDst >= toCopy {
- panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, faultN, srcN, srcN+toCopy, dstN, dstN+toCopy))
+ panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, fault, src, src+toCopy, dst, dst+toCopy))
}
faultedAfter := faultAfterSrc
if faultedAfter > faultAfterDst {
@@ -195,7 +215,7 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
if faultedAfter > maxRegisterSize {
done = faultedAfter - maxRegisterSize
}
- n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done)
+ n, err := copyN(dst+done, src+done, faultedAfter-done)
done += n
if err != nil {
return done, err
@@ -206,6 +226,13 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes
// written and an error if SIGSEGV or SIGBUS is received while writing to dst.
func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
+ n, err := zeroOut(uintptr(dst), toZero)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// zeroOut is the underlying definition for ZeroOut.
+func zeroOut(dst uintptr, toZero uintptr) (uintptr, error) {
if toZero == 0 {
return 0, nil
}
@@ -215,19 +242,18 @@ func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
return toZero, nil
}
- faultN, dstN := uintptr(fault), uintptr(dst)
- if faultN < dstN || faultN >= dstN+toZero {
- panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toZero))
+ if fault < dst || fault >= dst+toZero {
+ panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toZero))
}
// memclr might have ended the write up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to write up to the fault.
var done uintptr
- if faultN-dstN > maxRegisterSize {
- done = faultN - dstN - maxRegisterSize
+ if fault-dst > maxRegisterSize {
+ done = fault - dst - maxRegisterSize
}
- n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done)
+ n, err := zeroOut(dst+done, fault-dst-done)
done += n
if err != nil {
return done, err
@@ -243,7 +269,7 @@ func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
old, sig := swapUint32(ptr, new)
- return old, errorFromFaultSignal(ptr, sig)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
}
// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns
@@ -254,7 +280,7 @@ func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) {
return 0, AlignmentError{addr, 8}
}
old, sig := swapUint64(ptr, new)
- return old, errorFromFaultSignal(ptr, sig)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
}
// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32,
@@ -265,7 +291,7 @@ func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
prev, sig := compareAndSwapUint32(ptr, old, new)
- return prev, errorFromFaultSignal(ptr, sig)
+ return prev, errorFromFaultSignal(uintptr(ptr), sig)
}
// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
@@ -277,17 +303,17 @@ func LoadUint32(ptr unsafe.Pointer) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
val, sig := loadUint32(ptr)
- return val, errorFromFaultSignal(ptr, sig)
+ return val, errorFromFaultSignal(uintptr(ptr), sig)
}
-func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error {
+func errorFromFaultSignal(addr uintptr, sig int32) error {
switch sig {
case 0:
return nil
case int32(syscall.SIGSEGV):
- return SegvError{uintptr(addr)}
+ return SegvError{addr}
case int32(syscall.SIGBUS):
- return BusError{uintptr(addr)}
+ return BusError{addr}
default:
panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
}