diff options
| -rw-r--r-- | sigchainlib/sigchain.cc | 92 |
1 files changed, 72 insertions, 20 deletions
diff --git a/sigchainlib/sigchain.cc b/sigchainlib/sigchain.cc index 65f5e92acf..a7f73f717f 100644 --- a/sigchainlib/sigchain.cc +++ b/sigchainlib/sigchain.cc @@ -27,6 +27,7 @@ #endif #include <algorithm> +#include <atomic> #include <initializer_list> #include <mutex> #include <type_traits> @@ -151,38 +152,92 @@ __attribute__((constructor)) static void InitializeSignalChain() { }); } -static pthread_key_t GetHandlingSignalKey() { - static pthread_key_t key; +template <typename T> +static constexpr bool IsPowerOfTwo(T x) { + static_assert(std::is_integral_v<T>, "T must be integral"); + static_assert(std::is_unsigned_v<T>, "T must be unsigned"); + return (x & (x - 1)) == 0; +} + +template <typename T> +static constexpr T RoundUp(T x, T n) { + return (x + n - 1) & -n; +} +// Use a bitmap to indicate which signal is being handled so that other +// non-blocked signals are allowed to be handled, if raised. +static constexpr size_t kSignalSetLength = _NSIG - 1; +static constexpr size_t kNumSignalsPerKey = std::numeric_limits<uintptr_t>::digits; +static_assert(IsPowerOfTwo(kNumSignalsPerKey)); +static constexpr size_t kHandlingSignalKeyCount = + RoundUp(kSignalSetLength, kNumSignalsPerKey) / kNumSignalsPerKey; + +// We rely on bionic's implementation of pthread_(get/set)specific being +// async-signal safe. +static pthread_key_t GetHandlingSignalKey(size_t idx) { + static pthread_key_t key[kHandlingSignalKeyCount]; static std::once_flag once; std::call_once(once, []() { - int rc = pthread_key_create(&key, nullptr); - if (rc != 0) { - fatal("failed to create sigchain pthread key: %s", strerror(rc)); + for (size_t i = 0; i < kHandlingSignalKeyCount; i++) { + int rc = pthread_key_create(&key[i], nullptr); + if (rc != 0) { + fatal("failed to create sigchain pthread key: %s", strerror(rc)); + } } }); - return key; + return key[idx]; } static bool GetHandlingSignal() { - void* result = pthread_getspecific(GetHandlingSignalKey()); - return reinterpret_cast<uintptr_t>(result); + for (size_t i = 0; i < kHandlingSignalKeyCount; i++) { + void* result = pthread_getspecific(GetHandlingSignalKey(i)); + if (reinterpret_cast<uintptr_t>(result) != 0) { + return true; + } + } + return false; } -static void SetHandlingSignal(bool value) { - pthread_setspecific(GetHandlingSignalKey(), - reinterpret_cast<void*>(static_cast<uintptr_t>(value))); +static bool GetHandlingSignal(int signo) { + size_t bit_idx = signo - 1; + size_t key_idx = bit_idx / kNumSignalsPerKey; + uintptr_t bit_mask = static_cast<uintptr_t>(1) << (bit_idx % kNumSignalsPerKey); + uintptr_t result = + reinterpret_cast<uintptr_t>(pthread_getspecific(GetHandlingSignalKey(key_idx))); + return result & bit_mask; +} + +static bool SetHandlingSignal(int signo, bool value) { + // Use signal-fence to ensure that compiler doesn't reorder generated code + // across signal handlers. + size_t bit_idx = signo - 1; + size_t key_idx = bit_idx / kNumSignalsPerKey; + uintptr_t bit_mask = static_cast<uintptr_t>(1) << (bit_idx % kNumSignalsPerKey); + pthread_key_t key = GetHandlingSignalKey(key_idx); + std::atomic_signal_fence(std::memory_order_seq_cst); + uintptr_t bitmap = reinterpret_cast<uintptr_t>(pthread_getspecific(key)); + bool ret = bitmap & bit_mask; + if (value) { + bitmap |= bit_mask; + } else { + bitmap &= ~bit_mask; + } + pthread_setspecific(key, reinterpret_cast<void*>(bitmap)); + std::atomic_signal_fence(std::memory_order_seq_cst); + return ret; } class ScopedHandlingSignal { public: - ScopedHandlingSignal() : original_value_(GetHandlingSignal()) { - } + ScopedHandlingSignal(int signo, bool set) + : signo_(signo), + original_value_(set ? SetHandlingSignal(signo, true) : GetHandlingSignal(signo)) {} ~ScopedHandlingSignal() { - SetHandlingSignal(original_value_); + SetHandlingSignal(signo_, original_value_); } private: + int signo_; bool original_value_; }; @@ -338,7 +393,7 @@ class SignalChain { // _NSIG is 1 greater than the highest valued signal, but signals start from 1. // Leave an empty element at index 0 for convenience. -static SignalChain chains[_NSIG + 1]; +static SignalChain chains[_NSIG]; static bool is_signal_hook_debuggable = false; @@ -351,7 +406,7 @@ __attribute__((weak)) extern "C" bool android_handle_signal(int signal_number, void SignalChain::Handler(int signo, siginfo_t* siginfo, void* ucontext_raw) { // Try the special handlers first. // If one of them crashes, we'll reenter this handler and pass that crash onto the user handler. - if (!GetHandlingSignal()) { + if (!GetHandlingSignal(signo)) { for (const auto& handler : chains[signo].special_handlers_) { if (handler.sc_sigaction == nullptr) { break; @@ -364,10 +419,7 @@ void SignalChain::Handler(int signo, siginfo_t* siginfo, void* ucontext_raw) { sigset_t previous_mask; linked_sigprocmask(SIG_SETMASK, &handler.sc_mask, &previous_mask); - ScopedHandlingSignal restorer; - if (!handler_noreturn) { - SetHandlingSignal(true); - } + ScopedHandlingSignal restorer(signo, !handler_noreturn); if (handler.sc_sigaction(signo, siginfo, ucontext_raw)) { return; |