diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/sentry/platform/safecopy | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/sentry/platform/safecopy')
-rw-r--r-- | pkg/sentry/platform/safecopy/BUILD | 28 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/atomic_amd64.s | 108 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/memclr_amd64.s | 157 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/memcpy_amd64.s | 242 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/safecopy.go | 140 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/safecopy_test.go | 617 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/safecopy_unsafe.go | 315 | ||||
-rw-r--r-- | pkg/sentry/platform/safecopy/sighandler_amd64.s | 124 |
8 files changed, 1731 insertions, 0 deletions
diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD new file mode 100644 index 000000000..8b9f29403 --- /dev/null +++ b/pkg/sentry/platform/safecopy/BUILD @@ -0,0 +1,28 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "safecopy", + srcs = [ + "atomic_amd64.s", + "memclr_amd64.s", + "memcpy_amd64.s", + "safecopy.go", + "safecopy_unsafe.go", + "sighandler_amd64.s", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/syserror", + ], +) + +go_test( + name = "safecopy_test", + srcs = [ + "safecopy_test.go", + ], + embed = [":safecopy"], +) diff --git a/pkg/sentry/platform/safecopy/atomic_amd64.s b/pkg/sentry/platform/safecopy/atomic_amd64.s new file mode 100644 index 000000000..69947dec3 --- /dev/null +++ b/pkg/sentry/platform/safecopy/atomic_amd64.s @@ -0,0 +1,108 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleSwapUint32Fault returns the value stored in DI. Control is transferred +// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as swapUint32 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVL DI, sig+20(FP) + RET + +// swapUint32 atomically stores new into *addr and returns (the previous *addr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) +TEXT ·swapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint32Fault will store a different value in this address. + MOVL $0, sig+20(FP) + + MOVQ addr+0(FP), DI + MOVL new+8(FP), AX + XCHGL AX, 0(DI) + MOVL AX, old+16(FP) + RET + +// handleSwapUint64Fault returns the value stored in DI. Control is transferred +// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as swapUint64 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 + MOVL DI, sig+24(FP) + RET + +// swapUint64 atomically stores new into *addr and returns (the previous *addr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: addr must be aligned to a 8-byte boundary. +// +//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) +TEXT ·swapUint64(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint64Fault will store a different value in this address. + MOVL $0, sig+24(FP) + + MOVQ addr+0(FP), DI + MOVQ new+8(FP), AX + XCHGQ AX, 0(DI) + MOVQ AX, old+16(FP) + RET + +// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is +// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the +// signal number stored in DI. +// +// It must have the same frame configuration as compareAndSwapUint32 so that it +// can undo any potential call frame set up by the assembler. +TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVL DI, sig+20(FP) + RET + +// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns +// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is +// received during the operation, the value of prev is unspecified, and sig is +// the number of the signal that was received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) +TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, this is + // the value the caller will see; if a signal is received, + // handleCompareAndSwapUint32Fault will store a different value in this + // address. + MOVL $0, sig+20(FP) + + MOVQ addr+0(FP), DI + MOVL old+8(FP), AX + MOVL new+12(FP), DX + LOCK + CMPXCHGL DX, 0(DI) + MOVL AX, prev+16(FP) + RET diff --git a/pkg/sentry/platform/safecopy/memclr_amd64.s b/pkg/sentry/platform/safecopy/memclr_amd64.s new file mode 100644 index 000000000..7d1019f60 --- /dev/null +++ b/pkg/sentry/platform/safecopy/memclr_amd64.s @@ -0,0 +1,157 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleMemclrFault returns (the value stored in AX, the value stored in DI). +// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS, +// with the faulting address stored in AX and the signal number stored in DI. +// +// It must have the same frame configuration as memclr so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemclrFault(SB), NOSPLIT, $0-28 + MOVQ AX, addr+16(FP) + MOVL DI, sig+24(FP) + RET + +// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS +// signal is received during the write, it returns the address that caused the +// fault and the number of the signal that was received. Otherwise, it returns +// an unspecified address and a signal number of 0. +// +// Data is written in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully written. +// +// The code is derived from runtime.memclrNoHeapPointers. +// +// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memclr(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemclrFault will store a different value in this address. + MOVL $0, sig+24(FP) + + MOVQ ptr+0(FP), DI + MOVQ n+8(FP), BX + XORQ AX, AX + + // MOVOU seems always faster than REP STOSQ. +tail: + TESTQ BX, BX + JEQ _0 + CMPQ BX, $2 + JBE _1or2 + CMPQ BX, $4 + JBE _3or4 + CMPQ BX, $8 + JB _5through7 + JE _8 + CMPQ BX, $16 + JBE _9through16 + PXOR X0, X0 + CMPQ BX, $32 + JBE _17through32 + CMPQ BX, $64 + JBE _33through64 + CMPQ BX, $128 + JBE _65through128 + CMPQ BX, $256 + JBE _129through256 + // TODO: use branch table and BSR to make this just a single dispatch + // TODO: for really big clears, use MOVNTDQ, even without AVX2. + +loop: + MOVOU X0, 0(DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, 64(DI) + MOVOU X0, 80(DI) + MOVOU X0, 96(DI) + MOVOU X0, 112(DI) + MOVOU X0, 128(DI) + MOVOU X0, 144(DI) + MOVOU X0, 160(DI) + MOVOU X0, 176(DI) + MOVOU X0, 192(DI) + MOVOU X0, 208(DI) + MOVOU X0, 224(DI) + MOVOU X0, 240(DI) + SUBQ $256, BX + ADDQ $256, DI + CMPQ BX, $256 + JAE loop + JMP tail + +_1or2: + MOVB AX, (DI) + MOVB AX, -1(DI)(BX*1) + RET +_0: + RET +_3or4: + MOVW AX, (DI) + MOVW AX, -2(DI)(BX*1) + RET +_5through7: + MOVL AX, (DI) + MOVL AX, -4(DI)(BX*1) + RET +_8: + // We need a separate case for 8 to make sure we clear pointers atomically. + MOVQ AX, (DI) + RET +_9through16: + MOVQ AX, (DI) + MOVQ AX, -8(DI)(BX*1) + RET +_17through32: + MOVOU X0, (DI) + MOVOU X0, -16(DI)(BX*1) + RET +_33through64: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET +_65through128: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, -64(DI)(BX*1) + MOVOU X0, -48(DI)(BX*1) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET +_129through256: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, 64(DI) + MOVOU X0, 80(DI) + MOVOU X0, 96(DI) + MOVOU X0, 112(DI) + MOVOU X0, -128(DI)(BX*1) + MOVOU X0, -112(DI)(BX*1) + MOVOU X0, -96(DI)(BX*1) + MOVOU X0, -80(DI)(BX*1) + MOVOU X0, -64(DI)(BX*1) + MOVOU X0, -48(DI)(BX*1) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET diff --git a/pkg/sentry/platform/safecopy/memcpy_amd64.s b/pkg/sentry/platform/safecopy/memcpy_amd64.s new file mode 100644 index 000000000..96ef2eefc --- /dev/null +++ b/pkg/sentry/platform/safecopy/memcpy_amd64.s @@ -0,0 +1,242 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleMemcpyFault returns (the value stored in AX, the value stored in DI). +// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS, +// with the faulting address stored in AX and the signal number stored in DI. +// +// It must have the same frame configuration as memcpy so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemcpyFault(SB), NOSPLIT, $0-36 + MOVQ AX, addr+24(FP) + MOVL DI, sig+32(FP) + RET + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +// The code is derived from the forward copying part of runtime.memmove. +// +// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memcpy(SB), NOSPLIT, $0-36 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemcpyFault will store a different value in this address. + MOVL $0, sig+32(FP) + + MOVQ to+0(FP), DI + MOVQ from+8(FP), SI + MOVQ n+16(FP), BX + + // REP instructions have a high startup cost, so we handle small sizes + // with some straightline code. The REP MOVSQ instruction is really fast + // for large sizes. The cutover is approximately 2K. +tail: + // move_129through256 or smaller work whether or not the source and the + // destination memory regions overlap because they load all data into + // registers before writing it back. move_256through2048 on the other + // hand can be used only when the memory regions don't overlap or the copy + // direction is forward. + TESTQ BX, BX + JEQ move_0 + CMPQ BX, $2 + JBE move_1or2 + CMPQ BX, $4 + JBE move_3or4 + CMPQ BX, $8 + JB move_5through7 + JE move_8 + CMPQ BX, $16 + JBE move_9through16 + CMPQ BX, $32 + JBE move_17through32 + CMPQ BX, $64 + JBE move_33through64 + CMPQ BX, $128 + JBE move_65through128 + CMPQ BX, $256 + JBE move_129through256 + // TODO: use branch table and BSR to make this just a single dispatch + +/* + * forward copy loop + */ + CMPQ BX, $2048 + JLS move_256through2048 + + // Check alignment + MOVL SI, AX + ORL DI, AX + TESTL $7, AX + JEQ fwdBy8 + + // Do 1 byte at a time + MOVQ BX, CX + REP; MOVSB + RET + +fwdBy8: + // Do 8 bytes at a time + MOVQ BX, CX + SHRQ $3, CX + ANDQ $7, BX + REP; MOVSQ + JMP tail + +move_1or2: + MOVB (SI), AX + MOVB AX, (DI) + MOVB -1(SI)(BX*1), CX + MOVB CX, -1(DI)(BX*1) + RET +move_0: + RET +move_3or4: + MOVW (SI), AX + MOVW AX, (DI) + MOVW -2(SI)(BX*1), CX + MOVW CX, -2(DI)(BX*1) + RET +move_5through7: + MOVL (SI), AX + MOVL AX, (DI) + MOVL -4(SI)(BX*1), CX + MOVL CX, -4(DI)(BX*1) + RET +move_8: + // We need a separate case for 8 to make sure we write pointers atomically. + MOVQ (SI), AX + MOVQ AX, (DI) + RET +move_9through16: + MOVQ (SI), AX + MOVQ AX, (DI) + MOVQ -8(SI)(BX*1), CX + MOVQ CX, -8(DI)(BX*1) + RET +move_17through32: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU -16(SI)(BX*1), X1 + MOVOU X1, -16(DI)(BX*1) + RET +move_33through64: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU -32(SI)(BX*1), X2 + MOVOU X2, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X3 + MOVOU X3, -16(DI)(BX*1) + RET +move_65through128: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU -64(SI)(BX*1), X4 + MOVOU X4, -64(DI)(BX*1) + MOVOU -48(SI)(BX*1), X5 + MOVOU X5, -48(DI)(BX*1) + MOVOU -32(SI)(BX*1), X6 + MOVOU X6, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X7 + MOVOU X7, -16(DI)(BX*1) + RET +move_129through256: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU -128(SI)(BX*1), X8 + MOVOU X8, -128(DI)(BX*1) + MOVOU -112(SI)(BX*1), X9 + MOVOU X9, -112(DI)(BX*1) + MOVOU -96(SI)(BX*1), X10 + MOVOU X10, -96(DI)(BX*1) + MOVOU -80(SI)(BX*1), X11 + MOVOU X11, -80(DI)(BX*1) + MOVOU -64(SI)(BX*1), X12 + MOVOU X12, -64(DI)(BX*1) + MOVOU -48(SI)(BX*1), X13 + MOVOU X13, -48(DI)(BX*1) + MOVOU -32(SI)(BX*1), X14 + MOVOU X14, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X15 + MOVOU X15, -16(DI)(BX*1) + RET +move_256through2048: + SUBQ $256, BX + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU 128(SI), X8 + MOVOU X8, 128(DI) + MOVOU 144(SI), X9 + MOVOU X9, 144(DI) + MOVOU 160(SI), X10 + MOVOU X10, 160(DI) + MOVOU 176(SI), X11 + MOVOU X11, 176(DI) + MOVOU 192(SI), X12 + MOVOU X12, 192(DI) + MOVOU 208(SI), X13 + MOVOU X13, 208(DI) + MOVOU 224(SI), X14 + MOVOU X14, 224(DI) + MOVOU 240(SI), X15 + MOVOU X15, 240(DI) + CMPQ BX, $256 + LEAQ 256(SI), SI + LEAQ 256(DI), DI + JGE move_256through2048 + JMP tail diff --git a/pkg/sentry/platform/safecopy/safecopy.go b/pkg/sentry/platform/safecopy/safecopy.go new file mode 100644 index 000000000..90a2aad7b --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy.go @@ -0,0 +1,140 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package safecopy provides an efficient implementation of functions to access +// memory that may result in SIGSEGV or SIGBUS being sent to the accessor. +package safecopy + +import ( + "fmt" + "reflect" + "runtime" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/syserror" +) + +// SegvError is returned when a safecopy function receives SIGSEGV. +type SegvError struct { + // Addr is the address at which the SIGSEGV occurred. + Addr uintptr +} + +// Error implements error.Error. +func (e SegvError) Error() string { + return fmt.Sprintf("SIGSEGV at %#x", e.Addr) +} + +// BusError is returned when a safecopy function receives SIGBUS. +type BusError struct { + // Addr is the address at which the SIGBUS occurred. + Addr uintptr +} + +// Error implements error.Error. +func (e BusError) Error() string { + return fmt.Sprintf("SIGBUS at %#x", e.Addr) +} + +// AlignmentError is returned when a safecopy function is passed an address +// that does not meet alignment requirements. +type AlignmentError struct { + // Addr is the invalid address. + Addr uintptr + + // Alignment is the required alignment. + Alignment uintptr +} + +// Error implements error.Error. +func (e AlignmentError) Error() string { + return fmt.Sprintf("address %#x is not aligned to a %d-byte boundary", e.Addr, e.Alignment) +} + +var ( + // The begin and end addresses below are for the functions that are + // checked by the signal handler. + memcpyBegin uintptr + memcpyEnd uintptr + memclrBegin uintptr + memclrEnd uintptr + swapUint32Begin uintptr + swapUint32End uintptr + swapUint64Begin uintptr + swapUint64End uintptr + compareAndSwapUint32Begin uintptr + compareAndSwapUint32End uintptr + + // savedSigSegVHandler is a pointer to the SIGSEGV handler that was + // configured before we replaced it with our own. We still call into it + // when we get a SIGSEGV that is not interesting to us. + savedSigSegVHandler uintptr + + // same a above, but for SIGBUS signals. + savedSigBusHandler uintptr +) + +// signalHandler is our replacement signal handler for SIGSEGV and SIGBUS +// signals. +func signalHandler() + +// FindEndAddress returns the end address (one byte beyond the last) of the +// function that contains the specified address (begin). +func FindEndAddress(begin uintptr) uintptr { + f := runtime.FuncForPC(begin) + if f != nil { + for p := begin; ; p++ { + g := runtime.FuncForPC(p) + if f != g { + return p + } + } + } + return begin +} + +// initializeAddresses initializes the addresses used by the signal handler. +func initializeAddresses() { + // The following functions are written in assembly language, so they won't + // be inlined by the existing compiler/linker. Tests will fail if this + // assumption is violated. + memcpyBegin = reflect.ValueOf(memcpy).Pointer() + memcpyEnd = FindEndAddress(memcpyBegin) + memclrBegin = reflect.ValueOf(memclr).Pointer() + memclrEnd = FindEndAddress(memclrBegin) + swapUint32Begin = reflect.ValueOf(swapUint32).Pointer() + swapUint32End = FindEndAddress(swapUint32Begin) + swapUint64Begin = reflect.ValueOf(swapUint64).Pointer() + swapUint64End = FindEndAddress(swapUint64Begin) + compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer() + compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin) +} + +func init() { + initializeAddresses() + if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) + } + if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) + } + syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) { + switch e.(type) { + case SegvError, BusError, AlignmentError: + return syscall.EFAULT, true + default: + return 0, false + } + }) +} diff --git a/pkg/sentry/platform/safecopy/safecopy_test.go b/pkg/sentry/platform/safecopy/safecopy_test.go new file mode 100644 index 000000000..67df36121 --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy_test.go @@ -0,0 +1,617 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safecopy + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "runtime/debug" + "syscall" + "testing" + "unsafe" +) + +// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular +// dependency. +const pageSize = 4096 + +func initRandom(b []byte) { + for i := range b { + b[i] = byte(rand.Intn(256)) + } +} + +func randBuf(size int) []byte { + b := make([]byte, size) + initRandom(b) + return b +} + +func TestCopyInSuccess(t *testing.T) { + // Test that CopyIn does not return an error when all pages are accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := CopyIn(b, unsafe.Pointer(&a[0])) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestCopyOutSuccess(t *testing.T) { + // Test that CopyOut does not return an error when all pages are + // accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := CopyOut(unsafe.Pointer(&b[0]), a) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestCopySuccess(t *testing.T) { + // Test that Copy does not return an error when all pages are accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestZeroOutSuccess(t *testing.T) { + // Test that ZeroOut does not return an error when all pages are + // accessible. + const bufLen = 8192 + a := make([]byte, bufLen) + b := randBuf(bufLen) + + n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestSwapUint32Success(t *testing.T) { + // Test that SwapUint32 does not return an error when the page is + // accessible. + before := uint32(rand.Int31()) + after := uint32(rand.Int31()) + val := before + + old, err := SwapUint32(unsafe.Pointer(&val), after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if val != after { + t.Errorf("Unexpected new value: got %v, want %v", val, after) + } +} + +func TestSwapUint32AlignmentError(t *testing.T) { + // Test that SwapUint32 returns an AlignmentError when passed an unaligned + // address. + data := new(struct{ val uint64 }) + addr := uintptr(unsafe.Pointer(&data.val)) + 1 + want := AlignmentError{Addr: addr, Alignment: 4} + if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +func TestSwapUint64Success(t *testing.T) { + // Test that SwapUint64 does not return an error when the page is + // accessible. + before := uint64(rand.Int63()) + after := uint64(rand.Int63()) + // "The first word in ... an allocated struct or slice can be relied upon + // to be 64-bit aligned." - sync/atomic docs + data := new(struct{ val uint64 }) + data.val = before + + old, err := SwapUint64(unsafe.Pointer(&data.val), after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if data.val != after { + t.Errorf("Unexpected new value: got %v, want %v", data.val, after) + } +} + +func TestSwapUint64AlignmentError(t *testing.T) { + // Test that SwapUint64 returns an AlignmentError when passed an unaligned + // address. + data := new(struct{ val1, val2 uint64 }) + addr := uintptr(unsafe.Pointer(&data.val1)) + 1 + want := AlignmentError{Addr: addr, Alignment: 8} + if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +func TestCompareAndSwapUint32Success(t *testing.T) { + // Test that CompareAndSwapUint32 does not return an error when the page is + // accessible. + before := uint32(rand.Int31()) + after := uint32(rand.Int31()) + val := before + + old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if val != after { + t.Errorf("Unexpected new value: got %v, want %v", val, after) + } +} + +func TestCompareAndSwapUint32AlignmentError(t *testing.T) { + // Test that CompareAndSwapUint32 returns an AlignmentError when passed an + // unaligned address. + data := new(struct{ val uint64 }) + addr := uintptr(unsafe.Pointer(&data.val)) + 1 + want := AlignmentError{Addr: addr, Alignment: 4} + if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +// withSegvErrorTestMapping calls fn with a two-page mapping. The first page +// contains random data, and the second page generates SIGSEGV when accessed. +func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { + mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + } + defer syscall.Munmap(mapping) + if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil { + t.Fatalf("Mprotect failed: %v", err) + } + initRandom(mapping[:pageSize]) + + fn(mapping) +} + +// withBusErrorTestMapping calls fn with a two-page mapping. The first page +// contains random data, and the second page generates SIGBUS when accessed. +func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + defer f.Close() + if err := f.Truncate(pageSize); err != nil { + t.Fatalf("Truncate failed: %v", err) + } + mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + } + defer syscall.Munmap(mapping) + initRandom(mapping[:pageSize]) + + fn(mapping) +} + +func TestCopyInSegvError(t *testing.T) { + // Test that CopyIn returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := CopyIn(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyInBusError(t *testing.T) { + // Test that CopyIn returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := CopyIn(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyOutSegvError(t *testing.T) { + // Test that CopyOut returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := CopyOut(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyOutBusError(t *testing.T) { + // Test that CopyOut returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := CopyOut(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopySourceSegvError(t *testing.T) { + // Test that Copy returns a SegvError when copying from a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopySourceBusError(t *testing.T) { + // Test that Copy returns a BusError when copying from a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyDestinationSegvError(t *testing.T) { + // Test that Copy returns a SegvError when copying to a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyDestinationBusError(t *testing.T) { + // Test that Copy returns a BusError when copying to a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestZeroOutSegvError(t *testing.T) { + // Test that ZeroOut returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + n, err := ZeroOut(dst, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { + t.Errorf("Non-zero bytes in written part of mapping: %v", got) + } + }) + }) + } +} + +func TestZeroOutBusError(t *testing.T) { + // Test that ZeroOut returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + n, err := ZeroOut(dst, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { + t.Errorf("Non-zero bytes in written part of mapping: %v", got) + } + }) + }) + } +} + +func TestSwapUint32SegvError(t *testing.T) { + // Test that SwapUint32 returns a SegvError when reaching a page that + // signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint32(unsafe.Pointer(secondPage), 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint32BusError(t *testing.T) { + // Test that SwapUint32 returns a BusError when reaching a page that + // signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint32(unsafe.Pointer(secondPage), 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint64SegvError(t *testing.T) { + // Test that SwapUint64 returns a SegvError when reaching a page that + // signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint64(unsafe.Pointer(secondPage), 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint64BusError(t *testing.T) { + // Test that SwapUint64 returns a BusError when reaching a page that + // signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint64(unsafe.Pointer(secondPage), 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestCompareAndSwapUint32SegvError(t *testing.T) { + // Test that CompareAndSwapUint32 returns a SegvError when reaching a page + // that signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestCompareAndSwapUint32BusError(t *testing.T) { + // Test that CompareAndSwapUint32 returns a BusError when reaching a page + // that signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func testCopy(dst, src []byte) (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + debug.SetPanicOnFault(true) + copy(dst, src) + return +} + +func TestSegVOnMemmove(t *testing.T) { + // Test that SIGSEGVs received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer syscall.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} + +func TestSigbusOnMemmove(t *testing.T) { + // Test that SIGBUS received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + os.Remove(f.Name()) + defer f.Close() + + a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer syscall.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} diff --git a/pkg/sentry/platform/safecopy/safecopy_unsafe.go b/pkg/sentry/platform/safecopy/safecopy_unsafe.go new file mode 100644 index 000000000..72f243f8d --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy_unsafe.go @@ -0,0 +1,315 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safecopy + +import ( + "fmt" + "syscall" + "unsafe" +) + +// maxRegisterSize is the maximum register size used in memcpy and memclr. It +// is used to decide by how much to rewind the copy (for memcpy) or zeroing +// (for memclr) before proceeding. +const maxRegisterSize = 16 + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +//go:noescape +func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) + +// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS +// signal is received during the write, it returns the address that caused the +// fault and the number of the signal that was received. Otherwise, it returns +// an unspecified address and a signal number of 0. +// +// Data is written in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully written. +// +//go:noescape +func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) + +// swapUint32 atomically stores new into *ptr and returns (the previous *ptr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) + +// swapUint64 atomically stores new into *ptr and returns (the previous *ptr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: ptr must be aligned to a 8-byte boundary. +// +//go:noescape +func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) + +// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns +// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is +// received during the operation, the value of prev is unspecified, and sig is +// the number of the signal that was received. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) + +// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes +// copied and an error if SIGSEGV or SIGBUS is received while reading from src. +func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { + toCopy := uintptr(len(dst)) + if len(dst) == 0 { + return 0, nil + } + + fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy) + if sig == 0 { + return len(dst), nil + } + + if faultN, srcN := uintptr(fault), uintptr(src); faultN < srcN && faultN >= srcN+toCopy { + panic(fmt.Sprintf("CopyIn faulted at %#x, which is outside source [%#x, %#x)", faultN, srcN, srcN+toCopy)) + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + faultN, srcN := uintptr(fault), uintptr(src) + var done int + if faultN-srcN > maxRegisterSize { + done = int(faultN - srcN - maxRegisterSize) + } + n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done))) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// CopyOut copies len(src) bytes from src to dst. If returns the number of +// bytes done and an error if SIGSEGV or SIGBUS is received while writing to +// dst. +func CopyOut(dst unsafe.Pointer, src []byte) (int, error) { + toCopy := uintptr(len(src)) + if toCopy == 0 { + return 0, nil + } + + fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy) + if sig == 0 { + return len(src), nil + } + + if faultN, dstN := uintptr(fault), uintptr(dst); faultN < dstN && faultN >= dstN+toCopy { + panic(fmt.Sprintf("CopyOut faulted at %#x, which is outside destination [%#x, %#x)", faultN, dstN, dstN+toCopy)) + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + faultN, dstN := uintptr(fault), uintptr(dst) + var done int + if faultN-dstN > maxRegisterSize { + done = int(faultN - dstN - maxRegisterSize) + } + n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)]) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// Copy copies toCopy bytes from src to dst. It returns the number of bytes +// copied and an error if SIGSEGV or SIGBUS is received while reading from src +// or writing to dst. +// +// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap, +// the resulting contents of dst are unspecified. +func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) { + if toCopy == 0 { + return 0, nil + } + + fault, sig := memcpy(dst, src, toCopy) + if sig == 0 { + return toCopy, nil + } + + // Did the fault occur while reading from src or writing to dst? + faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst) + faultAfterSrc := ^uintptr(0) + if faultN >= srcN { + faultAfterSrc = faultN - srcN + } + faultAfterDst := ^uintptr(0) + if faultN >= dstN { + faultAfterDst = faultN - dstN + } + if faultAfterSrc >= toCopy && faultAfterDst >= toCopy { + panic(fmt.Sprintf("Copy faulted at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", faultN, srcN, srcN+toCopy, dstN, dstN+toCopy)) + } + faultedAfter := faultAfterSrc + if faultedAfter > faultAfterDst { + faultedAfter = faultAfterDst + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + var done uintptr + if faultedAfter > maxRegisterSize { + done = faultedAfter - maxRegisterSize + } + n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes +// written and an error if SIGSEGV or SIGBUS is received while writing to dst. +func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) { + if toZero == 0 { + return 0, nil + } + + fault, sig := memclr(dst, toZero) + if sig == 0 { + return toZero, nil + } + + if faultN, dstN := uintptr(fault), uintptr(dst); faultN < dstN && faultN >= dstN+toZero { + panic(fmt.Sprintf("ZeroOut faulted at %#x, which is outside destination [%#x, %#x)", faultN, dstN, dstN+toZero)) + } + + // memclr might have ended the write up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to write up to the fault. + faultN, dstN := uintptr(fault), uintptr(dst) + var done uintptr + if faultN-dstN > maxRegisterSize { + done = faultN - dstN - maxRegisterSize + } + n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns +// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is +// not aligned to a 4-byte boundary. +func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + old, sig := swapUint32(ptr, new) + return old, errorFromFaultSignal(ptr, sig) +} + +// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns +// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is +// not aligned to an 8-byte boundary. +func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) { + if addr := uintptr(ptr); addr&7 != 0 { + return 0, AlignmentError{addr, 8} + } + old, sig := swapUint64(ptr, new) + return old, errorFromFaultSignal(ptr, sig) +} + +// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32, +// except that it returns an error if SIGSEGV or SIGBUS is received while +// accessing ptr, or if ptr is not aligned to a 4-byte boundary. +func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + prev, sig := compareAndSwapUint32(ptr, old, new) + return prev, errorFromFaultSignal(ptr, sig) +} + +func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error { + switch sig { + case 0: + return nil + case int32(syscall.SIGSEGV): + return SegvError{uintptr(addr)} + case int32(syscall.SIGBUS): + return BusError{uintptr(addr)} + default: + panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) + } +} + +// ReplaceSignalHandler replaces the existing signal handler for the provided +// signal with the one that handles faults in safecopy-protected functions. +// +// It stores the value of the previously set handler in previous. +// +// This function will be called on initialization in order to install safecopy +// handlers for appropriate signals. These handlers will call the previous +// handler however, and if this is function is being used externally then the +// same courtesy is expected. +func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error { + var sa struct { + handler uintptr + flags uint64 + restorer uintptr + mask uint64 + } + const maskLen = 8 + + // Get the existing signal handler information, and save the current + // handler. Once we replace it, we will use this pointer to fall back to + // it when we receive other signals. + if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { + return e + } + + // Fail if there isn't a previous handler. + if sa.handler == 0 { + return fmt.Errorf("previous handler for signal %x isn't set", sig) + } + + *previous = sa.handler + + // Install our own handler. + sa.handler = handler + if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { + return e + } + + return nil +} diff --git a/pkg/sentry/platform/safecopy/sighandler_amd64.s b/pkg/sentry/platform/safecopy/sighandler_amd64.s new file mode 100644 index 000000000..a65cb0c26 --- /dev/null +++ b/pkg/sentry/platform/safecopy/sighandler_amd64.s @@ -0,0 +1,124 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// The signals handled by sigHandler. +#define SIGBUS 7 +#define SIGSEGV 11 + +// Offsets to the registers in context->uc_mcontext.gregs[]. +#define REG_RDI 0x68 +#define REG_RAX 0x90 +#define REG_IP 0xa8 + +// Offset to the si_addr field of siginfo. +#define SI_CODE 0x08 +#define SI_ADDR 0x10 + +// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must +// not be set up as a handler to any other signals. +// +// If the instruction causing the signal is within a safecopy-protected +// function, the signal is handled such that execution resumes in the +// appropriate fault handling stub with AX containing the faulting address and +// DI containing the signal number. Otherwise control is transferred to the +// previously configured signal handler (savedSigSegvHandler or +// savedSigBusHandler). +// +// This function cannot be written in go because it runs whenever a signal is +// received by the thread (preempting whatever was running), which includes when +// garbage collector has stopped or isn't expecting any interactions (like +// barriers). +// +// The arguments are the following: +// DI - The signal number. +// SI - Pointer to siginfo_t structure. +// DX - Pointer to ucontext structure. +TEXT ·signalHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $0x0, CX + CMPL CX, SI_CODE(SI) + JGE original_handler + + // Check if RIP is within the area we care about. + MOVQ REG_IP(DX), CX + CMPQ CX, ·memcpyBegin(SB) + JB not_memcpy + CMPQ CX, ·memcpyEnd(SB) + JAE not_memcpy + + // Modify the context such that execution will resume in the fault + // handler. + LEAQ handleMemcpyFault(SB), CX + JMP handle_fault + +not_memcpy: + CMPQ CX, ·memclrBegin(SB) + JB not_memclr + CMPQ CX, ·memclrEnd(SB) + JAE not_memclr + + LEAQ handleMemclrFault(SB), CX + JMP handle_fault + +not_memclr: + CMPQ CX, ·swapUint32Begin(SB) + JB not_swapuint32 + CMPQ CX, ·swapUint32End(SB) + JAE not_swapuint32 + + LEAQ handleSwapUint32Fault(SB), CX + JMP handle_fault + +not_swapuint32: + CMPQ CX, ·swapUint64Begin(SB) + JB not_swapuint64 + CMPQ CX, ·swapUint64End(SB) + JAE not_swapuint64 + + LEAQ handleSwapUint64Fault(SB), CX + JMP handle_fault + +not_swapuint64: + CMPQ CX, ·compareAndSwapUint32Begin(SB) + JB not_casuint32 + CMPQ CX, ·compareAndSwapUint32End(SB) + JAE not_casuint32 + + LEAQ handleCompareAndSwapUint32Fault(SB), CX + JMP handle_fault + +not_casuint32: +original_handler: + // Jump to the previous signal handler, which is likely the golang one. + XORQ CX, CX + MOVQ ·savedSigBusHandler(SB), AX + CMPL DI, $SIGSEGV + CMOVQEQ ·savedSigSegVHandler(SB), AX + JMP AX + +handle_fault: + // Entered with the address of the fault handler in RCX; store it in + // RIP. + MOVQ CX, REG_IP(DX) + + // Store the faulting address in RAX. + MOVQ SI_ADDR(SI), CX + MOVQ CX, REG_RAX(DX) + + // Store the signal number in EDI. + MOVL DI, REG_RDI(DX) + + RET |