summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sigchainlib/sigchain.cc92
1 files changed, 72 insertions, 20 deletions
diff --git a/sigchainlib/sigchain.cc b/sigchainlib/sigchain.cc
index 65f5e92acf..a7f73f717f 100644
--- a/sigchainlib/sigchain.cc
+++ b/sigchainlib/sigchain.cc
@@ -27,6 +27,7 @@
#endif
#include <algorithm>
+#include <atomic>
#include <initializer_list>
#include <mutex>
#include <type_traits>
@@ -151,38 +152,92 @@ __attribute__((constructor)) static void InitializeSignalChain() {
});
}
-static pthread_key_t GetHandlingSignalKey() {
- static pthread_key_t key;
+template <typename T>
+static constexpr bool IsPowerOfTwo(T x) {
+ static_assert(std::is_integral_v<T>, "T must be integral");
+ static_assert(std::is_unsigned_v<T>, "T must be unsigned");
+ return (x & (x - 1)) == 0;
+}
+
+template <typename T>
+static constexpr T RoundUp(T x, T n) {
+ return (x + n - 1) & -n;
+}
+// Use a bitmap to indicate which signal is being handled so that other
+// non-blocked signals are allowed to be handled, if raised.
+static constexpr size_t kSignalSetLength = _NSIG - 1;
+static constexpr size_t kNumSignalsPerKey = std::numeric_limits<uintptr_t>::digits;
+static_assert(IsPowerOfTwo(kNumSignalsPerKey));
+static constexpr size_t kHandlingSignalKeyCount =
+ RoundUp(kSignalSetLength, kNumSignalsPerKey) / kNumSignalsPerKey;
+
+// We rely on bionic's implementation of pthread_(get/set)specific being
+// async-signal safe.
+static pthread_key_t GetHandlingSignalKey(size_t idx) {
+ static pthread_key_t key[kHandlingSignalKeyCount];
static std::once_flag once;
std::call_once(once, []() {
- int rc = pthread_key_create(&key, nullptr);
- if (rc != 0) {
- fatal("failed to create sigchain pthread key: %s", strerror(rc));
+ for (size_t i = 0; i < kHandlingSignalKeyCount; i++) {
+ int rc = pthread_key_create(&key[i], nullptr);
+ if (rc != 0) {
+ fatal("failed to create sigchain pthread key: %s", strerror(rc));
+ }
}
});
- return key;
+ return key[idx];
}
static bool GetHandlingSignal() {
- void* result = pthread_getspecific(GetHandlingSignalKey());
- return reinterpret_cast<uintptr_t>(result);
+ for (size_t i = 0; i < kHandlingSignalKeyCount; i++) {
+ void* result = pthread_getspecific(GetHandlingSignalKey(i));
+ if (reinterpret_cast<uintptr_t>(result) != 0) {
+ return true;
+ }
+ }
+ return false;
}
-static void SetHandlingSignal(bool value) {
- pthread_setspecific(GetHandlingSignalKey(),
- reinterpret_cast<void*>(static_cast<uintptr_t>(value)));
+static bool GetHandlingSignal(int signo) {
+ size_t bit_idx = signo - 1;
+ size_t key_idx = bit_idx / kNumSignalsPerKey;
+ uintptr_t bit_mask = static_cast<uintptr_t>(1) << (bit_idx % kNumSignalsPerKey);
+ uintptr_t result =
+ reinterpret_cast<uintptr_t>(pthread_getspecific(GetHandlingSignalKey(key_idx)));
+ return result & bit_mask;
+}
+
+static bool SetHandlingSignal(int signo, bool value) {
+ // Use signal-fence to ensure that compiler doesn't reorder generated code
+ // across signal handlers.
+ size_t bit_idx = signo - 1;
+ size_t key_idx = bit_idx / kNumSignalsPerKey;
+ uintptr_t bit_mask = static_cast<uintptr_t>(1) << (bit_idx % kNumSignalsPerKey);
+ pthread_key_t key = GetHandlingSignalKey(key_idx);
+ std::atomic_signal_fence(std::memory_order_seq_cst);
+ uintptr_t bitmap = reinterpret_cast<uintptr_t>(pthread_getspecific(key));
+ bool ret = bitmap & bit_mask;
+ if (value) {
+ bitmap |= bit_mask;
+ } else {
+ bitmap &= ~bit_mask;
+ }
+ pthread_setspecific(key, reinterpret_cast<void*>(bitmap));
+ std::atomic_signal_fence(std::memory_order_seq_cst);
+ return ret;
}
class ScopedHandlingSignal {
public:
- ScopedHandlingSignal() : original_value_(GetHandlingSignal()) {
- }
+ ScopedHandlingSignal(int signo, bool set)
+ : signo_(signo),
+ original_value_(set ? SetHandlingSignal(signo, true) : GetHandlingSignal(signo)) {}
~ScopedHandlingSignal() {
- SetHandlingSignal(original_value_);
+ SetHandlingSignal(signo_, original_value_);
}
private:
+ int signo_;
bool original_value_;
};
@@ -338,7 +393,7 @@ class SignalChain {
// _NSIG is 1 greater than the highest valued signal, but signals start from 1.
// Leave an empty element at index 0 for convenience.
-static SignalChain chains[_NSIG + 1];
+static SignalChain chains[_NSIG];
static bool is_signal_hook_debuggable = false;
@@ -351,7 +406,7 @@ __attribute__((weak)) extern "C" bool android_handle_signal(int signal_number,
void SignalChain::Handler(int signo, siginfo_t* siginfo, void* ucontext_raw) {
// Try the special handlers first.
// If one of them crashes, we'll reenter this handler and pass that crash onto the user handler.
- if (!GetHandlingSignal()) {
+ if (!GetHandlingSignal(signo)) {
for (const auto& handler : chains[signo].special_handlers_) {
if (handler.sc_sigaction == nullptr) {
break;
@@ -364,10 +419,7 @@ void SignalChain::Handler(int signo, siginfo_t* siginfo, void* ucontext_raw) {
sigset_t previous_mask;
linked_sigprocmask(SIG_SETMASK, &handler.sc_mask, &previous_mask);
- ScopedHandlingSignal restorer;
- if (!handler_noreturn) {
- SetHandlingSignal(true);
- }
+ ScopedHandlingSignal restorer(signo, !handler_noreturn);
if (handler.sc_sigaction(signo, siginfo, ucontext_raw)) {
return;