Ensure jvmti agents don't share ThreadLocalStorage

Bug: 63665647
Test: ./test.py --host -j40

Change-Id: Iea33cca5b708f60390b8c79462ca991363ad33a2
diff --git a/runtime/openjdkjvmti/OpenjdkJvmTi.cc b/runtime/openjdkjvmti/OpenjdkJvmTi.cc
index 63892dd..3c1311b 100644
--- a/runtime/openjdkjvmti/OpenjdkJvmTi.cc
+++ b/runtime/openjdkjvmti/OpenjdkJvmTi.cc
@@ -1498,10 +1498,11 @@
 
   static jvmtiError DisposeEnvironment(jvmtiEnv* env) {
     ENSURE_VALID_ENV(env);
-    gEventHandler.RemoveArtJvmTiEnv(ArtJvmTiEnv::AsArtJvmTiEnv(env));
-    art::Runtime::Current()->RemoveSystemWeakHolder(
-        ArtJvmTiEnv::AsArtJvmTiEnv(env)->object_tag_table.get());
-    delete env;
+    ArtJvmTiEnv* tienv = ArtJvmTiEnv::AsArtJvmTiEnv(env);
+    gEventHandler.RemoveArtJvmTiEnv(tienv);
+    art::Runtime::Current()->RemoveSystemWeakHolder(tienv->object_tag_table.get());
+    ThreadUtil::RemoveEnvironment(tienv);
+    delete tienv;
     return OK;
   }
 
@@ -1671,6 +1672,7 @@
 }
 
 extern const jvmtiInterface_1 gJvmtiInterface;
+
 ArtJvmTiEnv::ArtJvmTiEnv(art::JavaVMExt* runtime, EventHandler* event_handler)
     : art_vm(runtime),
       local_data(nullptr),
diff --git a/runtime/openjdkjvmti/ti_thread.cc b/runtime/openjdkjvmti/ti_thread.cc
index 3d447dc..d1cee9a 100644
--- a/runtime/openjdkjvmti/ti_thread.cc
+++ b/runtime/openjdkjvmti/ti_thread.cc
@@ -159,6 +159,17 @@
   return ERR(NONE);
 }
 
+static art::Thread* GetNativeThreadLocked(jthread thread,
+                                          const art::ScopedObjectAccessAlreadyRunnable& soa)
+    REQUIRES_SHARED(art::Locks::mutator_lock_)
+    REQUIRES(art::Locks::thread_list_lock_) {
+  if (thread == nullptr) {
+    return art::Thread::Current();
+  }
+
+  return art::Thread::FromManagedThread(soa, thread);
+}
+
 // Get the native thread. The spec says a null object denotes the current thread.
 static art::Thread* GetNativeThread(jthread thread,
                                     const art::ScopedObjectAccessAlreadyRunnable& soa)
@@ -495,40 +506,82 @@
   return ERR(NONE);
 }
 
-jvmtiError ThreadUtil::SetThreadLocalStorage(jvmtiEnv* env ATTRIBUTE_UNUSED,
-                                             jthread thread,
-                                             const void* data) {
-  art::ScopedObjectAccess soa(art::Thread::Current());
-  art::Thread* self = GetNativeThread(thread, soa);
-  if (self == nullptr && thread == nullptr) {
+// The struct that we store in the art::Thread::custom_tls_ that maps the jvmtiEnvs to the data
+// stored with that thread. This is needed since different jvmtiEnvs are not supposed to share TLS
+// data but we only have a single slot in Thread objects to store data.
+struct JvmtiGlobalTLSData {
+  std::unordered_map<jvmtiEnv*, const void*> data GUARDED_BY(art::Locks::thread_list_lock_);
+};
+
+static void RemoveTLSData(art::Thread* target, void* ctx) REQUIRES(art::Locks::thread_list_lock_) {
+  jvmtiEnv* env = reinterpret_cast<jvmtiEnv*>(ctx);
+  art::Locks::thread_list_lock_->AssertHeld(art::Thread::Current());
+  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  if (global_tls != nullptr) {
+    global_tls->data.erase(env);
+  }
+}
+
+void ThreadUtil::RemoveEnvironment(jvmtiEnv* env) {
+  art::Thread* self = art::Thread::Current();
+  art::MutexLock mu(self, *art::Locks::thread_list_lock_);
+  art::ThreadList* list = art::Runtime::Current()->GetThreadList();
+  list->ForEach(RemoveTLSData, env);
+}
+
+jvmtiError ThreadUtil::SetThreadLocalStorage(jvmtiEnv* env, jthread thread, const void* data) {
+  art::Thread* self = art::Thread::Current();
+  art::ScopedObjectAccess soa(self);
+  art::MutexLock mu(self, *art::Locks::thread_list_lock_);
+  art::Thread* target = GetNativeThreadLocked(thread, soa);
+  if (target == nullptr && thread == nullptr) {
     return ERR(INVALID_THREAD);
   }
-  if (self == nullptr) {
+  if (target == nullptr) {
     return ERR(THREAD_NOT_ALIVE);
   }
 
-  self->SetCustomTLS(data);
+  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  if (global_tls == nullptr) {
+    target->SetCustomTLS(new JvmtiGlobalTLSData);
+    global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  }
+
+  global_tls->data[env] = data;
 
   return ERR(NONE);
 }
 
-jvmtiError ThreadUtil::GetThreadLocalStorage(jvmtiEnv* env ATTRIBUTE_UNUSED,
+jvmtiError ThreadUtil::GetThreadLocalStorage(jvmtiEnv* env,
                                              jthread thread,
                                              void** data_ptr) {
   if (data_ptr == nullptr) {
     return ERR(NULL_POINTER);
   }
 
-  art::ScopedObjectAccess soa(art::Thread::Current());
-  art::Thread* self = GetNativeThread(thread, soa);
-  if (self == nullptr && thread == nullptr) {
+  art::Thread* self = art::Thread::Current();
+  art::ScopedObjectAccess soa(self);
+  art::MutexLock mu(self, *art::Locks::thread_list_lock_);
+  art::Thread* target = GetNativeThreadLocked(thread, soa);
+  if (target == nullptr && thread == nullptr) {
     return ERR(INVALID_THREAD);
   }
-  if (self == nullptr) {
+  if (target == nullptr) {
     return ERR(THREAD_NOT_ALIVE);
   }
 
-  *data_ptr = const_cast<void*>(self->GetCustomTLS());
+  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  if (global_tls == nullptr) {
+    *data_ptr = nullptr;
+    return OK;
+  }
+  auto it = global_tls->data.find(env);
+  if (it != global_tls->data.end()) {
+    *data_ptr = const_cast<void*>(it->second);
+  } else {
+    *data_ptr = nullptr;
+  }
+
   return ERR(NONE);
 }
 
diff --git a/runtime/openjdkjvmti/ti_thread.h b/runtime/openjdkjvmti/ti_thread.h
index 57967eb..d07dc06 100644
--- a/runtime/openjdkjvmti/ti_thread.h
+++ b/runtime/openjdkjvmti/ti_thread.h
@@ -54,6 +54,9 @@
   // To be called when it is safe to cache data.
   static void CacheData();
 
+  // Handle a jvmtiEnv going away.
+  static void RemoveEnvironment(jvmtiEnv* env);
+
   static jvmtiError GetAllThreads(jvmtiEnv* env, jint* threads_count_ptr, jthread** threads_ptr);
 
   static jvmtiError GetCurrentThread(jvmtiEnv* env, jthread* thread_ptr);