Rewrite thread CustomTLS to support keyed TLS entries

Our Thread struct allowed plugins or other parts of the runtime to use
a single custom_tls_ void* pointer to store additional data in the
Thread struct. Unfortunately it was limited to a single value. This CL
changes the API to allow one to pass in a const char* key that is used
to address the TLS data. This lets multiple plugins or parts of the
runtime set their own independent TLS entries. This also adds support
for TLS entries actually being deallocated automatically.

Test: ./test.py --host
Change-Id: I40fa767b9c61a755b2ed910e4ad1e6327705e941
diff --git a/libartbase/base/safe_map.h b/libartbase/base/safe_map.h
index e08394e..a4d8459 100644
--- a/libartbase/base/safe_map.h
+++ b/libartbase/base/safe_map.h
@@ -129,7 +129,7 @@
   }
 
   template <typename CreateFn>
-  V GetOrCreate(const K& k, CreateFn create) {
+  V& GetOrCreate(const K& k, CreateFn create) {
     static_assert(std::is_same<V, typename std::result_of<CreateFn()>::type>::value,
                   "Argument `create` should return a value of type V.");
     auto lb = lower_bound(k);
diff --git a/openjdkjvmti/ti_thread.cc b/openjdkjvmti/ti_thread.cc
index cabf9e8..369eb76 100644
--- a/openjdkjvmti/ti_thread.cc
+++ b/openjdkjvmti/ti_thread.cc
@@ -60,6 +60,8 @@
 
 namespace openjdkjvmti {
 
+static const char* kJvmtiTlsKey = "JvmtiTlsKey";
+
 art::ArtField* ThreadUtil::context_class_loader_ = nullptr;
 
 struct ThreadCallback : public art::ThreadLifecycleCallback {
@@ -624,14 +626,15 @@
 // The struct that we store in the art::Thread::custom_tls_ that maps the jvmtiEnvs to the data
 // stored with that thread. This is needed since different jvmtiEnvs are not supposed to share TLS
 // data but we only have a single slot in Thread objects to store data.
-struct JvmtiGlobalTLSData {
+struct JvmtiGlobalTLSData : public art::TLSData {
   std::unordered_map<jvmtiEnv*, const void*> data GUARDED_BY(art::Locks::thread_list_lock_);
 };
 
 static void RemoveTLSData(art::Thread* target, void* ctx) REQUIRES(art::Locks::thread_list_lock_) {
   jvmtiEnv* env = reinterpret_cast<jvmtiEnv*>(ctx);
   art::Locks::thread_list_lock_->AssertHeld(art::Thread::Current());
-  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  JvmtiGlobalTLSData* global_tls =
+      reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS(kJvmtiTlsKey));
   if (global_tls != nullptr) {
     global_tls->data.erase(env);
   }
@@ -654,10 +657,12 @@
     return err;
   }
 
-  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  JvmtiGlobalTLSData* global_tls =
+      reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS(kJvmtiTlsKey));
   if (global_tls == nullptr) {
-    target->SetCustomTLS(new JvmtiGlobalTLSData);
-    global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+    // Synchronized using thread_list_lock_ to prevent racing sets.
+    target->SetCustomTLS(kJvmtiTlsKey, new JvmtiGlobalTLSData);
+    global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS(kJvmtiTlsKey));
   }
 
   global_tls->data[env] = data;
@@ -681,7 +686,8 @@
     return err;
   }
 
-  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  JvmtiGlobalTLSData* global_tls =
+      reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS(kJvmtiTlsKey));
   if (global_tls == nullptr) {
     *data_ptr = nullptr;
     return OK;
diff --git a/runtime/base/mutex.cc b/runtime/base/mutex.cc
index da286d7..dd58d75 100644
--- a/runtime/base/mutex.cc
+++ b/runtime/base/mutex.cc
@@ -42,6 +42,7 @@
 Mutex* Locks::allocated_thread_ids_lock_ = nullptr;
 ReaderWriterMutex* Locks::breakpoint_lock_ = nullptr;
 ReaderWriterMutex* Locks::classlinker_classes_lock_ = nullptr;
+Mutex* Locks::custom_tls_lock_ = nullptr;
 Mutex* Locks::deoptimization_lock_ = nullptr;
 ReaderWriterMutex* Locks::heap_bitmap_lock_ = nullptr;
 Mutex* Locks::instrument_entrypoints_lock_ = nullptr;
@@ -1057,6 +1058,7 @@
     DCHECK(allocated_thread_ids_lock_ != nullptr);
     DCHECK(breakpoint_lock_ != nullptr);
     DCHECK(classlinker_classes_lock_ != nullptr);
+    DCHECK(custom_tls_lock_ != nullptr);
     DCHECK(deoptimization_lock_ != nullptr);
     DCHECK(heap_bitmap_lock_ != nullptr);
     DCHECK(oat_file_manager_lock_ != nullptr);
@@ -1220,6 +1222,10 @@
     DCHECK(jni_function_table_lock_ == nullptr);
     jni_function_table_lock_ = new Mutex("JNI function table lock", current_lock_level);
 
+    UPDATE_CURRENT_LOCK_LEVEL(kCustomTlsLock);
+    DCHECK(custom_tls_lock_ == nullptr);
+    custom_tls_lock_ = new Mutex("Thread::custom_tls_ lock", current_lock_level);
+
     UPDATE_CURRENT_LOCK_LEVEL(kNativeDebugInterfaceLock);
     DCHECK(native_debug_interface_lock_ == nullptr);
     native_debug_interface_lock_ = new Mutex("Native debug interface lock", current_lock_level);
diff --git a/runtime/base/mutex.h b/runtime/base/mutex.h
index ced0cb1..af2e7b2 100644
--- a/runtime/base/mutex.h
+++ b/runtime/base/mutex.h
@@ -78,6 +78,7 @@
   kRosAllocBulkFreeLock,
   kTaggingLockLevel,
   kTransactionLogLock,
+  kCustomTlsLock,
   kJniFunctionTableLock,
   kJniWeakGlobalsLock,
   kJniGlobalsLock,
@@ -738,14 +739,20 @@
   // Guard accesses to the JNI function table override.
   static Mutex* jni_function_table_lock_ ACQUIRED_AFTER(jni_weak_globals_lock_);
 
+  // Guard accesses to the Thread::custom_tls_. We use this to allow the TLS of other threads to be
+  // read (the reader must hold the ThreadListLock or have some other way of ensuring the thread
+  // will not die in that case though). This is useful for (eg) the implementation of
+  // GetThreadLocalStorage.
+  static Mutex* custom_tls_lock_ ACQUIRED_AFTER(jni_function_table_lock_);
+
   // When declaring any Mutex add BOTTOM_MUTEX_ACQUIRED_AFTER to use annotalysis to check the code
   // doesn't try to acquire a higher level Mutex. NB Due to the way the annotalysis works this
   // actually only encodes the mutex being below jni_function_table_lock_ although having
   // kGenericBottomLock level is lower than this.
-  #define BOTTOM_MUTEX_ACQUIRED_AFTER ACQUIRED_AFTER(art::Locks::jni_function_table_lock_)
+  #define BOTTOM_MUTEX_ACQUIRED_AFTER ACQUIRED_AFTER(art::Locks::custom_tls_lock_)
 
   // Have an exclusive aborting thread.
-  static Mutex* abort_lock_ ACQUIRED_AFTER(jni_function_table_lock_);
+  static Mutex* abort_lock_ ACQUIRED_AFTER(custom_tls_lock_);
 
   // Allow mutual exclusion when manipulating Thread::suspend_count_.
   // TODO: Does the trade-off of a per-thread lock make sense?
diff --git a/runtime/thread.cc b/runtime/thread.cc
index 99a8829..cd6c834 100644
--- a/runtime/thread.cc
+++ b/runtime/thread.cc
@@ -42,6 +42,7 @@
 #include "base/file_utils.h"
 #include "base/memory_tool.h"
 #include "base/mutex.h"
+#include "base/stl_util.h"
 #include "base/systrace.h"
 #include "base/timing_logger.h"
 #include "base/to_str.h"
@@ -393,6 +394,22 @@
   return shadow_frame;
 }
 
+TLSData* Thread::GetCustomTLS(const char* key) {
+  MutexLock mu(Thread::Current(), *Locks::custom_tls_lock_);
+  auto it = custom_tls_.find(key);
+  return (it != custom_tls_.end()) ? it->second.get() : nullptr;
+}
+
+void Thread::SetCustomTLS(const char* key, TLSData* data) {
+  // We will swap the old data (which might be nullptr) with this and then delete it outside of the
+  // custom_tls_lock_.
+  std::unique_ptr<TLSData> old_data(data);
+  {
+    MutexLock mu(Thread::Current(), *Locks::custom_tls_lock_);
+    custom_tls_.GetOrCreate(key, []() { return std::unique_ptr<TLSData>(); }).swap(old_data);
+  }
+}
+
 void Thread::RemoveDebuggerShadowFrameMapping(size_t frame_id) {
   FrameIdToShadowFrame* head = tlsPtr_.frame_id_to_shadow_frame;
   if (head->GetFrameId() == frame_id) {
@@ -2092,7 +2109,6 @@
 Thread::Thread(bool daemon)
     : tls32_(daemon),
       wait_monitor_(nullptr),
-      custom_tls_(nullptr),
       can_call_into_java_(true) {
   wait_mutex_ = new Mutex("a thread wait mutex");
   wait_cond_ = new ConditionVariable("a thread wait condition variable", *wait_mutex_);
diff --git a/runtime/thread.h b/runtime/thread.h
index c8a4b61..edc429b 100644
--- a/runtime/thread.h
+++ b/runtime/thread.h
@@ -33,6 +33,7 @@
 #include "base/globals.h"
 #include "base/macros.h"
 #include "base/mutex.h"
+#include "base/safe_map.h"
 #include "entrypoints/jni/jni_entrypoints.h"
 #include "entrypoints/quick/quick_entrypoints.h"
 #include "handle_scope.h"
@@ -97,6 +98,14 @@
 class ThreadList;
 enum VisitRootFlags : uint8_t;
 
+// A piece of data that can be held in the CustomTls. The destructor will be called during thread
+// shutdown. The thread the destructor is called on is not necessarily the same thread it was stored
+// on.
+class TLSData {
+ public:
+  virtual ~TLSData() {}
+};
+
 // Thread priorities. These must match the Thread.MIN_PRIORITY,
 // Thread.NORM_PRIORITY, and Thread.MAX_PRIORITY constants.
 enum ThreadPriority {
@@ -1248,13 +1257,14 @@
     return debug_disallow_read_barrier_;
   }
 
-  void* GetCustomTLS() const REQUIRES(Locks::thread_list_lock_) {
-    return custom_tls_;
-  }
+  // Gets the current TLSData associated with the key or nullptr if there isn't any. Note that users
+  // do not gain ownership of TLSData and must synchronize with SetCustomTls themselves to prevent
+  // it from being deleted.
+  TLSData* GetCustomTLS(const char* key) REQUIRES(!Locks::custom_tls_lock_);
 
-  void SetCustomTLS(void* data) REQUIRES(Locks::thread_list_lock_) {
-    custom_tls_ = data;
-  }
+  // Sets the tls entry at 'key' to data. The thread takes ownership of the TLSData. The destructor
+  // will be run when the thread exits or when SetCustomTLS is called again with the same key.
+  void SetCustomTLS(const char* key, TLSData* data) REQUIRES(!Locks::custom_tls_lock_);
 
   // Returns true if the current thread is the jit sensitive thread.
   bool IsJitSensitiveThread() const {
@@ -1754,9 +1764,9 @@
   // Pending extra checkpoints if checkpoint_function_ is already used.
   std::list<Closure*> checkpoint_overflow_ GUARDED_BY(Locks::thread_suspend_count_lock_);
 
-  // Custom TLS field that can be used by plugins.
-  // TODO: Generalize once we have more plugins.
-  void* custom_tls_;
+  // Custom TLS field that can be used by plugins or the runtime. Should not be accessed directly by
+  // compiled code or entrypoints.
+  SafeMap<std::string, std::unique_ptr<TLSData>> custom_tls_ GUARDED_BY(Locks::custom_tls_lock_);
 
   // True if the thread is allowed to call back into java (for e.g. during class resolution).
   // By default this is true.