Ensure we can Mutex::Wait on a recursively locked Mutex
Change-Id: I4cec5575976892933beebf3c8c01ebf2b0ccde23
diff --git a/src/mutex.cc b/src/mutex.cc
index 8fea616..c5551bd 100644
--- a/src/mutex.cc
+++ b/src/mutex.cc
@@ -190,7 +190,9 @@
void ConditionVariable::Wait(Mutex& mutex) {
CheckSafeToWait(mutex.rank_);
+ uint unlock_depth = UnlockBeforeWait(mutex);
CHECK_MUTEX_CALL(pthread_cond_wait, (&cond_, &mutex.mutex_));
+ RelockAfterWait(mutex, unlock_depth);
}
void ConditionVariable::TimedWait(Mutex& mutex, const timespec& ts) {
@@ -200,11 +202,31 @@
#define TIMEDWAIT pthread_cond_timedwait
#endif
CheckSafeToWait(mutex.rank_);
+ uint unlock_depth = UnlockBeforeWait(mutex);
int rc = TIMEDWAIT(&cond_, &mutex.mutex_, &ts);
+ RelockAfterWait(mutex, unlock_depth);
if (rc != 0 && rc != ETIMEDOUT) {
errno = rc;
PLOG(FATAL) << "TimedWait failed for " << name_;
}
}
+// Unlock a mutex down to depth == 1 so pthread conditional waiting can be used.
+// After waiting, use RelockAfterWait to restore the lock depth.
+uint32_t ConditionVariable::UnlockBeforeWait(Mutex& mutex) {
+ uint32_t unlock_count = 0;
+ CHECK_GT(mutex.GetDepth(), 0U);
+ while (mutex.GetDepth() != 1) {
+ mutex.Unlock();
+ unlock_count++;
+ }
+ return unlock_count;
+}
+
+void ConditionVariable::RelockAfterWait(Mutex& mutex, uint32_t unlock_count) {
+ for (uint32_t i = 0; i < unlock_count; i++) {
+ mutex.Lock();
+ }
+}
+
} // namespace art
diff --git a/src/mutex.h b/src/mutex.h
index 0e7c173..4c5d537 100644
--- a/src/mutex.h
+++ b/src/mutex.h
@@ -99,6 +99,9 @@
void TimedWait(Mutex& mutex, const timespec& ts);
private:
+ uint32_t UnlockBeforeWait(Mutex& mutex) NO_THREAD_SAFETY_ANALYSIS;
+ void RelockAfterWait(Mutex& mutex, uint32_t unlock_count) NO_THREAD_SAFETY_ANALYSIS;
+
pthread_cond_t cond_;
std::string name_;
DISALLOW_COPY_AND_ASSIGN(ConditionVariable);
diff --git a/src/mutex_test.cc b/src/mutex_test.cc
index 4220f2b..69507d1 100644
--- a/src/mutex_test.cc
+++ b/src/mutex_test.cc
@@ -16,10 +16,12 @@
#include "mutex.h"
-#include "gtest/gtest.h"
+#include "common_test.h"
namespace art {
+class MutexTest : public CommonTest {};
+
struct MutexTester {
static void AssertDepth(Mutex& mu, uint32_t expected_depth) {
ASSERT_EQ(expected_depth, mu.GetDepth());
@@ -33,7 +35,7 @@
}
};
-TEST(Mutex, LockUnlock) {
+TEST_F(MutexTest, LockUnlock) {
Mutex mu("test mutex");
MutexTester::AssertDepth(mu, 0U);
mu.Lock();
@@ -52,7 +54,7 @@
MutexTester::AssertDepth(mu, 0U);
}
-TEST(Mutex, TryLockUnlock) {
+TEST_F(MutexTest, TryLockUnlock) {
TryLockUnlockTest();
}
@@ -70,7 +72,7 @@
MutexTester::AssertDepth(mu, 0U);
}
-TEST(Mutex, RecursiveLockUnlock) {
+TEST_F(MutexTest, RecursiveLockUnlock) {
RecursiveLockUnlockTest();
}
@@ -88,8 +90,46 @@
MutexTester::AssertDepth(mu, 0U);
}
-TEST(Mutex, RecursiveTryLockUnlock) {
+TEST_F(MutexTest, RecursiveTryLockUnlock) {
RecursiveTryLockUnlockTest();
}
+
+struct RecursiveLockWait {
+ explicit RecursiveLockWait() : mu("test mutex"), cv("test condition variable") {}
+
+ static void* Callback(void* arg) {
+ RecursiveLockWait* state = reinterpret_cast<RecursiveLockWait*>(arg);
+ state->mu.Lock();
+ state->cv.Signal();
+ state->mu.Unlock();
+ return NULL;
+ }
+
+ Mutex mu;
+ ConditionVariable cv;
+};
+
+// GCC has trouble with our mutex tests, so we have to turn off thread safety analysis.
+static void RecursiveLockWaitTest() NO_THREAD_SAFETY_ANALYSIS {
+ RecursiveLockWait state;
+ state.mu.Lock();
+ state.mu.Lock();
+
+ pthread_t pthread;
+ int pthread_create_result = pthread_create(&pthread, NULL, RecursiveLockWait::Callback, &state);
+ ASSERT_EQ(0, pthread_create_result);
+
+ state.cv.Wait(state.mu);
+
+ state.mu.Unlock();
+ state.mu.Unlock();
+}
+
+// This ensures we don't hang when waiting on a recursively locked mutex,
+// which is not supported with bare pthread_mutex_t.
+TEST_F(MutexTest, RecursiveLockWait) {
+ RecursiveLockWaitTest();
+}
+
} // namespace art