Ensure jvmti agents don't share ThreadLocalStorage

Bug: 63665647
Test: ./test.py --host -j40

Change-Id: Iea33cca5b708f60390b8c79462ca991363ad33a2
diff --git a/runtime/openjdkjvmti/ti_thread.cc b/runtime/openjdkjvmti/ti_thread.cc
index 3d447dc..d1cee9a 100644
--- a/runtime/openjdkjvmti/ti_thread.cc
+++ b/runtime/openjdkjvmti/ti_thread.cc
@@ -159,6 +159,17 @@
   return ERR(NONE);
 }
 
+static art::Thread* GetNativeThreadLocked(jthread thread,
+                                          const art::ScopedObjectAccessAlreadyRunnable& soa)
+    REQUIRES_SHARED(art::Locks::mutator_lock_)
+    REQUIRES(art::Locks::thread_list_lock_) {
+  if (thread == nullptr) {
+    return art::Thread::Current();
+  }
+
+  return art::Thread::FromManagedThread(soa, thread);
+}
+
 // Get the native thread. The spec says a null object denotes the current thread.
 static art::Thread* GetNativeThread(jthread thread,
                                     const art::ScopedObjectAccessAlreadyRunnable& soa)
@@ -495,40 +506,82 @@
   return ERR(NONE);
 }
 
-jvmtiError ThreadUtil::SetThreadLocalStorage(jvmtiEnv* env ATTRIBUTE_UNUSED,
-                                             jthread thread,
-                                             const void* data) {
-  art::ScopedObjectAccess soa(art::Thread::Current());
-  art::Thread* self = GetNativeThread(thread, soa);
-  if (self == nullptr && thread == nullptr) {
+// The struct that we store in the art::Thread::custom_tls_ that maps the jvmtiEnvs to the data
+// stored with that thread. This is needed since different jvmtiEnvs are not supposed to share TLS
+// data but we only have a single slot in Thread objects to store data.
+struct JvmtiGlobalTLSData {
+  std::unordered_map<jvmtiEnv*, const void*> data GUARDED_BY(art::Locks::thread_list_lock_);
+};
+
+static void RemoveTLSData(art::Thread* target, void* ctx) REQUIRES(art::Locks::thread_list_lock_) {
+  jvmtiEnv* env = reinterpret_cast<jvmtiEnv*>(ctx);
+  art::Locks::thread_list_lock_->AssertHeld(art::Thread::Current());
+  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  if (global_tls != nullptr) {
+    global_tls->data.erase(env);
+  }
+}
+
+void ThreadUtil::RemoveEnvironment(jvmtiEnv* env) {
+  art::Thread* self = art::Thread::Current();
+  art::MutexLock mu(self, *art::Locks::thread_list_lock_);
+  art::ThreadList* list = art::Runtime::Current()->GetThreadList();
+  list->ForEach(RemoveTLSData, env);
+}
+
+jvmtiError ThreadUtil::SetThreadLocalStorage(jvmtiEnv* env, jthread thread, const void* data) {
+  art::Thread* self = art::Thread::Current();
+  art::ScopedObjectAccess soa(self);
+  art::MutexLock mu(self, *art::Locks::thread_list_lock_);
+  art::Thread* target = GetNativeThreadLocked(thread, soa);
+  if (target == nullptr && thread == nullptr) {
     return ERR(INVALID_THREAD);
   }
-  if (self == nullptr) {
+  if (target == nullptr) {
     return ERR(THREAD_NOT_ALIVE);
   }
 
-  self->SetCustomTLS(data);
+  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  if (global_tls == nullptr) {
+    target->SetCustomTLS(new JvmtiGlobalTLSData);
+    global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  }
+
+  global_tls->data[env] = data;
 
   return ERR(NONE);
 }
 
-jvmtiError ThreadUtil::GetThreadLocalStorage(jvmtiEnv* env ATTRIBUTE_UNUSED,
+jvmtiError ThreadUtil::GetThreadLocalStorage(jvmtiEnv* env,
                                              jthread thread,
                                              void** data_ptr) {
   if (data_ptr == nullptr) {
     return ERR(NULL_POINTER);
   }
 
-  art::ScopedObjectAccess soa(art::Thread::Current());
-  art::Thread* self = GetNativeThread(thread, soa);
-  if (self == nullptr && thread == nullptr) {
+  art::Thread* self = art::Thread::Current();
+  art::ScopedObjectAccess soa(self);
+  art::MutexLock mu(self, *art::Locks::thread_list_lock_);
+  art::Thread* target = GetNativeThreadLocked(thread, soa);
+  if (target == nullptr && thread == nullptr) {
     return ERR(INVALID_THREAD);
   }
-  if (self == nullptr) {
+  if (target == nullptr) {
     return ERR(THREAD_NOT_ALIVE);
   }
 
-  *data_ptr = const_cast<void*>(self->GetCustomTLS());
+  JvmtiGlobalTLSData* global_tls = reinterpret_cast<JvmtiGlobalTLSData*>(target->GetCustomTLS());
+  if (global_tls == nullptr) {
+    *data_ptr = nullptr;
+    return OK;
+  }
+  auto it = global_tls->data.find(env);
+  if (it != global_tls->data.end()) {
+    *data_ptr = const_cast<void*>(it->second);
+  } else {
+    *data_ptr = nullptr;
+  }
+
   return ERR(NONE);
 }