// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package safecopy import ( "fmt" "runtime" "unsafe" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" ) // maxRegisterSize is the maximum register size used in memcpy and memclr. It // is used to decide by how much to rewind the copy (for memcpy) or zeroing // (for memclr) before proceeding. const maxRegisterSize = 16 // memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received // during the copy, it returns the address that caused the fault and the number // of the signal that was received. Otherwise, it returns an unspecified address // and a signal number of 0. // // Data is copied in order, such that if a fault happens at address p, it is // safe to assume that all data before p-maxRegisterSize has already been // successfully copied. // //go:noescape func memcpy(dst, src uintptr, n uintptr) (fault uintptr, sig int32) // memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS // signal is received during the write, it returns the address that caused the // fault and the number of the signal that was received. Otherwise, it returns // an unspecified address and a signal number of 0. // // Data is written in order, such that if a fault happens at address p, it is // safe to assume that all data before p-maxRegisterSize has already been // successfully written. // //go:noescape func memclr(ptr uintptr, n uintptr) (fault uintptr, sig int32) // swapUint32 atomically stores new into *ptr and returns (the previous *ptr // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // // Preconditions: ptr must be aligned to a 4-byte boundary. // //go:noescape func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) // swapUint64 atomically stores new into *ptr and returns (the previous *ptr // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // // Preconditions: ptr must be aligned to a 8-byte boundary. // //go:noescape func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) // compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns // (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is // received during the operation, the value of prev is unspecified, and sig is // the number of the signal that was received. // // Preconditions: ptr must be aligned to a 4-byte boundary. // //go:noescape func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) // LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It // may fail with SIGSEGV or SIGBUS if it is received while reading from ptr. // // Preconditions: ptr must be aligned to a 4-byte boundary. // //go:noescape func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) // Return the start address of the functions above. // // In Go 1.17+, Go references to assembly functions resolve to an ABIInternal // wrapper function rather than the function itself. We must reference from // assembly to get the ABI0 (i.e., primary) address. func addrOfMemcpy() uintptr func addrOfMemclr() uintptr func addrOfSwapUint32() uintptr func addrOfSwapUint64() uintptr func addrOfCompareAndSwapUint32() uintptr func addrOfLoadUint32() uintptr // CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes // copied and an error if SIGSEGV or SIGBUS is received while reading from src. func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { n, err := copyIn(dst, uintptr(src)) runtime.KeepAlive(src) return n, err } // copyIn is the underlying definition for CopyIn. func copyIn(dst []byte, src uintptr) (int, error) { toCopy := uintptr(len(dst)) if len(dst) == 0 { return 0, nil } fault, sig := memcpy(uintptr(unsafe.Pointer(&dst[0])), src, toCopy) if sig == 0 { return len(dst), nil } if fault < src || fault >= src+toCopy { panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, fault, src, src+toCopy)) } // memcpy might have ended the copy up to maxRegisterSize bytes before // fault, if an instruction caused a memory access that straddled two // pages, and the second one faulted. Try to copy up to the fault. var done int if fault-src > maxRegisterSize { done = int(fault - src - maxRegisterSize) } n, err := copyIn(dst[done:int(fault-src)], src+uintptr(done)) done += n if err != nil { return done, err } return done, errorFromFaultSignal(fault, sig) } // CopyOut copies len(src) bytes from src to dst. If returns the number of // bytes done and an error if SIGSEGV or SIGBUS is received while writing to // dst. func CopyOut(dst unsafe.Pointer, src []byte) (int, error) { n, err := copyOut(uintptr(dst), src) runtime.KeepAlive(dst) return n, err } // copyOut is the underlying definition for CopyOut. func copyOut(dst uintptr, src []byte) (int, error) { toCopy := uintptr(len(src)) if toCopy == 0 { return 0, nil } fault, sig := memcpy(dst, uintptr(unsafe.Pointer(&src[0])), toCopy) if sig == 0 { return len(src), nil } if fault < dst || fault >= dst+toCopy { panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toCopy)) } // memcpy might have ended the copy up to maxRegisterSize bytes before // fault, if an instruction caused a memory access that straddled two // pages, and the second one faulted. Try to copy up to the fault. var done int if fault-dst > maxRegisterSize { done = int(fault - dst - maxRegisterSize) } n, err := copyOut(dst+uintptr(done), src[done:int(fault-dst)]) done += n if err != nil { return done, err } return done, errorFromFaultSignal(fault, sig) } // Copy copies toCopy bytes from src to dst. It returns the number of bytes // copied and an error if SIGSEGV or SIGBUS is received while reading from src // or writing to dst. // // Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap, // the resulting contents of dst are unspecified. func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) { n, err := copyN(uintptr(dst), uintptr(src), toCopy) runtime.KeepAlive(dst) runtime.KeepAlive(src) return n, err } // copyN is the underlying definition for Copy. func copyN(dst, src uintptr, toCopy uintptr) (uintptr, error) { if toCopy == 0 { return 0, nil } fault, sig := memcpy(dst, src, toCopy) if sig == 0 { return toCopy, nil } // Did the fault occur while reading from src or writing to dst? faultAfterSrc := ^uintptr(0) if fault >= src { faultAfterSrc = fault - src } faultAfterDst := ^uintptr(0) if fault >= dst { faultAfterDst = fault - dst } if faultAfterSrc >= toCopy && faultAfterDst >= toCopy { panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, fault, src, src+toCopy, dst, dst+toCopy)) } faultedAfter := faultAfterSrc if faultedAfter > faultAfterDst { faultedAfter = faultAfterDst } // memcpy might have ended the copy up to maxRegisterSize bytes before // fault, if an instruction caused a memory access that straddled two // pages, and the second one faulted. Try to copy up to the fault. var done uintptr if faultedAfter > maxRegisterSize { done = faultedAfter - maxRegisterSize } n, err := copyN(dst+done, src+done, faultedAfter-done) done += n if err != nil { return done, err } return done, errorFromFaultSignal(fault, sig) } // ZeroOut writes toZero zero bytes to dst. It returns the number of bytes // written and an error if SIGSEGV or SIGBUS is received while writing to dst. func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) { n, err := zeroOut(uintptr(dst), toZero) runtime.KeepAlive(dst) return n, err } // zeroOut is the underlying definition for ZeroOut. func zeroOut(dst uintptr, toZero uintptr) (uintptr, error) { if toZero == 0 { return 0, nil } fault, sig := memclr(dst, toZero) if sig == 0 { return toZero, nil } if fault < dst || fault >= dst+toZero { panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toZero)) } // memclr might have ended the write up to maxRegisterSize bytes before // fault, if an instruction caused a memory access that straddled two // pages, and the second one faulted. Try to write up to the fault. var done uintptr if fault-dst > maxRegisterSize { done = fault - dst - maxRegisterSize } n, err := zeroOut(dst+done, fault-dst-done) done += n if err != nil { return done, err } return done, errorFromFaultSignal(fault, sig) } // SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns // an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is // not aligned to a 4-byte boundary. func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) { if addr := uintptr(ptr); addr&3 != 0 { return 0, AlignmentError{addr, 4} } old, sig := swapUint32(ptr, new) return old, errorFromFaultSignal(uintptr(ptr), sig) } // SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns // an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is // not aligned to an 8-byte boundary. func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) { if addr := uintptr(ptr); addr&7 != 0 { return 0, AlignmentError{addr, 8} } old, sig := swapUint64(ptr, new) return old, errorFromFaultSignal(uintptr(ptr), sig) } // CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32, // except that it returns an error if SIGSEGV or SIGBUS is received while // accessing ptr, or if ptr is not aligned to a 4-byte boundary. func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) { if addr := uintptr(ptr); addr&3 != 0 { return 0, AlignmentError{addr, 4} } prev, sig := compareAndSwapUint32(ptr, old, new) return prev, errorFromFaultSignal(uintptr(ptr), sig) } // LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It // may fail with SIGSEGV or SIGBUS if it is received while reading from ptr. // // Preconditions: ptr must be aligned to a 4-byte boundary. func LoadUint32(ptr unsafe.Pointer) (uint32, error) { if addr := uintptr(ptr); addr&3 != 0 { return 0, AlignmentError{addr, 4} } val, sig := loadUint32(ptr) return val, errorFromFaultSignal(uintptr(ptr), sig) } func errorFromFaultSignal(addr uintptr, sig int32) error { switch sig { case 0: return nil case int32(unix.SIGSEGV): return SegvError{addr} case int32(unix.SIGBUS): return BusError{addr} default: panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) } } // ReplaceSignalHandler replaces the existing signal handler for the provided // signal with the one that handles faults in safecopy-protected functions. // // It stores the value of the previously set handler in previous. // // This function will be called on initialization in order to install safecopy // handlers for appropriate signals. These handlers will call the previous // handler however, and if this is function is being used externally then the // same courtesy is expected. func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { var sa linux.SigAction const maskLen = 8 // Get the existing signal handler information, and save the current // handler. Once we replace it, we will use this pointer to fall back to // it when we receive other signals. if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { return e } // Fail if there isn't a previous handler. if sa.Handler == 0 { return fmt.Errorf("previous handler for signal %x isn't set", sig) } *previous = uintptr(sa.Handler) // Install our own handler. sa.Handler = uint64(handler) if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { return e } return nil }