Increase use of ScopedJniThreadState.

Move the routines for changing Object* to jobject and vice-versa
(AddLocalReference and Decode) to ScopedJniThreadState to enforce use of
Object*s in the Runnable thread state. In the Runnable thread state
suspension is necessary before GC can take place.

Reduce use of const ClassLoader* as the code bottoms out in FindClass
and with a field assignment where the const is cast away (ie if we're
not going to enforce the const-ness we shouldn't pretend it is).

Refactor the Thread::Attach API so that we're not handling raw Objects on
unattached threads.

Remove some unreachable code.

Change-Id: I0fa969f49ee6a8f10752af74a6b0e04d46b4cd97
diff --git a/src/check_jni.cc b/src/check_jni.cc
index 0fd5f6e..47f20e1 100644
--- a/src/check_jni.cc
+++ b/src/check_jni.cc
@@ -83,11 +83,6 @@
       reinterpret_cast<JNIEnvExt*>(env)->self->SirtContains(localRef);
 }
 
-template<typename T>
-T Decode(ScopedJniThreadState& ts, jobject obj) {
-  return reinterpret_cast<T>(ts.Self()->DecodeJObject(obj));
-}
-
 // Hack to allow forcecopy to work with jniGetNonMovableArrayElements.
 // The code deliberately uses an invalid sequence of operations, so we
 // need to pass it through unmodified.  Review that code before making
@@ -151,14 +146,14 @@
 class ScopedCheck {
  public:
   // For JNIEnv* functions.
-  explicit ScopedCheck(JNIEnv* env, int flags, const char* functionName) {
-    Init(env, reinterpret_cast<JNIEnvExt*>(env)->vm, flags, functionName, true);
+  explicit ScopedCheck(JNIEnv* env, int flags, const char* functionName) : ts_(env) {
+    Init(flags, functionName, true);
     CheckThread(flags);
   }
 
   // For JavaVM* functions.
-  explicit ScopedCheck(JavaVM* vm, bool has_method, const char* functionName) {
-    Init(NULL, vm, kFlag_Invocation, functionName, has_method);
+  explicit ScopedCheck(JavaVM* vm, bool has_method, const char* functionName) : ts_(vm) {
+    Init(kFlag_Invocation, functionName, has_method);
   }
 
   bool ForceCopy() {
@@ -185,7 +180,6 @@
    * Works for both static and instance fields.
    */
   void CheckFieldType(jobject java_object, jfieldID fid, char prim, bool isStatic) {
-    ScopedJniThreadState ts(env_);
     Field* f = CheckFieldID(fid);
     if (f == NULL) {
       return;
@@ -193,7 +187,7 @@
     Class* field_type = FieldHelper(f).GetType();
     if (!field_type->IsPrimitive()) {
       if (java_object != NULL) {
-        Object* obj = Decode<Object*>(ts, java_object);
+        Object* obj = ts_.Decode<Object*>(java_object);
         // If java_object is a weak global ref whose referent has been cleared,
         // obj will be NULL.  Otherwise, obj should always be non-NULL
         // and valid.
@@ -231,9 +225,7 @@
    * Assumes "jobj" has already been validated.
    */
   void CheckInstanceFieldID(jobject java_object, jfieldID fid) {
-    ScopedJniThreadState ts(env_);
-
-    Object* o = Decode<Object*>(ts, java_object);
+    Object* o = ts_.Decode<Object*>(java_object);
     if (o == NULL || !Runtime::Current()->GetHeap()->IsHeapAddress(o)) {
       JniAbortF(function_name_, "field operation on invalid %s: %p",
                 ToStr<IndirectRefKind>(GetIndirectRefKind(java_object)).c_str(), java_object);
@@ -266,7 +258,6 @@
    * 'expectedType' will be "L" for all objects, including arrays.
    */
   void CheckSig(jmethodID mid, const char* expectedType, bool isStatic) {
-    ScopedJniThreadState ts(env_);
     Method* m = CheckMethodID(mid);
     if (m == NULL) {
       return;
@@ -292,8 +283,7 @@
    * Assumes "java_class" has already been validated.
    */
   void CheckStaticFieldID(jclass java_class, jfieldID fid) {
-    ScopedJniThreadState ts(env_);
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts_.Decode<Class*>(java_class);
     const Field* f = CheckFieldID(fid);
     if (f == NULL) {
       return;
@@ -314,12 +304,11 @@
    * Instances of "java_class" must be instances of the method's declaring class.
    */
   void CheckStaticMethod(jclass java_class, jmethodID mid) {
-    ScopedJniThreadState ts(env_);
     const Method* m = CheckMethodID(mid);
     if (m == NULL) {
       return;
     }
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts_.Decode<Class*>(java_class);
     if (!c->IsAssignableFrom(m->GetDeclaringClass())) {
       JniAbortF(function_name_, "can't call static %s on class %s",
                 PrettyMethod(m).c_str(), PrettyClass(c).c_str());
@@ -334,12 +323,11 @@
    * will be handled automatically by the instanceof check.)
    */
   void CheckVirtualMethod(jobject java_object, jmethodID mid) {
-    ScopedJniThreadState ts(env_);
     const Method* m = CheckMethodID(mid);
     if (m == NULL) {
       return;
     }
-    Object* o = Decode<Object*>(ts, java_object);
+    Object* o = ts_.Decode<Object*>(java_object);
     if (!o->InstanceOf(m->GetDeclaringClass())) {
       JniAbortF(function_name_, "can't call %s on instance of %s",
                 PrettyMethod(m).c_str(), PrettyTypeOf(o).c_str());
@@ -386,7 +374,7 @@
     va_list ap;
 
     const Method* traceMethod = NULL;
-    if ((!vm_->trace.empty() || VLOG_IS_ON(third_party_jni)) && has_method_) {
+    if ((!ts_.Vm()->trace.empty() || VLOG_IS_ON(third_party_jni)) && has_method_) {
       // We need to guard some of the invocation interface's calls: a bad caller might
       // use DetachCurrentThread or GetEnv on a thread that's not yet attached.
       Thread* self = Thread::Current();
@@ -395,7 +383,7 @@
       }
     }
 
-    if (((flags_ & kFlag_ForceTrace) != 0) || (traceMethod != NULL && ShouldTrace(vm_, traceMethod))) {
+    if (((flags_ & kFlag_ForceTrace) != 0) || (traceMethod != NULL && ShouldTrace(ts_.Vm(), traceMethod))) {
       va_start(ap, fmt0);
       std::string msg;
       for (const char* fmt = fmt0; *fmt;) {
@@ -610,8 +598,7 @@
       return false;
     }
 
-    ScopedJniThreadState ts(env_);
-    Object* obj = Decode<Object*>(ts, java_object);
+    Object* obj = ts_.Decode<Object*>(java_object);
     if (!Runtime::Current()->GetHeap()->IsHeapAddress(obj)) {
       JniAbortF(function_name_, "%s is an invalid %s: %p (%p)",
                 what, ToStr<IndirectRefKind>(GetIndirectRefKind(java_object)).c_str(), java_object, obj);
@@ -647,9 +634,7 @@
   // Set "has_method" to true if we have a valid thread with a method pointer.
   // We won't have one before attaching a thread, after detaching a thread, or
   // when shutting down the runtime.
-  void Init(JNIEnv* env, JavaVM* vm, int flags, const char* functionName, bool has_method) {
-    env_ = reinterpret_cast<JNIEnvExt*>(env);
-    vm_ = reinterpret_cast<JavaVMExt*>(vm);
+  void Init(int flags, const char* functionName, bool has_method) {
     flags_ = flags;
     function_name_ = functionName;
     has_method_ = has_method;
@@ -666,8 +651,7 @@
       return;
     }
 
-    ScopedJniThreadState ts(env_);
-    Array* a = Decode<Array*>(ts, java_array);
+    Array* a = ts_.Decode<Array*>(java_array);
     if (!Runtime::Current()->GetHeap()->IsHeapAddress(a)) {
       JniAbortF(function_name_, "jarray is an invalid %s: %p (%p)",
                 ToStr<IndirectRefKind>(GetIndirectRefKind(java_array)).c_str(), java_array, a);
@@ -687,8 +671,8 @@
       JniAbortF(function_name_, "jfieldID was NULL");
       return NULL;
     }
-    Field* f = DecodeField(fid);
-    if (!Runtime::Current()->GetHeap()->IsHeapAddress(f)) {
+    Field* f = ts_.DecodeField(fid);
+    if (!Runtime::Current()->GetHeap()->IsHeapAddress(f) || !f->IsField()) {
       JniAbortF(function_name_, "invalid jfieldID: %p", fid);
       return NULL;
     }
@@ -700,8 +684,8 @@
       JniAbortF(function_name_, "jmethodID was NULL");
       return NULL;
     }
-    Method* m = DecodeMethod(mid);
-    if (!Runtime::Current()->GetHeap()->IsHeapAddress(m)) {
+    Method* m = ts_.DecodeMethod(mid);
+    if (!Runtime::Current()->GetHeap()->IsHeapAddress(m) || !m->IsMethod()) {
       JniAbortF(function_name_, "invalid jmethodID: %p", mid);
       return NULL;
     }
@@ -719,9 +703,7 @@
       return;
     }
 
-    ScopedJniThreadState ts(env_);
-
-    Object* o = Decode<Object*>(ts, java_object);
+    Object* o = ts_.Decode<Object*>(java_object);
     if (!Runtime::Current()->GetHeap()->IsHeapAddress(o)) {
       // TODO: when we remove work_around_app_jni_bugs, this should be impossible.
       JniAbortF(function_name_, "native code passing in reference to invalid %s: %p",
@@ -751,13 +733,13 @@
 
     // Verify that the current thread is (a) attached and (b) associated with
     // this particular instance of JNIEnv.
-    if (env_ != threadEnv) {
-      if (vm_->work_around_app_jni_bugs) {
+    if (ts_.Env() != threadEnv) {
+      if (ts_.Vm()->work_around_app_jni_bugs) {
         // If we're keeping broken code limping along, we need to suppress the abort...
-        LOG(ERROR) << "APP BUG DETECTED: thread " << *self << " using JNIEnv* from thread " << *env_->self;
+        LOG(ERROR) << "APP BUG DETECTED: thread " << *self << " using JNIEnv* from thread " << *ts_.Self();
       } else {
         JniAbortF(function_name_, "thread %s using JNIEnv* from thread %s",
-                  ToStr<Thread>(*self).c_str(), ToStr<Thread>(*env_->self).c_str());
+                  ToStr<Thread>(*self).c_str(), ToStr<Thread>(*ts_.Self()).c_str());
         return;
       }
     }
@@ -796,7 +778,7 @@
       // TODO: do we care any more? art always dumps pending exceptions on aborting threads.
       if (type != "java.lang.OutOfMemoryError") {
         JniAbortF(function_name_, "JNI %s called with pending exception: %s",
-                  function_name_, type.c_str(), jniGetStackTrace(env_).c_str());
+                  function_name_, type.c_str(), jniGetStackTrace(ts_.Env()).c_str());
       } else {
         JniAbortF(function_name_, "JNI %s called with %s pending", function_name_, type.c_str());
       }
@@ -873,8 +855,7 @@
     return 0;
   }
 
-  JNIEnvExt* env_;
-  JavaVMExt* vm_;
+  const ScopedJniThreadState ts_;
   const char* function_name_;
   int flags_;
   bool has_method_;
@@ -1072,7 +1053,7 @@
 static void* CreateGuardedPACopy(JNIEnv* env, const jarray java_array, jboolean* isCopy) {
   ScopedJniThreadState ts(env);
 
-  Array* a = Decode<Array*>(ts, java_array);
+  Array* a = ts.Decode<Array*>(java_array);
   size_t component_size = a->GetClass()->GetComponentSize();
   size_t byte_count = a->GetLength() * component_size;
   void* result = GuardedCopy::Create(a->GetRawData(component_size), byte_count, true);
@@ -1092,7 +1073,7 @@
   }
 
   ScopedJniThreadState ts(env);
-  Array* a = Decode<Array*>(ts, java_array);
+  Array* a = ts.Decode<Array*>(java_array);
 
   GuardedCopy::Check(__FUNCTION__, dataBuf, true);
 
@@ -1481,7 +1462,7 @@
     const jchar* result = baseEnv(env)->GetStringChars(env, java_string, isCopy);
     if (sc.ForceCopy() && result != NULL) {
       ScopedJniThreadState ts(env);
-      String* s = Decode<String*>(ts, java_string);
+      String* s = ts.Decode<String*>(java_string);
       int byteCount = s->GetLength() * 2;
       result = (const jchar*) GuardedCopy::Create(result, byteCount, false);
       if (isCopy != NULL) {
@@ -1709,7 +1690,7 @@
     const jchar* result = baseEnv(env)->GetStringCritical(env, java_string, isCopy);
     if (sc.ForceCopy() && result != NULL) {
       ScopedJniThreadState ts(env);
-      String* s = Decode<String*>(ts, java_string);
+      String* s = ts.Decode<String*>(java_string);
       int byteCount = s->GetLength() * 2;
       result = (const jchar*) GuardedCopy::Create(result, byteCount, false);
       if (isCopy != NULL) {
diff --git a/src/class_linker.cc b/src/class_linker.cc
index b18b31f..3c0c345 100644
--- a/src/class_linker.cc
+++ b/src/class_linker.cc
@@ -45,6 +45,7 @@
 #if defined(ART_USE_LLVM_COMPILER)
 #include "compiler_llvm/runtime_support_llvm.h"
 #endif
+#include "scoped_jni_thread_state.h"
 #include "ScopedLocalRef.h"
 #include "space.h"
 #include "stack_indirect_reference_table.h"
@@ -1116,7 +1117,7 @@
   return FindClass(descriptor, NULL);
 }
 
-Class* ClassLinker::FindClass(const char* descriptor, const ClassLoader* class_loader) {
+Class* ClassLinker::FindClass(const char* descriptor, ClassLoader* class_loader) {
   DCHECK_NE(*descriptor, '\0') << "descriptor is empty string";
   Thread* self = Thread::Current();
   DCHECK(self != NULL);
@@ -1159,19 +1160,24 @@
     }
 
   } else {
+    ScopedJniThreadState ts(self->GetJniEnv());
+    ScopedLocalRef<jobject> class_loader_object(ts.Env(),
+                                                ts.AddLocalReference<jobject>(class_loader));
     std::string class_name_string(DescriptorToDot(descriptor));
-    ScopedThreadStateChange tsc(self, kNative);
-    JNIEnv* env = self->GetJniEnv();
-    ScopedLocalRef<jobject> class_name_object(env, env->NewStringUTF(class_name_string.c_str()));
-    if (class_name_object.get() == NULL) {
-      return NULL;
+    ScopedLocalRef<jobject> result(ts.Env(), NULL);
+    {
+      ScopedThreadStateChange tsc(self, kNative);
+      ScopedLocalRef<jobject> class_name_object(ts.Env(),
+                                                ts.Env()->NewStringUTF(class_name_string.c_str()));
+      if (class_name_object.get() == NULL) {
+        return NULL;
+      }
+      CHECK(class_loader_object.get() != NULL);
+      result.reset(ts.Env()->CallObjectMethod(class_loader_object.get(),
+                                              WellKnownClasses::java_lang_ClassLoader_loadClass,
+                                              class_name_object.get()));
     }
-    ScopedLocalRef<jobject> class_loader_object(env, AddLocalReference<jobject>(env, class_loader));
-    CHECK(class_loader_object.get() != NULL);
-    ScopedLocalRef<jobject> result(env, env->CallObjectMethod(class_loader_object.get(),
-                                                              WellKnownClasses::java_lang_ClassLoader_loadClass,
-                                                              class_name_object.get()));
-    if (env->ExceptionCheck()) {
+    if (ts.Env()->ExceptionCheck()) {
       // If the ClassLoader threw, pass that exception up.
       return NULL;
     } else if (result.get() == NULL) {
@@ -1181,7 +1187,7 @@
       return NULL;
     } else {
       // success, return Class*
-      return Decode<Class*>(env, result.get());
+      return ts.Decode<Class*>(result.get());
     }
   }
 
@@ -1190,7 +1196,7 @@
 }
 
 Class* ClassLinker::DefineClass(const StringPiece& descriptor,
-                                const ClassLoader* class_loader,
+                                ClassLoader* class_loader,
                                 const DexFile& dex_file,
                                 const DexFile::ClassDef& dex_class_def) {
   SirtRef<Class> klass(NULL);
@@ -1453,7 +1459,7 @@
 void ClassLinker::LoadClass(const DexFile& dex_file,
                             const DexFile::ClassDef& dex_class_def,
                             SirtRef<Class>& klass,
-                            const ClassLoader* class_loader) {
+                            ClassLoader* class_loader) {
   CHECK(klass.get() != NULL);
   CHECK(klass->GetDexCache() != NULL);
   CHECK_EQ(Class::kStatusNotReady, klass->GetStatus());
@@ -1707,7 +1713,7 @@
 // array class; that always comes from the base element class.
 //
 // Returns NULL with an exception raised on failure.
-Class* ClassLinker::CreateArrayClass(const std::string& descriptor, const ClassLoader* class_loader) {
+Class* ClassLinker::CreateArrayClass(const std::string& descriptor, ClassLoader* class_loader) {
   CHECK_EQ('[', descriptor[0]);
 
   // Identify the underlying component type
@@ -2657,7 +2663,7 @@
 
 void ClassLinker::ConstructFieldMap(const DexFile& dex_file, const DexFile::ClassDef& dex_class_def,
                                     Class* c, SafeMap<uint32_t, Field*>& field_map) {
-  const ClassLoader* cl = c->GetClassLoader();
+  ClassLoader* cl = c->GetClassLoader();
   const byte* class_data = dex_file.GetClassData(dex_class_def);
   ClassDataItemIterator it(dex_file, class_data);
   for (size_t i = 0; it.HasNextStaticField(); i++, it.Next()) {
@@ -3342,7 +3348,7 @@
 Class* ClassLinker::ResolveType(const DexFile& dex_file,
                                 uint16_t type_idx,
                                 DexCache* dex_cache,
-                                const ClassLoader* class_loader) {
+                                ClassLoader* class_loader) {
   DCHECK(dex_cache != NULL);
   Class* resolved = dex_cache->GetResolvedType(type_idx);
   if (resolved == NULL) {
@@ -3369,7 +3375,7 @@
 Method* ClassLinker::ResolveMethod(const DexFile& dex_file,
                                    uint32_t method_idx,
                                    DexCache* dex_cache,
-                                   const ClassLoader* class_loader,
+                                   ClassLoader* class_loader,
                                    bool is_direct) {
   DCHECK(dex_cache != NULL);
   Method* resolved = dex_cache->GetResolvedMethod(method_idx);
@@ -3419,7 +3425,7 @@
 Field* ClassLinker::ResolveField(const DexFile& dex_file,
                                  uint32_t field_idx,
                                  DexCache* dex_cache,
-                                 const ClassLoader* class_loader,
+                                 ClassLoader* class_loader,
                                  bool is_static) {
   DCHECK(dex_cache != NULL);
   Field* resolved = dex_cache->GetResolvedField(field_idx);
@@ -3459,7 +3465,7 @@
 Field* ClassLinker::ResolveFieldJLS(const DexFile& dex_file,
                                     uint32_t field_idx,
                                     DexCache* dex_cache,
-                                    const ClassLoader* class_loader) {
+                                    ClassLoader* class_loader) {
   DCHECK(dex_cache != NULL);
   Field* resolved = dex_cache->GetResolvedField(field_idx);
   if (resolved != NULL) {
diff --git a/src/class_linker.h b/src/class_linker.h
index 01c1051..6cf2e14 100644
--- a/src/class_linker.h
+++ b/src/class_linker.h
@@ -54,12 +54,12 @@
 
   // Finds a class by its descriptor, loading it if necessary.
   // If class_loader is null, searches boot_class_path_.
-  Class* FindClass(const char* descriptor, const ClassLoader* class_loader);
+  Class* FindClass(const char* descriptor, ClassLoader* class_loader);
 
   Class* FindSystemClass(const char* descriptor);
 
   // Define a new a class based on a ClassDef from a DexFile
-  Class* DefineClass(const StringPiece& descriptor, const ClassLoader* class_loader,
+  Class* DefineClass(const StringPiece& descriptor, ClassLoader* class_loader,
                      const DexFile& dex_file, const DexFile::ClassDef& dex_class_def);
 
   // Finds a class by its descriptor, returning NULL if it isn't wasn't loaded
@@ -119,7 +119,7 @@
     if (UNLIKELY(resolved_type == NULL)) {
       Class* declaring_class = referrer->GetDeclaringClass();
       DexCache* dex_cache = declaring_class->GetDexCache();
-      const ClassLoader* class_loader = declaring_class->GetClassLoader();
+      ClassLoader* class_loader = declaring_class->GetClassLoader();
       const DexFile& dex_file = FindDexFile(dex_cache);
       resolved_type = ResolveType(dex_file, type_idx, dex_cache, class_loader);
     }
@@ -131,7 +131,7 @@
     DexCache* dex_cache = declaring_class->GetDexCache();
     Class* resolved_type = dex_cache->GetResolvedType(type_idx);
     if (UNLIKELY(resolved_type == NULL)) {
-      const ClassLoader* class_loader = declaring_class->GetClassLoader();
+      ClassLoader* class_loader = declaring_class->GetClassLoader();
       const DexFile& dex_file = FindDexFile(dex_cache);
       resolved_type = ResolveType(dex_file, type_idx, dex_cache, class_loader);
     }
@@ -145,7 +145,7 @@
   Class* ResolveType(const DexFile& dex_file,
                      uint16_t type_idx,
                      DexCache* dex_cache,
-                     const ClassLoader* class_loader);
+                     ClassLoader* class_loader);
 
   // Resolve a method with a given ID from the DexFile, storing the
   // result in DexCache. The ClassLinker and ClassLoader are used as
@@ -155,7 +155,7 @@
   Method* ResolveMethod(const DexFile& dex_file,
                         uint32_t method_idx,
                         DexCache* dex_cache,
-                        const ClassLoader* class_loader,
+                        ClassLoader* class_loader,
                         bool is_direct);
 
   Method* ResolveMethod(uint32_t method_idx, const Method* referrer, bool is_direct) {
@@ -163,7 +163,7 @@
     if (UNLIKELY(resolved_method == NULL || resolved_method->IsRuntimeMethod())) {
       Class* declaring_class = referrer->GetDeclaringClass();
       DexCache* dex_cache = declaring_class->GetDexCache();
-      const ClassLoader* class_loader = declaring_class->GetClassLoader();
+      ClassLoader* class_loader = declaring_class->GetClassLoader();
       const DexFile& dex_file = FindDexFile(dex_cache);
       resolved_method = ResolveMethod(dex_file, method_idx, dex_cache, class_loader, is_direct);
     }
@@ -176,7 +176,7 @@
     if (UNLIKELY(resolved_field == NULL)) {
       Class* declaring_class = referrer->GetDeclaringClass();
       DexCache* dex_cache = declaring_class->GetDexCache();
-      const ClassLoader* class_loader = declaring_class->GetClassLoader();
+      ClassLoader* class_loader = declaring_class->GetClassLoader();
       const DexFile& dex_file = FindDexFile(dex_cache);
       resolved_field = ResolveField(dex_file, field_idx, dex_cache, class_loader, is_static);
     }
@@ -191,7 +191,7 @@
   Field* ResolveField(const DexFile& dex_file,
                       uint32_t field_idx,
                       DexCache* dex_cache,
-                      const ClassLoader* class_loader,
+                      ClassLoader* class_loader,
                       bool is_static);
 
   // Resolve a field with a given ID from the DexFile, storing the
@@ -201,7 +201,7 @@
   Field* ResolveFieldJLS(const DexFile& dex_file,
                          uint32_t field_idx,
                          DexCache* dex_cache,
-                         const ClassLoader* class_loader);
+                         ClassLoader* class_loader);
 
   // Get shorty from method index without resolution. Used to do handlerization.
   const char* MethodShorty(uint32_t method_idx, Method* referrer, uint32_t* length);
@@ -323,7 +323,7 @@
                                   Primitive::Type type);
 
 
-  Class* CreateArrayClass(const std::string& descriptor, const ClassLoader* class_loader);
+  Class* CreateArrayClass(const std::string& descriptor, ClassLoader* class_loader);
 
   void AppendToBootClassPath(const DexFile& dex_file);
   void AppendToBootClassPath(const DexFile& dex_file, SirtRef<DexCache>& dex_cache);
@@ -337,7 +337,7 @@
   void LoadClass(const DexFile& dex_file,
                  const DexFile::ClassDef& dex_class_def,
                  SirtRef<Class>& klass,
-                 const ClassLoader* class_loader);
+                 ClassLoader* class_loader);
 
   void LoadField(const DexFile& dex_file, const ClassDataItemIterator& it, SirtRef<Class>& klass,
                  SirtRef<Field>& dst);
diff --git a/src/class_linker_test.cc b/src/class_linker_test.cc
index f677cae..a7f9c66 100644
--- a/src/class_linker_test.cc
+++ b/src/class_linker_test.cc
@@ -79,7 +79,7 @@
 
   void AssertArrayClass(const std::string& array_descriptor,
                         const std::string& component_type,
-                        const ClassLoader* class_loader) {
+                        ClassLoader* class_loader) {
     Class* array = class_linker_->FindClass(array_descriptor.c_str(), class_loader);
     ClassHelper array_component_ch(array->GetComponentType());
     EXPECT_STREQ(component_type.c_str(), array_component_ch.GetDescriptor());
diff --git a/src/common_test.h b/src/common_test.h
index fbd8b5b..a9bd139 100644
--- a/src/common_test.h
+++ b/src/common_test.h
@@ -468,7 +468,7 @@
     return class_loader.get();
   }
 
-  void CompileClass(const ClassLoader* class_loader, const char* class_name) {
+  void CompileClass(ClassLoader* class_loader, const char* class_name) {
     std::string class_descriptor(DotToDescriptor(class_name));
     Class* klass = class_linker_->FindClass(class_descriptor.c_str(), class_loader);
     CHECK(klass != NULL) << "Class not found " << class_name;
diff --git a/src/compiler.cc b/src/compiler.cc
index 2633b78..fd18713 100644
--- a/src/compiler.cc
+++ b/src/compiler.cc
@@ -439,7 +439,7 @@
   }
 }
 
-void Compiler::CompileAll(const ClassLoader* class_loader,
+void Compiler::CompileAll(ClassLoader* class_loader,
                           const std::vector<const DexFile*>& dex_files) {
   DCHECK(!Runtime::Current()->IsStarted());
 
@@ -469,7 +469,7 @@
 void Compiler::CompileOne(const Method* method) {
   DCHECK(!Runtime::Current()->IsStarted());
 
-  const ClassLoader* class_loader = method->GetDeclaringClass()->GetClassLoader();
+  ClassLoader* class_loader = method->GetDeclaringClass()->GetClassLoader();
 
   // Find the dex_file
   const DexCache* dex_cache = method->GetDeclaringClass()->GetDexCache();
@@ -487,7 +487,7 @@
   PostCompile(class_loader, dex_files);
 }
 
-void Compiler::Resolve(const ClassLoader* class_loader,
+void Compiler::Resolve(ClassLoader* class_loader,
                        const std::vector<const DexFile*>& dex_files, TimingLogger& timings) {
   for (size_t i = 0; i != dex_files.size(); ++i) {
     const DexFile* dex_file = dex_files[i];
@@ -496,7 +496,7 @@
   }
 }
 
-void Compiler::PreCompile(const ClassLoader* class_loader,
+void Compiler::PreCompile(ClassLoader* class_loader,
                           const std::vector<const DexFile*>& dex_files, TimingLogger& timings) {
   Resolve(class_loader, dex_files, timings);
 
@@ -507,7 +507,7 @@
   timings.AddSplit("PreCompile.InitializeClassesWithoutClinit");
 }
 
-void Compiler::PostCompile(const ClassLoader* class_loader,
+void Compiler::PostCompile(ClassLoader* class_loader,
                            const std::vector<const DexFile*>& dex_files) {
   SetGcMaps(class_loader, dex_files);
 #if defined(ART_USE_LLVM_COMPILER)
@@ -926,7 +926,7 @@
 class CompilationContext {
  public:
   CompilationContext(ClassLinker* class_linker,
-          const ClassLoader* class_loader,
+          ClassLoader* class_loader,
           Compiler* compiler,
           DexCache* dex_cache,
           const DexFile* dex_file)
@@ -940,7 +940,7 @@
     CHECK(class_linker_ != NULL);
     return class_linker_;
   }
-  const ClassLoader* GetClassLoader() {
+  ClassLoader* GetClassLoader() {
     return class_loader_;
   }
   Compiler* GetCompiler() {
@@ -958,7 +958,7 @@
 
  private:
   ClassLinker* class_linker_;
-  const ClassLoader* class_loader_;
+  ClassLoader* class_loader_;
   Compiler* compiler_;
   DexCache* dex_cache_;
   const DexFile* dex_file_;
@@ -1121,7 +1121,7 @@
   }
 }
 
-void Compiler::ResolveDexFile(const ClassLoader* class_loader, const DexFile& dex_file, TimingLogger& timings) {
+void Compiler::ResolveDexFile(ClassLoader* class_loader, const DexFile& dex_file, TimingLogger& timings) {
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   DexCache* dex_cache = class_linker->FindDexCache(dex_file);
 
@@ -1143,7 +1143,7 @@
   timings.AddSplit("Resolve " + dex_file.GetLocation() + " MethodsAndFields");
 }
 
-void Compiler::Verify(const ClassLoader* class_loader,
+void Compiler::Verify(ClassLoader* class_loader,
                       const std::vector<const DexFile*>& dex_files) {
   for (size_t i = 0; i != dex_files.size(); ++i) {
     const DexFile* dex_file = dex_files[i];
@@ -1190,7 +1190,7 @@
   CHECK(!Thread::Current()->IsExceptionPending()) << PrettyTypeOf(Thread::Current()->GetException());
 }
 
-void Compiler::VerifyDexFile(const ClassLoader* class_loader, const DexFile& dex_file) {
+void Compiler::VerifyDexFile(ClassLoader* class_loader, const DexFile& dex_file) {
   dex_file.ChangePermissions(PROT_READ | PROT_WRITE);
 
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
@@ -1200,7 +1200,7 @@
   dex_file.ChangePermissions(PROT_READ);
 }
 
-void Compiler::InitializeClassesWithoutClinit(const ClassLoader* class_loader,
+void Compiler::InitializeClassesWithoutClinit(ClassLoader* class_loader,
                                               const std::vector<const DexFile*>& dex_files) {
   for (size_t i = 0; i != dex_files.size(); ++i) {
     const DexFile* dex_file = dex_files[i];
@@ -1209,7 +1209,7 @@
   }
 }
 
-void Compiler::InitializeClassesWithoutClinit(const ClassLoader* class_loader, const DexFile& dex_file) {
+void Compiler::InitializeClassesWithoutClinit(ClassLoader* class_loader, const DexFile& dex_file) {
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   for (size_t class_def_index = 0; class_def_index < dex_file.NumClassDefs(); class_def_index++) {
     const DexFile::ClassDef& class_def = dex_file.GetClassDef(class_def_index);
@@ -1390,7 +1390,7 @@
   STLDeleteElements(&threads);
 }
 
-void Compiler::Compile(const ClassLoader* class_loader,
+void Compiler::Compile(ClassLoader* class_loader,
                        const std::vector<const DexFile*>& dex_files) {
 #if defined(ART_USE_LLVM_COMPILER)
   if (dex_files.size() <= 0) {
@@ -1465,7 +1465,7 @@
   DCHECK(!it.HasNext());
 }
 
-void Compiler::CompileDexFile(const ClassLoader* class_loader, const DexFile& dex_file) {
+void Compiler::CompileDexFile(ClassLoader* class_loader, const DexFile& dex_file) {
   CompilationContext context(NULL, class_loader, this, NULL, &dex_file);
   ForAll(&context, 0, dex_file.NumClassDefs(), Compiler::CompileClass, thread_count_);
 }
@@ -1605,7 +1605,7 @@
   return it->second;
 }
 
-void Compiler::SetGcMaps(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files) {
+void Compiler::SetGcMaps(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files) {
   for (size_t i = 0; i != dex_files.size(); ++i) {
     const DexFile* dex_file = dex_files[i];
     CHECK(dex_file != NULL);
@@ -1613,7 +1613,7 @@
   }
 }
 
-void Compiler::SetGcMapsDexFile(const ClassLoader* class_loader, const DexFile& dex_file) {
+void Compiler::SetGcMapsDexFile(ClassLoader* class_loader, const DexFile& dex_file) {
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   DexCache* dex_cache = class_linker->FindDexCache(dex_file);
   for (size_t class_def_index = 0; class_def_index < dex_file.NumClassDefs(); class_def_index++) {
diff --git a/src/compiler.h b/src/compiler.h
index 8f5d5b4..5202967 100644
--- a/src/compiler.h
+++ b/src/compiler.h
@@ -53,7 +53,7 @@
 
   ~Compiler();
 
-  void CompileAll(const ClassLoader* class_loader,
+  void CompileAll(ClassLoader* class_loader,
                   const std::vector<const DexFile*>& dex_files);
 
   // Compile a single Method
@@ -255,24 +255,24 @@
   // Checks if class specified by type_idx is one of the image_classes_
   bool IsImageClass(const std::string& descriptor) const;
 
-  void PreCompile(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files, TimingLogger& timings);
-  void PostCompile(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
+  void PreCompile(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files, TimingLogger& timings);
+  void PostCompile(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
 
   // Attempt to resolve all type, methods, fields, and strings
   // referenced from code in the dex file following PathClassLoader
   // ordering semantics.
-  void Resolve(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files, TimingLogger& timings);
-  void ResolveDexFile(const ClassLoader* class_loader, const DexFile& dex_file, TimingLogger& timings);
+  void Resolve(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files, TimingLogger& timings);
+  void ResolveDexFile(ClassLoader* class_loader, const DexFile& dex_file, TimingLogger& timings);
 
-  void Verify(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
-  void VerifyDexFile(const ClassLoader* class_loader, const DexFile& dex_file);
+  void Verify(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
+  void VerifyDexFile(ClassLoader* class_loader, const DexFile& dex_file);
 
-  void InitializeClassesWithoutClinit(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
-  void InitializeClassesWithoutClinit(const ClassLoader* class_loader, const DexFile& dex_file);
+  void InitializeClassesWithoutClinit(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
+  void InitializeClassesWithoutClinit(ClassLoader* class_loader, const DexFile& dex_file);
 
-  void Compile(const ClassLoader* class_loader,
+  void Compile(ClassLoader* class_loader,
                const std::vector<const DexFile*>& dex_files);
-  void CompileDexFile(const ClassLoader* class_loader, const DexFile& dex_file);
+  void CompileDexFile(ClassLoader* class_loader, const DexFile& dex_file);
   void CompileClass(const DexFile::ClassDef& class_def, const ClassLoader* class_loader,
                     const DexFile& dex_file);
   void CompileMethod(const DexFile::CodeItem* code_item, uint32_t access_flags, uint32_t method_idx,
@@ -280,8 +280,8 @@
 
   static void CompileClass(CompilationContext* context, size_t class_def_index);
 
-  void SetGcMaps(const ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
-  void SetGcMapsDexFile(const ClassLoader* class_loader, const DexFile& dex_file);
+  void SetGcMaps(ClassLoader* class_loader, const std::vector<const DexFile*>& dex_files);
+  void SetGcMapsDexFile(ClassLoader* class_loader, const DexFile& dex_file);
   void SetGcMapsMethod(const DexFile& dex_file, Method* method);
 
   void InsertInvokeStub(const std::string& key, const CompiledInvokeStub* compiled_invoke_stub);
diff --git a/src/compiler/CompilerIR.h b/src/compiler/CompilerIR.h
index ba99715..ef56876 100644
--- a/src/compiler/CompilerIR.h
+++ b/src/compiler/CompilerIR.h
@@ -436,7 +436,7 @@
   ClassLinker* class_linker;     // Linker to resolve fields and methods
   const DexFile* dex_file;       // DexFile containing the method being compiled
   DexCache* dex_cache;           // DexFile's corresponding cache
-  const ClassLoader* class_loader;  // compiling method's class loader
+  ClassLoader* class_loader;     // compiling method's class loader
   uint32_t method_idx;                // compiling method's index into method_ids of DexFile
   const DexFile::CodeItem* code_item;  // compiling method's DexFile code_item
   uint32_t access_flags;              // compiling method's access flags
diff --git a/src/compiler_llvm/runtime_support_llvm.cc b/src/compiler_llvm/runtime_support_llvm.cc
index 93785a6..cfaaea4 100644
--- a/src/compiler_llvm/runtime_support_llvm.cc
+++ b/src/compiler_llvm/runtime_support_llvm.cc
@@ -667,10 +667,11 @@
 
   // Start new JNI local reference state
   JNIEnvExt* env = thread->GetJniEnv();
+  ScopedJniThreadState ts(env);
   ScopedJniEnvLocalRefState env_state(env);
 
   // Create local ref. copies of the receiver
-  jobject rcvr_jobj = AddLocalReference<jobject>(env, receiver);
+  jobject rcvr_jobj = ts.AddLocalReference<jobject>(receiver);
 
   // Convert proxy method into expected interface method
   Method* interface_method = proxy_method->FindOverriddenMethod();
@@ -680,7 +681,7 @@
   // Set up arguments array and place in local IRT during boxing (which may allocate/GC)
   jvalue args_jobj[3];
   args_jobj[0].l = rcvr_jobj;
-  args_jobj[1].l = AddLocalReference<jobject>(env, interface_method);
+  args_jobj[1].l = ts.AddLocalReference<jobject>(interface_method);
   // Args array, if no arguments then NULL (don't include receiver in argument count)
   args_jobj[2].l = NULL;
   ObjectArray<Object>* args = NULL;
@@ -690,7 +691,7 @@
       CHECK(thread->IsExceptionPending());
       return;
     }
-    args_jobj[2].l = AddLocalReference<jobjectArray>(env, args);
+    args_jobj[2].l = ts.AddLocalReference<jobjectArray>(args);
   }
 
   // Get parameter types.
diff --git a/src/compiler_test.cc b/src/compiler_test.cc
index e3faa3b..088726f 100644
--- a/src/compiler_test.cc
+++ b/src/compiler_test.cc
@@ -31,13 +31,13 @@
 
 class CompilerTest : public CommonTest {
  protected:
-  void CompileAll(const ClassLoader* class_loader) {
+  void CompileAll(ClassLoader* class_loader) {
     compiler_->CompileAll(class_loader, Runtime::Current()->GetCompileTimeClassPath(class_loader));
     MakeAllExecutable(class_loader);
   }
 
-  void EnsureCompiled(const ClassLoader* class_loader,
-      const char* class_name, const char* method, const char* signature, bool is_virtual) {
+  void EnsureCompiled(ClassLoader* class_loader, const char* class_name, const char* method,
+                      const char* signature, bool is_virtual) {
     CompileAll(class_loader);
     runtime_->Start();
     env_ = Thread::Current()->GetJniEnv();
@@ -51,7 +51,7 @@
     CHECK(mid_ != NULL) << "Method not found: " << class_name << "." << method << signature;
   }
 
-  void MakeAllExecutable(const ClassLoader* class_loader) {
+  void MakeAllExecutable(ClassLoader* class_loader) {
     const std::vector<const DexFile*>& class_path
         = Runtime::Current()->GetCompileTimeClassPath(class_loader);
     for (size_t i = 0; i != class_path.size(); ++i) {
@@ -61,7 +61,7 @@
     }
   }
 
-  void MakeDexFileExecutable(const ClassLoader* class_loader, const DexFile& dex_file) {
+  void MakeDexFileExecutable(ClassLoader* class_loader, const DexFile& dex_file) {
     ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
     for (size_t i = 0; i < dex_file.NumClassDefs(); i++) {
       const DexFile::ClassDef& class_def = dex_file.GetClassDef(i);
diff --git a/src/debugger.cc b/src/debugger.cc
index 7dfbd22..9e3ab3f 100644
--- a/src/debugger.cc
+++ b/src/debugger.cc
@@ -28,6 +28,7 @@
 #endif
 #include "object_utils.h"
 #include "safe_map.h"
+#include "scoped_jni_thread_state.h"
 #include "scoped_thread_list_lock.h"
 #include "ScopedLocalRef.h"
 #include "ScopedPrimitiveArray.h"
@@ -220,11 +221,12 @@
 }
 
 static Thread* DecodeThread(JDWP::ObjectId threadId) {
+  ScopedJniThreadState ts(Thread::Current());
   Object* thread_peer = gRegistry->Get<Object*>(threadId);
   if (thread_peer == NULL || thread_peer == kInvalidObject) {
     return NULL;
   }
-  return Thread::FromManagedThread(thread_peer);
+  return Thread::FromManagedThread(ts, thread_peer);
 }
 
 static JDWP::JdwpTag BasicTagFromDescriptor(const char* descriptor) {
@@ -1369,11 +1371,17 @@
 }
 
 JDWP::ObjectId Dbg::GetSystemThreadGroupId() {
-  return gRegistry->Add(Thread::GetSystemThreadGroup());
+  ScopedJniThreadState ts(Thread::Current());
+  Object* group =
+      ts.DecodeField(WellKnownClasses::java_lang_ThreadGroup_systemThreadGroup)->GetObject(NULL);
+  return gRegistry->Add(group);
 }
 
 JDWP::ObjectId Dbg::GetMainThreadGroupId() {
-  return gRegistry->Add(Thread::GetMainThreadGroup());
+  ScopedJniThreadState ts(Thread::Current());
+  Object* group =
+      ts.DecodeField(WellKnownClasses::java_lang_ThreadGroup_mainThreadGroup)->GetObject(NULL);
+  return gRegistry->Add(group);
 }
 
 bool Dbg::GetThreadStatus(JDWP::ObjectId threadId, JDWP::JdwpThreadStatus* pThreadStatus, JDWP::JdwpSuspendStatus* pSuspendStatus) {
@@ -1422,7 +1430,11 @@
 }
 
 void Dbg::GetThreadGroupThreadsImpl(Object* thread_group, JDWP::ObjectId** ppThreadIds, uint32_t* pThreadCount) {
-  struct ThreadListVisitor {
+  class ThreadListVisitor {
+   public:
+    ThreadListVisitor(const ScopedJniThreadState& ts, Object* thread_group)
+      : ts_(ts), thread_group_(thread_group) {}
+
     static void Visit(Thread* t, void* arg) {
       reinterpret_cast<ThreadListVisitor*>(arg)->Visit(t);
     }
@@ -1433,27 +1445,34 @@
         // query all threads, so it's easier if we just don't tell them about this thread.
         return;
       }
-      if (thread_group == NULL || t->GetThreadGroup() == thread_group) {
-        threads.push_back(gRegistry->Add(t->GetPeer()));
+      if (thread_group_ == NULL || t->GetThreadGroup(ts_) == thread_group_) {
+        threads_.push_back(gRegistry->Add(t->GetPeer()));
       }
     }
 
-    Object* thread_group;
-    std::vector<JDWP::ObjectId> threads;
+    const std::vector<JDWP::ObjectId>& GetThreads() {
+      return threads_;
+    }
+
+   private:
+    const ScopedJniThreadState& ts_;
+    Object* const thread_group_;
+    std::vector<JDWP::ObjectId> threads_;
   };
 
-  ThreadListVisitor tlv;
-  tlv.thread_group = thread_group;
+  ScopedJniThreadState ts(Thread::Current());
+  ThreadListVisitor tlv(ts, thread_group);
 
   Runtime::Current()->GetThreadList()->ForEach(ThreadListVisitor::Visit, &tlv);
 
-  *pThreadCount = tlv.threads.size();
+  *pThreadCount = tlv.GetThreads().size();
   if (*pThreadCount == 0) {
     *ppThreadIds = NULL;
   } else {
+    // TODO: pass in std::vector rather than passing around pointers.
     *ppThreadIds = new JDWP::ObjectId[*pThreadCount];
     for (size_t i = 0; i < *pThreadCount; ++i) {
-      (*ppThreadIds)[i] = tlv.threads[i];
+      (*ppThreadIds)[i] = tlv.GetThreads()[i];
     }
   }
 }
@@ -1546,9 +1565,10 @@
 }
 
 void Dbg::SuspendThread(JDWP::ObjectId threadId) {
+  ScopedJniThreadState ts(Thread::Current());
   Object* peer = gRegistry->Get<Object*>(threadId);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(peer);
+  Thread* thread = Thread::FromManagedThread(ts, peer);
   if (thread == NULL) {
     LOG(WARNING) << "No such thread for suspend: " << peer;
     return;
@@ -1557,9 +1577,10 @@
 }
 
 void Dbg::ResumeThread(JDWP::ObjectId threadId) {
+  ScopedJniThreadState ts(Thread::Current());
   Object* peer = gRegistry->Get<Object*>(threadId);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(peer);
+  Thread* thread = Thread::FromManagedThread(ts, peer);
   if (thread == NULL) {
     LOG(WARNING) << "No such thread for resume: " << peer;
     return;
@@ -2336,14 +2357,12 @@
 }
 
 void Dbg::ExecuteMethod(DebugInvokeReq* pReq) {
-  Thread* self = Thread::Current();
+  ScopedJniThreadState ts(Thread::Current());
 
   // We can be called while an exception is pending. We need
   // to preserve that across the method invocation.
-  SirtRef<Throwable> old_exception(self->GetException());
-  self->ClearException();
-
-  ScopedThreadStateChange tsc(self, kRunnable);
+  SirtRef<Throwable> old_exception(ts.Self()->GetException());
+  ts.Self()->ClearException();
 
   // Translate the method through the vtable, unless the debugger wants to suppress it.
   Method* m = pReq->method_;
@@ -2359,15 +2378,15 @@
 
   CHECK_EQ(sizeof(jvalue), sizeof(uint64_t));
 
-  LOG(INFO) << "self=" << self << " pReq->receiver_=" << pReq->receiver_ << " m=" << m << " #" << pReq->arg_count_ << " " << pReq->arg_values_;
-  pReq->result_value = InvokeWithJValues(self, pReq->receiver_, m, reinterpret_cast<JValue*>(pReq->arg_values_));
+  LOG(INFO) << "self=" << ts.Self() << " pReq->receiver_=" << pReq->receiver_ << " m=" << m << " #" << pReq->arg_count_ << " " << pReq->arg_values_;
+  pReq->result_value = InvokeWithJValues(ts, pReq->receiver_, m, reinterpret_cast<JValue*>(pReq->arg_values_));
 
-  pReq->exception = gRegistry->Add(self->GetException());
+  pReq->exception = gRegistry->Add(ts.Self()->GetException());
   pReq->result_tag = BasicTagFromDescriptor(MethodHelper(m).GetShorty());
   if (pReq->exception != 0) {
-    Object* exc = self->GetException();
+    Object* exc = ts.Self()->GetException();
     VLOG(jdwp) << "  JDWP invocation returning with exception=" << exc << " " << PrettyTypeOf(exc);
-    self->ClearException();
+    ts.Self()->ClearException();
     pReq->result_value.SetJ(0);
   } else if (pReq->result_tag == JDWP::JT_OBJECT) {
     /* if no exception thrown, examine object result more closely */
@@ -2390,7 +2409,7 @@
   }
 
   if (old_exception.get() != NULL) {
-    self->SetException(old_exception.get());
+    ts.Self()->SetException(old_exception.get());
   }
 }
 
@@ -2549,7 +2568,8 @@
     Dbg::DdmSendChunk(CHUNK_TYPE("THDE"), 4, buf);
   } else {
     CHECK(type == CHUNK_TYPE("THCR") || type == CHUNK_TYPE("THNM")) << type;
-    SirtRef<String> name(t->GetThreadName());
+    ScopedJniThreadState ts(Thread::Current());
+    SirtRef<String> name(t->GetThreadName(ts));
     size_t char_count = (name.get() != NULL) ? name->GetLength() : 0;
     const jchar* chars = name->GetCharArray()->GetData();
 
diff --git a/src/exception_test.cc b/src/exception_test.cc
index 90bcb7c..269822a 100644
--- a/src/exception_test.cc
+++ b/src/exception_test.cc
@@ -19,6 +19,7 @@
 #include "dex_file.h"
 #include "gtest/gtest.h"
 #include "runtime.h"
+#include "scoped_jni_thread_state.h"
 #include "thread.h"
 #include "UniquePtr.h"
 
@@ -160,12 +161,13 @@
 #endif
 
   JNIEnv* env = thread->GetJniEnv();
-  jobject internal = thread->CreateInternalStackTrace(env);
+  ScopedJniThreadState ts(env);
+  jobject internal = thread->CreateInternalStackTrace(ts);
   ASSERT_TRUE(internal != NULL);
   jobjectArray ste_array = Thread::InternalStackTraceToStackTraceElementArray(env, internal);
   ASSERT_TRUE(ste_array != NULL);
   ObjectArray<StackTraceElement>* trace_array =
-      Decode<ObjectArray<StackTraceElement>*>(env, ste_array);
+      ts.Decode<ObjectArray<StackTraceElement>*>(ste_array);
 
   ASSERT_TRUE(trace_array != NULL);
   ASSERT_TRUE(trace_array->Get(0) != NULL);
diff --git a/src/heap.cc b/src/heap.cc
index 9fbfa32..c6dfdf7 100644
--- a/src/heap.cc
+++ b/src/heap.cc
@@ -31,6 +31,7 @@
 #include "object_utils.h"
 #include "os.h"
 #include "scoped_heap_lock.h"
+#include "scoped_jni_thread_state.h"
 #include "scoped_thread_list_lock_releaser.h"
 #include "ScopedLocalRef.h"
 #include "space.h"
@@ -935,20 +936,19 @@
 }
 
 void Heap::AddFinalizerReference(Thread* self, Object* object) {
-  ScopedThreadStateChange tsc(self, kRunnable);
+  ScopedJniThreadState ts(self);
   JValue args[1];
   args[0].SetL(object);
-  DecodeMethod(WellKnownClasses::java_lang_ref_FinalizerReference_add)->Invoke(self, NULL, args, NULL);
+  ts.DecodeMethod(WellKnownClasses::java_lang_ref_FinalizerReference_add)->Invoke(self, NULL, args, NULL);
 }
 
 void Heap::EnqueueClearedReferences(Object** cleared) {
   DCHECK(cleared != NULL);
   if (*cleared != NULL) {
-    Thread* self = Thread::Current();
-    ScopedThreadStateChange tsc(self, kRunnable);
+    ScopedJniThreadState ts(Thread::Current());
     JValue args[1];
     args[0].SetL(*cleared);
-    DecodeMethod(WellKnownClasses::java_lang_ref_ReferenceQueue_add)->Invoke(self, NULL, args, NULL);
+    ts.DecodeMethod(WellKnownClasses::java_lang_ref_ReferenceQueue_add)->Invoke(ts.Self(), NULL, args, NULL);
     *cleared = NULL;
   }
 }
diff --git a/src/jdwp/jdwp_main.cc b/src/jdwp/jdwp_main.cc
index 3a4d398..dfe83ff 100644
--- a/src/jdwp/jdwp_main.cc
+++ b/src/jdwp/jdwp_main.cc
@@ -270,7 +270,7 @@
 
 void JdwpState::Run() {
   Runtime* runtime = Runtime::Current();
-  runtime->AttachCurrentThread("JDWP", true, Thread::GetSystemThreadGroup());
+  runtime->AttachCurrentThread("JDWP", true, runtime->GetSystemThreadGroup());
 
   VLOG(jdwp) << "JDWP: thread running";
 
diff --git a/src/jni_compiler_test.cc b/src/jni_compiler_test.cc
index f5e1d1e..22b4b2c 100644
--- a/src/jni_compiler_test.cc
+++ b/src/jni_compiler_test.cc
@@ -536,10 +536,10 @@
     ScopedJniThreadState ts(env);
 
     // Build stack trace
-    jobject internal = Thread::Current()->CreateInternalStackTrace(env);
+    jobject internal = Thread::Current()->CreateInternalStackTrace(ts);
     jobjectArray ste_array = Thread::InternalStackTraceToStackTraceElementArray(env, internal);
     ObjectArray<StackTraceElement>* trace_array =
-        Decode<ObjectArray<StackTraceElement>*>(env, ste_array);
+        ts.Decode<ObjectArray<StackTraceElement>*>(ste_array);
     EXPECT_TRUE(trace_array != NULL);
     EXPECT_EQ(11, trace_array->GetLength());
 
@@ -591,8 +591,9 @@
 
 jint local_ref_test(JNIEnv* env, jobject thisObj, jint x) {
   // Add 10 local references
+  ScopedJniThreadState ts(env);
   for (int i = 0; i < 10; i++) {
-    AddLocalReference<jobject>(env, Decode<Object*>(env, thisObj));
+    ts.AddLocalReference<jobject>(ts.Decode<Object*>(thisObj));
   }
   return x+1;
 }
diff --git a/src/jni_internal.cc b/src/jni_internal.cc
index fa79a01..74b740a 100644
--- a/src/jni_internal.cc
+++ b/src/jni_internal.cc
@@ -72,85 +72,6 @@
   }
 }
 
-/*
- * Add a local reference for an object to the current stack frame.  When
- * the native function returns, the reference will be discarded.
- *
- * We need to allow the same reference to be added multiple times.
- *
- * This will be called on otherwise unreferenced objects.  We cannot do
- * GC allocations here, and it's best if we don't grab a mutex.
- *
- * Returns the local reference (currently just the same pointer that was
- * passed in), or NULL on failure.
- */
-template<typename T>
-T AddLocalReference(JNIEnv* public_env, const Object* const_obj) {
-  // The jobject type hierarchy has no notion of const, so it's not worth carrying through.
-  Object* obj = const_cast<Object*>(const_obj);
-
-  if (obj == NULL) {
-    return NULL;
-  }
-
-  DCHECK_NE((reinterpret_cast<uintptr_t>(obj) & 0xffff0000), 0xebad0000);
-
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  IndirectReferenceTable& locals = env->locals;
-
-  uint32_t cookie = env->local_ref_cookie;
-  IndirectRef ref = locals.Add(cookie, obj);
-
-#if 0 // TODO: fix this to understand PushLocalFrame, so we can turn it on.
-  if (env->check_jni) {
-    size_t entry_count = locals.Capacity();
-    if (entry_count > 16) {
-      LOG(WARNING) << "Warning: more than 16 JNI local references: "
-                   << entry_count << " (most recent was a " << PrettyTypeOf(obj) << ")\n"
-                   << Dumpable<IndirectReferenceTable>(locals);
-      // TODO: LOG(FATAL) in a later release?
-    }
-  }
-#endif
-
-  if (env->vm->work_around_app_jni_bugs) {
-    // Hand out direct pointers to support broken old apps.
-    return reinterpret_cast<T>(obj);
-  }
-
-  return reinterpret_cast<T>(ref);
-}
-// Explicit instantiations
-template jclass AddLocalReference<jclass>(JNIEnv* public_env, const Object* const_obj);
-template jobject AddLocalReference<jobject>(JNIEnv* public_env, const Object* const_obj);
-template jobjectArray AddLocalReference<jobjectArray>(JNIEnv* public_env, const Object* const_obj);
-template jstring AddLocalReference<jstring>(JNIEnv* public_env, const Object* const_obj);
-template jthrowable AddLocalReference<jthrowable>(JNIEnv* public_env, const Object* const_obj);
-
-// For external use.
-template<typename T>
-T Decode(JNIEnv* public_env, jobject obj) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  return reinterpret_cast<T>(env->self->DecodeJObject(obj));
-}
-// TODO: Change to use template when Mac OS build server no longer uses GCC 4.2.*.
-Object* DecodeObj(JNIEnv* public_env, jobject obj) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  return reinterpret_cast<Object*>(env->self->DecodeJObject(obj));
-}
-// Explicit instantiations.
-template Array* Decode<Array*>(JNIEnv*, jobject);
-template Class* Decode<Class*>(JNIEnv*, jobject);
-template ClassLoader* Decode<ClassLoader*>(JNIEnv*, jobject);
-template Object* Decode<Object*>(JNIEnv*, jobject);
-template ObjectArray<Class>* Decode<ObjectArray<Class>*>(JNIEnv*, jobject);
-template ObjectArray<ObjectArray<Class> >* Decode<ObjectArray<ObjectArray<Class> >*>(JNIEnv*, jobject);
-template ObjectArray<Object>* Decode<ObjectArray<Object>*>(JNIEnv*, jobject);
-template ObjectArray<StackTraceElement>* Decode<ObjectArray<StackTraceElement>*>(JNIEnv*, jobject);
-template ObjectArray<Method>* Decode<ObjectArray<Method>*>(JNIEnv*, jobject);
-template String* Decode<String*>(JNIEnv*, jobject);
-template Throwable* Decode<Throwable*>(JNIEnv*, jobject);
-
 size_t NumArgArrayBytes(const char* shorty, uint32_t shorty_len) {
   size_t num_bytes = 0;
   for (size_t i = 1; i < shorty_len; ++i) {
@@ -186,8 +107,7 @@
     return arg_array_;
   }
 
-  void BuildArgArray(JNIEnv* public_env, va_list ap) {
-    JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
+  void BuildArgArray(const ScopedJniThreadState& ts, va_list ap) {
     for (size_t i = 1, offset = 0; i < shorty_len_; ++i, ++offset) {
       switch (shorty_[i]) {
         case 'Z':
@@ -209,7 +129,7 @@
           arg_array_[offset].SetF(va_arg(ap, jdouble));
           break;
         case 'L':
-          arg_array_[offset].SetL(DecodeObj(env, va_arg(ap, jobject)));
+          arg_array_[offset].SetL(ts.Decode<Object*>(va_arg(ap, jobject)));
           break;
         case 'D':
           arg_array_[offset].SetD(va_arg(ap, jdouble));
@@ -221,8 +141,7 @@
     }
   }
 
-  void BuildArgArray(JNIEnv* public_env, jvalue* args) {
-    JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
+  void BuildArgArray(const ScopedJniThreadState& ts, jvalue* args) {
     for (size_t i = 1, offset = 0; i < shorty_len_; ++i, ++offset) {
       switch (shorty_[i]) {
         case 'Z':
@@ -244,7 +163,7 @@
           arg_array_[offset].SetF(args[offset].f);
           break;
         case 'L':
-          arg_array_[offset].SetL(DecodeObj(env, args[offset].l));
+          arg_array_[offset].SetL(ts.Decode<Object*>(args[offset].l));
           break;
         case 'D':
           arg_array_[offset].SetD(args[offset].d);
@@ -276,12 +195,6 @@
   return reinterpret_cast<jweak>(ref);
 }
 
-// For internal use.
-template<typename T>
-static T Decode(ScopedJniThreadState& ts, jobject obj) {
-  return reinterpret_cast<T>(ts.Self()->DecodeJObject(obj));
-}
-
 static void CheckMethodArguments(Method* m, JValue* args) {
   MethodHelper mh(m);
   ObjectArray<Class>* parameter_types = mh.GetParameterTypes();
@@ -306,47 +219,45 @@
   }
 }
 
-static JValue InvokeWithArgArray(JNIEnv* public_env, Object* receiver, Method* method, JValue* args) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  if (UNLIKELY(env->check_jni)) {
+static JValue InvokeWithArgArray(const ScopedJniThreadState& ts, Object* receiver, Method* method,
+                                 JValue* args) {
+  if (UNLIKELY(ts.Env()->check_jni)) {
     CheckMethodArguments(method, args);
   }
   JValue result;
-  method->Invoke(env->self, receiver, args, &result);
+  method->Invoke(ts.Self(), receiver, args, &result);
   return result;
 }
 
-static JValue InvokeWithVarArgs(JNIEnv* public_env, jobject obj, jmethodID mid, va_list args) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  Object* receiver = DecodeObj(env, obj);
-  Method* method = DecodeMethod(mid);
+static JValue InvokeWithVarArgs(const ScopedJniThreadState& ts, jobject obj, jmethodID mid,
+                                va_list args) {
+  Object* receiver = ts.Decode<Object*>(obj);
+  Method* method = ts.DecodeMethod(mid);
   ArgArray arg_array(method);
-  arg_array.BuildArgArray(env, args);
-  return InvokeWithArgArray(env, receiver, method, arg_array.get());
+  arg_array.BuildArgArray(ts, args);
+  return InvokeWithArgArray(ts, receiver, method, arg_array.get());
 }
 
 static Method* FindVirtualMethod(Object* receiver, Method* method) {
   return receiver->GetClass()->FindVirtualMethodForVirtualOrInterface(method);
 }
 
-static JValue InvokeVirtualOrInterfaceWithJValues(JNIEnv* public_env, jobject obj, jmethodID mid,
-                                                  jvalue* args) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  Object* receiver = DecodeObj(env, obj);
-  Method* method = FindVirtualMethod(receiver, DecodeMethod(mid));
+static JValue InvokeVirtualOrInterfaceWithJValues(const ScopedJniThreadState& ts, jobject obj,
+                                                  jmethodID mid, jvalue* args) {
+  Object* receiver = ts.Decode<Object*>(obj);
+  Method* method = FindVirtualMethod(receiver, ts.DecodeMethod(mid));
   ArgArray arg_array(method);
-  arg_array.BuildArgArray(env, args);
-  return InvokeWithArgArray(env, receiver, method, arg_array.get());
+  arg_array.BuildArgArray(ts, args);
+  return InvokeWithArgArray(ts, receiver, method, arg_array.get());
 }
 
-static JValue InvokeVirtualOrInterfaceWithVarArgs(JNIEnv* public_env, jobject obj, jmethodID mid,
-                                                  va_list args) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  Object* receiver = DecodeObj(env, obj);
-  Method* method = FindVirtualMethod(receiver, DecodeMethod(mid));
+static JValue InvokeVirtualOrInterfaceWithVarArgs(const ScopedJniThreadState& ts, jobject obj,
+                                                  jmethodID mid, va_list args) {
+  Object* receiver = ts.Decode<Object*>(obj);
+  Method* method = FindVirtualMethod(receiver, ts.DecodeMethod(mid));
   ArgArray arg_array(method);
-  arg_array.BuildArgArray(env, args);
-  return InvokeWithArgArray(env, receiver, method, arg_array.get());
+  arg_array.BuildArgArray(ts, args);
+  return InvokeWithArgArray(ts, receiver, method, arg_array.get());
 }
 
 // Section 12.3.2 of the JNI spec describes JNI class descriptors. They're
@@ -379,7 +290,7 @@
 }
 
 static jmethodID FindMethodID(ScopedJniThreadState& ts, jclass jni_class, const char* name, const char* sig, bool is_static) {
-  Class* c = Decode<Class*>(ts, jni_class);
+  Class* c = ts.Decode<Class*>(jni_class);
   if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
     return NULL;
   }
@@ -401,10 +312,10 @@
     return NULL;
   }
 
-  return EncodeMethod(method);
+  return ts.EncodeMethod(method);
 }
 
-static const ClassLoader* GetClassLoader(Thread* self) {
+static ClassLoader* GetClassLoader(Thread* self) {
   Method* method = self->GetCurrentMethod();
   if (method == NULL || PrettyMethod(method, false) == "java.lang.Runtime.nativeLoad") {
     return self->GetClassLoaderOverride();
@@ -412,8 +323,9 @@
   return method->GetDeclaringClass()->GetClassLoader();
 }
 
-static jfieldID FindFieldID(ScopedJniThreadState& ts, jclass jni_class, const char* name, const char* sig, bool is_static) {
-  Class* c = Decode<Class*>(ts, jni_class);
+static jfieldID FindFieldID(const ScopedJniThreadState& ts, jclass jni_class, const char* name,
+                            const char* sig, bool is_static) {
+  Class* c = ts.Decode<Class*>(jni_class);
   if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
     return NULL;
   }
@@ -422,7 +334,7 @@
   Class* field_type;
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   if (sig[1] != '\0') {
-    const ClassLoader* cl = GetClassLoader(ts.Self());
+    ClassLoader* cl = GetClassLoader(ts.Self());
     field_type = class_linker->FindClass(sig, cl);
   } else {
     field_type = class_linker->FindPrimitiveClass(*sig);
@@ -447,31 +359,31 @@
         name, ClassHelper(c).GetDescriptor());
     return NULL;
   }
-  return EncodeField(field);
+  return ts.EncodeField(field);
 }
 
-static void PinPrimitiveArray(ScopedJniThreadState& ts, const Array* array) {
+static void PinPrimitiveArray(const ScopedJniThreadState& ts, const Array* array) {
   JavaVMExt* vm = ts.Vm();
   MutexLock mu(vm->pins_lock);
   vm->pin_table.Add(array);
 }
 
-static void UnpinPrimitiveArray(ScopedJniThreadState& ts, const Array* array) {
+static void UnpinPrimitiveArray(const ScopedJniThreadState& ts, const Array* array) {
   JavaVMExt* vm = ts.Vm();
   MutexLock mu(vm->pins_lock);
   vm->pin_table.Remove(array);
 }
 
 template<typename JniT, typename ArtT>
-static JniT NewPrimitiveArray(ScopedJniThreadState& ts, jsize length) {
+static JniT NewPrimitiveArray(const ScopedJniThreadState& ts, jsize length) {
   CHECK_GE(length, 0); // TODO: ReportJniError
   ArtT* result = ArtT::Alloc(length);
-  return AddLocalReference<JniT>(ts.Env(), result);
+  return ts.AddLocalReference<JniT>(result);
 }
 
 template <typename ArrayT, typename CArrayT, typename ArtArrayT>
 static CArrayT GetPrimitiveArray(ScopedJniThreadState& ts, ArrayT java_array, jboolean* is_copy) {
-  ArtArrayT* array = Decode<ArtArrayT*>(ts, java_array);
+  ArtArrayT* array = ts.Decode<ArtArrayT*>(java_array);
   PinPrimitiveArray(ts, array);
   if (is_copy != NULL) {
     *is_copy = JNI_FALSE;
@@ -482,7 +394,7 @@
 template <typename ArrayT>
 static void ReleasePrimitiveArray(ScopedJniThreadState& ts, ArrayT java_array, jint mode) {
   if (mode != JNI_COMMIT) {
-    Array* array = Decode<Array*>(ts, java_array);
+    Array* array = ts.Decode<Array*>(java_array);
     UnpinPrimitiveArray(ts, array);
   }
 }
@@ -501,7 +413,7 @@
 
 template <typename JavaArrayT, typename JavaT, typename ArrayT>
 static void GetPrimitiveArrayRegion(ScopedJniThreadState& ts, JavaArrayT java_array, jsize start, jsize length, JavaT* buf) {
-  ArrayT* array = Decode<ArrayT*>(ts, java_array);
+  ArrayT* array = ts.Decode<ArrayT*>(java_array);
   if (start < 0 || length < 0 || start + length > array->GetLength()) {
     ThrowAIOOBE(ts, array, start, length, "src");
   } else {
@@ -512,7 +424,7 @@
 
 template <typename JavaArrayT, typename JavaT, typename ArrayT>
 static void SetPrimitiveArrayRegion(ScopedJniThreadState& ts, JavaArrayT java_array, jsize start, jsize length, const JavaT* buf) {
-  ArrayT* array = Decode<ArrayT*>(ts, java_array);
+  ArrayT* array = ts.Decode<ArrayT*>(java_array);
   if (start < 0 || length < 0 || start + length > array->GetLength()) {
     ThrowAIOOBE(ts, array, start, length, "dst");
   } else {
@@ -548,7 +460,8 @@
     }
     jmethodID mid = env->GetMethodID(exception_class, "<init>", signature);
     if (mid == NULL) {
-      LOG(ERROR) << "No <init>" << signature << " in " << PrettyClass(Decode<Class*>(env, exception_class));
+      LOG(ERROR) << "No <init>" << signature << " in "
+          << PrettyClass(ts.Decode<Class*>(exception_class));
       return JNI_ERR;
     }
 
@@ -557,7 +470,7 @@
       return JNI_ERR;
     }
 
-    ts.Self()->SetException(Decode<Throwable*>(ts, exception.get()));
+    ts.Self()->SetException(ts.Decode<Throwable*>(exception.get()));
 
     return JNI_OK;
 }
@@ -584,11 +497,11 @@
 
   JavaVMAttachArgs* args = static_cast<JavaVMAttachArgs*>(raw_args);
   const char* thread_name = NULL;
-  Object* thread_group = NULL;
+  jobject thread_group = NULL;
   if (args != NULL) {
     CHECK_GE(args->version, JNI_VERSION_1_2);
     thread_name = args->name;
-    thread_group = static_cast<Thread*>(NULL)->DecodeJObject(args->group);
+    thread_group = args->group;
   }
 
   runtime->AttachCurrentThread(thread_name, as_daemon, thread_group);
@@ -754,17 +667,16 @@
   SafeMap<std::string, SharedLibrary*> libraries_;
 };
 
-JValue InvokeWithJValues(JNIEnv* public_env, jobject obj, jmethodID mid, jvalue* args) {
-  JNIEnvExt* env = reinterpret_cast<JNIEnvExt*>(public_env);
-  Object* receiver = Decode<Object*>(env, obj);
-  Method* method = DecodeMethod(mid);
+JValue InvokeWithJValues(const ScopedJniThreadState& ts, jobject obj, jmethodID mid, jvalue* args) {
+  Object* receiver = ts.Decode<Object*>(obj);
+  Method* method = ts.DecodeMethod(mid);
   ArgArray arg_array(method);
-  arg_array.BuildArgArray(env, args);
-  return InvokeWithArgArray(env, receiver, method, arg_array.get());
+  arg_array.BuildArgArray(ts, args);
+  return InvokeWithArgArray(ts, receiver, method, arg_array.get());
 }
 
-JValue InvokeWithJValues(Thread* self, Object* receiver, Method* m, JValue* args) {
-  return InvokeWithArgArray(self->GetJniEnv(), receiver, m, args);
+JValue InvokeWithJValues(const ScopedJniThreadState& ts, Object* receiver, Method* m, JValue* args) {
+  return InvokeWithArgArray(ts, receiver, m, args);
 }
 
 class JNI {
@@ -787,54 +699,54 @@
     std::string descriptor(NormalizeJniClassDescriptor(name));
     Class* c = NULL;
     if (runtime->IsStarted()) {
-      const ClassLoader* cl = GetClassLoader(ts.Self());
+      ClassLoader* cl = GetClassLoader(ts.Self());
       c = class_linker->FindClass(descriptor.c_str(), cl);
     } else {
       c = class_linker->FindSystemClass(descriptor.c_str());
     }
-    return AddLocalReference<jclass>(env, c);
+    return ts.AddLocalReference<jclass>(c);
   }
 
   static jmethodID FromReflectedMethod(JNIEnv* env, jobject java_method) {
     ScopedJniThreadState ts(env);
-    Method* method = Decode<Method*>(ts, java_method);
-    return EncodeMethod(method);
+    Method* method = ts.Decode<Method*>(java_method);
+    return ts.EncodeMethod(method);
   }
 
   static jfieldID FromReflectedField(JNIEnv* env, jobject java_field) {
     ScopedJniThreadState ts(env);
-    Field* field = Decode<Field*>(ts, java_field);
-    return EncodeField(field);
+    Field* field = ts.Decode<Field*>(java_field);
+    return ts.EncodeField(field);
   }
 
   static jobject ToReflectedMethod(JNIEnv* env, jclass, jmethodID mid, jboolean) {
     ScopedJniThreadState ts(env);
-    Method* method = DecodeMethod(mid);
-    return AddLocalReference<jobject>(env, method);
+    Method* method = ts.DecodeMethod(mid);
+    return ts.AddLocalReference<jobject>(method);
   }
 
   static jobject ToReflectedField(JNIEnv* env, jclass, jfieldID fid, jboolean) {
     ScopedJniThreadState ts(env);
-    Field* field = DecodeField(fid);
-    return AddLocalReference<jobject>(env, field);
+    Field* field = ts.DecodeField(fid);
+    return ts.AddLocalReference<jobject>(field);
   }
 
   static jclass GetObjectClass(JNIEnv* env, jobject java_object) {
     ScopedJniThreadState ts(env);
-    Object* o = Decode<Object*>(ts, java_object);
-    return AddLocalReference<jclass>(env, o->GetClass());
+    Object* o = ts.Decode<Object*>(java_object);
+    return ts.AddLocalReference<jclass>(o->GetClass());
   }
 
   static jclass GetSuperclass(JNIEnv* env, jclass java_class) {
     ScopedJniThreadState ts(env);
-    Class* c = Decode<Class*>(ts, java_class);
-    return AddLocalReference<jclass>(env, c->GetSuperClass());
+    Class* c = ts.Decode<Class*>(java_class);
+    return ts.AddLocalReference<jclass>(c->GetSuperClass());
   }
 
   static jboolean IsAssignableFrom(JNIEnv* env, jclass java_class1, jclass java_class2) {
     ScopedJniThreadState ts(env);
-    Class* c1 = Decode<Class*>(ts, java_class1);
-    Class* c2 = Decode<Class*>(ts, java_class2);
+    Class* c1 = ts.Decode<Class*>(java_class1);
+    Class* c2 = ts.Decode<Class*>(java_class2);
     return c1->IsAssignableFrom(c2) ? JNI_TRUE : JNI_FALSE;
   }
 
@@ -845,15 +757,15 @@
       // Note: JNI is different from regular Java instanceof in this respect
       return JNI_TRUE;
     } else {
-      Object* obj = Decode<Object*>(ts, jobj);
-      Class* c = Decode<Class*>(ts, java_class);
+      Object* obj = ts.Decode<Object*>(jobj);
+      Class* c = ts.Decode<Class*>(java_class);
       return obj->InstanceOf(c) ? JNI_TRUE : JNI_FALSE;
     }
   }
 
   static jint Throw(JNIEnv* env, jthrowable java_exception) {
     ScopedJniThreadState ts(env);
-    Throwable* exception = Decode<Throwable*>(ts, java_exception);
+    Throwable* exception = ts.Decode<Throwable*>(java_exception);
     if (exception == NULL) {
       return JNI_ERR;
     }
@@ -882,7 +794,7 @@
     Throwable* original_exception = self->GetException();
     self->ClearException();
 
-    ScopedLocalRef<jthrowable> exception(env, AddLocalReference<jthrowable>(env, original_exception));
+    ScopedLocalRef<jthrowable> exception(env, ts.AddLocalReference<jthrowable>(original_exception));
     ScopedLocalRef<jclass> exception_class(env, env->GetObjectClass(exception.get()));
     jmethodID mid = env->GetMethodID(exception_class.get(), "printStackTrace", "()V");
     if (mid == NULL) {
@@ -903,7 +815,7 @@
   static jthrowable ExceptionOccurred(JNIEnv* env) {
     ScopedJniThreadState ts(env);
     Object* exception = ts.Self()->GetException();
-    return (exception != NULL) ? AddLocalReference<jthrowable>(env, exception) : NULL;
+    return ts.AddLocalReference<jthrowable>(exception);
   }
 
   static void FatalError(JNIEnv* env, const char* msg) {
@@ -922,9 +834,9 @@
 
   static jobject PopLocalFrame(JNIEnv* env, jobject java_survivor) {
     ScopedJniThreadState ts(env);
-    Object* survivor = Decode<Object*>(ts, java_survivor);
+    Object* survivor = ts.Decode<Object*>(java_survivor);
     ts.Env()->PopFrame();
-    return AddLocalReference<jobject>(env, survivor);
+    return ts.AddLocalReference<jobject>(survivor);
   }
 
   static jint EnsureLocalCapacity(JNIEnv* env, jint desired_capacity) {
@@ -932,7 +844,7 @@
     return EnsureLocalCapacity(ts, desired_capacity, "EnsureLocalCapacity");
   }
 
-  static jint EnsureLocalCapacity(ScopedJniThreadState& ts, jint desired_capacity, const char* caller) {
+  static jint EnsureLocalCapacity(const ScopedJniThreadState& ts, jint desired_capacity, const char* caller) {
     // TODO: we should try to expand the table if necessary.
     if (desired_capacity < 1 || desired_capacity > static_cast<jint>(kLocalsMax)) {
       LOG(ERROR) << "Invalid capacity given to " << caller << ": " << desired_capacity;
@@ -956,7 +868,7 @@
     JavaVMExt* vm = ts.Vm();
     IndirectReferenceTable& globals = vm->globals;
     MutexLock mu(vm->globals_lock);
-    IndirectRef ref = globals.Add(IRT_FIRST_SEGMENT, Decode<Object*>(ts, obj));
+    IndirectRef ref = globals.Add(IRT_FIRST_SEGMENT, ts.Decode<Object*>(obj));
     return reinterpret_cast<jobject>(ref);
   }
 
@@ -978,7 +890,7 @@
 
   static jweak NewWeakGlobalRef(JNIEnv* env, jobject obj) {
     ScopedJniThreadState ts(env);
-    return AddWeakGlobalReference(ts, Decode<Object*>(ts, obj));
+    return AddWeakGlobalReference(ts, ts.Decode<Object*>(obj));
   }
 
   static void DeleteWeakGlobalRef(JNIEnv* env, jweak obj) {
@@ -1006,7 +918,7 @@
     IndirectReferenceTable& locals = ts.Env()->locals;
 
     uint32_t cookie = ts.Env()->local_ref_cookie;
-    IndirectRef ref = locals.Add(cookie, Decode<Object*>(ts, obj));
+    IndirectRef ref = locals.Add(cookie, ts.Decode<Object*>(obj));
     return reinterpret_cast<jobject>(ref);
   }
 
@@ -1032,17 +944,17 @@
 
   static jboolean IsSameObject(JNIEnv* env, jobject obj1, jobject obj2) {
     ScopedJniThreadState ts(env);
-    return (Decode<Object*>(ts, obj1) == Decode<Object*>(ts, obj2))
+    return (ts.Decode<Object*>(obj1) == ts.Decode<Object*>(obj2))
         ? JNI_TRUE : JNI_FALSE;
   }
 
   static jobject AllocObject(JNIEnv* env, jclass java_class) {
     ScopedJniThreadState ts(env);
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts.Decode<Class*>(java_class);
     if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
       return NULL;
     }
-    return AddLocalReference<jobject>(env, c->AllocObject());
+    return ts.AddLocalReference<jobject>(c->AllocObject());
   }
 
   static jobject NewObject(JNIEnv* env, jclass c, jmethodID mid, ...) {
@@ -1056,7 +968,7 @@
 
   static jobject NewObjectV(JNIEnv* env, jclass java_class, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts.Decode<Class*>(java_class);
     if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
       return NULL;
     }
@@ -1064,7 +976,7 @@
     if (result == NULL) {
       return NULL;
     }
-    jobject local_result = AddLocalReference<jobject>(env, result);
+    jobject local_result = ts.AddLocalReference<jobject>(result);
     CallNonvirtualVoidMethodV(env, local_result, java_class, mid, args);
     if (!ts.Self()->IsExceptionPending()) {
       return local_result;
@@ -1075,7 +987,7 @@
 
   static jobject NewObjectA(JNIEnv* env, jclass java_class, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts.Decode<Class*>(java_class);
     if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
       return NULL;
     }
@@ -1083,7 +995,7 @@
     if (result == NULL) {
       return NULL;
     }
-    jobject local_result = AddLocalReference<jobjectArray>(env, result);
+    jobject local_result = ts.AddLocalReference<jobjectArray>(result);
     CallNonvirtualVoidMethodA(env, local_result, java_class, mid, args);
     if (!ts.Self()->IsExceptionPending()) {
       return local_result;
@@ -1106,199 +1018,199 @@
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
-    return AddLocalReference<jobject>(env, result.GetL());
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jobject CallObjectMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args));
-    return AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args));
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jobject CallObjectMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    JValue result(InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args));
-    return AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args));
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jboolean CallBooleanMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetZ();
   }
 
   static jboolean CallBooleanMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetZ();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetZ();
   }
 
   static jboolean CallBooleanMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetZ();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetZ();
   }
 
   static jbyte CallByteMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetB();
   }
 
   static jbyte CallByteMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetB();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetB();
   }
 
   static jbyte CallByteMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetB();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetB();
   }
 
   static jchar CallCharMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetC();
   }
 
   static jchar CallCharMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetC();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetC();
   }
 
   static jchar CallCharMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetC();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetC();
   }
 
   static jdouble CallDoubleMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetD();
   }
 
   static jdouble CallDoubleMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetD();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetD();
   }
 
   static jdouble CallDoubleMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetD();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetD();
   }
 
   static jfloat CallFloatMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetF();
   }
 
   static jfloat CallFloatMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetF();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetF();
   }
 
   static jfloat CallFloatMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetF();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetF();
   }
 
   static jint CallIntMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetI();
   }
 
   static jint CallIntMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetI();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetI();
   }
 
   static jint CallIntMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetI();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetI();
   }
 
   static jlong CallLongMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetJ();
   }
 
   static jlong CallLongMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetJ();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetJ();
   }
 
   static jlong CallLongMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetJ();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetJ();
   }
 
   static jshort CallShortMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetS();
   }
 
   static jshort CallShortMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args).GetS();
+    return InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args).GetS();
   }
 
   static jshort CallShortMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args).GetS();
+    return InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args).GetS();
   }
 
   static void CallVoidMethod(JNIEnv* env, jobject obj, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
   }
 
   static void CallVoidMethodV(JNIEnv* env, jobject obj, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    InvokeVirtualOrInterfaceWithVarArgs(env, obj, mid, args);
+    InvokeVirtualOrInterfaceWithVarArgs(ts, obj, mid, args);
   }
 
   static void CallVoidMethodA(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    InvokeVirtualOrInterfaceWithJValues(env, obj, mid, args);
+    InvokeVirtualOrInterfaceWithJValues(ts, obj, mid, args);
   }
 
   static jobject CallNonvirtualObjectMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
-    jobject local_result = AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
+    jobject local_result = ts.AddLocalReference<jobject>(result.GetL());
     va_end(ap);
     return local_result;
   }
@@ -1306,15 +1218,15 @@
   static jobject CallNonvirtualObjectMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    JValue result(InvokeWithVarArgs(env, obj, mid, args));
-    return AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeWithVarArgs(ts, obj, mid, args));
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jobject CallNonvirtualObjectMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    JValue result(InvokeWithJValues(env, obj, mid, args));
-    return AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeWithJValues(ts, obj, mid, args));
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jboolean CallNonvirtualBooleanMethod(JNIEnv* env,
@@ -1322,7 +1234,7 @@
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetZ();
   }
@@ -1330,20 +1242,20 @@
   static jboolean CallNonvirtualBooleanMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetZ();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetZ();
   }
 
   static jboolean CallNonvirtualBooleanMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetZ();
+    return InvokeWithJValues(ts, obj, mid, args).GetZ();
   }
 
   static jbyte CallNonvirtualByteMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetB();
   }
@@ -1351,20 +1263,20 @@
   static jbyte CallNonvirtualByteMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetB();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetB();
   }
 
   static jbyte CallNonvirtualByteMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetB();
+    return InvokeWithJValues(ts, obj, mid, args).GetB();
   }
 
   static jchar CallNonvirtualCharMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetC();
   }
@@ -1372,20 +1284,20 @@
   static jchar CallNonvirtualCharMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetC();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetC();
   }
 
   static jchar CallNonvirtualCharMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetC();
+    return InvokeWithJValues(ts, obj, mid, args).GetC();
   }
 
   static jshort CallNonvirtualShortMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetS();
   }
@@ -1393,20 +1305,20 @@
   static jshort CallNonvirtualShortMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetS();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetS();
   }
 
   static jshort CallNonvirtualShortMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetS();
+    return InvokeWithJValues(ts, obj, mid, args).GetS();
   }
 
   static jint CallNonvirtualIntMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetI();
   }
@@ -1414,20 +1326,20 @@
   static jint CallNonvirtualIntMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetI();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetI();
   }
 
   static jint CallNonvirtualIntMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetI();
+    return InvokeWithJValues(ts, obj, mid, args).GetI();
   }
 
   static jlong CallNonvirtualLongMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetJ();
   }
@@ -1435,20 +1347,20 @@
   static jlong CallNonvirtualLongMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetJ();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetJ();
   }
 
   static jlong CallNonvirtualLongMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetJ();
+    return InvokeWithJValues(ts, obj, mid, args).GetJ();
   }
 
   static jfloat CallNonvirtualFloatMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetF();
   }
@@ -1456,20 +1368,20 @@
   static jfloat CallNonvirtualFloatMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetF();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetF();
   }
 
   static jfloat CallNonvirtualFloatMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetF();
+    return InvokeWithJValues(ts, obj, mid, args).GetF();
   }
 
   static jdouble CallNonvirtualDoubleMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, obj, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, obj, mid, ap));
     va_end(ap);
     return result.GetD();
   }
@@ -1477,33 +1389,33 @@
   static jdouble CallNonvirtualDoubleMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, obj, mid, args).GetD();
+    return InvokeWithVarArgs(ts, obj, mid, args).GetD();
   }
 
   static jdouble CallNonvirtualDoubleMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, obj, mid, args).GetD();
+    return InvokeWithJValues(ts, obj, mid, args).GetD();
   }
 
   static void CallNonvirtualVoidMethod(JNIEnv* env, jobject obj, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    InvokeWithVarArgs(env, obj, mid, ap);
+    InvokeWithVarArgs(ts, obj, mid, ap);
     va_end(ap);
   }
 
   static void CallNonvirtualVoidMethodV(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    InvokeWithVarArgs(env, obj, mid, args);
+    InvokeWithVarArgs(ts, obj, mid, args);
   }
 
   static void CallNonvirtualVoidMethodA(JNIEnv* env,
       jobject obj, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    InvokeWithJValues(env, obj, mid, args);
+    InvokeWithJValues(ts, obj, mid, args);
   }
 
   static jfieldID GetFieldID(JNIEnv* env, jclass c, const char* name, const char* sig) {
@@ -1519,42 +1431,42 @@
 
   static jobject GetObjectField(JNIEnv* env, jobject obj, jfieldID fid) {
     ScopedJniThreadState ts(env);
-    Object* o = Decode<Object*>(ts, obj);
-    Field* f = DecodeField(fid);
-    return AddLocalReference<jobject>(env, f->GetObject(o));
+    Object* o = ts.Decode<Object*>(obj);
+    Field* f = ts.DecodeField(fid);
+    return ts.AddLocalReference<jobject>(f->GetObject(o));
   }
 
   static jobject GetStaticObjectField(JNIEnv* env, jclass, jfieldID fid) {
     ScopedJniThreadState ts(env);
-    Field* f = DecodeField(fid);
-    return AddLocalReference<jobject>(env, f->GetObject(NULL));
+    Field* f = ts.DecodeField(fid);
+    return ts.AddLocalReference<jobject>(f->GetObject(NULL));
   }
 
   static void SetObjectField(JNIEnv* env, jobject java_object, jfieldID fid, jobject java_value) {
     ScopedJniThreadState ts(env);
-    Object* o = Decode<Object*>(ts, java_object);
-    Object* v = Decode<Object*>(ts, java_value);
-    Field* f = DecodeField(fid);
+    Object* o = ts.Decode<Object*>(java_object);
+    Object* v = ts.Decode<Object*>(java_value);
+    Field* f = ts.DecodeField(fid);
     f->SetObject(o, v);
   }
 
   static void SetStaticObjectField(JNIEnv* env, jclass, jfieldID fid, jobject java_value) {
     ScopedJniThreadState ts(env);
-    Object* v = Decode<Object*>(ts, java_value);
-    Field* f = DecodeField(fid);
+    Object* v = ts.Decode<Object*>(java_value);
+    Field* f = ts.DecodeField(fid);
     f->SetObject(NULL, v);
   }
 
 #define GET_PRIMITIVE_FIELD(fn, instance) \
   ScopedJniThreadState ts(env); \
-  Object* o = Decode<Object*>(ts, instance); \
-  Field* f = DecodeField(fid); \
+  Object* o = ts.Decode<Object*>(instance); \
+  Field* f = ts.DecodeField(fid); \
   return f->fn(o)
 
 #define SET_PRIMITIVE_FIELD(fn, instance, value) \
   ScopedJniThreadState ts(env); \
-  Object* o = Decode<Object*>(ts, instance); \
-  Field* f = DecodeField(fid); \
+  Object* o = ts.Decode<Object*>(instance); \
+  Field* f = ts.DecodeField(fid); \
   f->fn(o, value)
 
   static jboolean GetBooleanField(JNIEnv* env, jobject obj, jfieldID fid) {
@@ -1689,222 +1601,222 @@
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
-    jobject local_result = AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
+    jobject local_result = ts.AddLocalReference<jobject>(result.GetL());
     va_end(ap);
     return local_result;
   }
 
   static jobject CallStaticObjectMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, args));
-    return AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, args));
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jobject CallStaticObjectMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    JValue result(InvokeWithJValues(env, NULL, mid, args));
-    return AddLocalReference<jobject>(env, result.GetL());
+    JValue result(InvokeWithJValues(ts, NULL, mid, args));
+    return ts.AddLocalReference<jobject>(result.GetL());
   }
 
   static jboolean CallStaticBooleanMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetZ();
   }
 
   static jboolean CallStaticBooleanMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetZ();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetZ();
   }
 
   static jboolean CallStaticBooleanMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetZ();
+    return InvokeWithJValues(ts, NULL, mid, args).GetZ();
   }
 
   static jbyte CallStaticByteMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetB();
   }
 
   static jbyte CallStaticByteMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetB();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetB();
   }
 
   static jbyte CallStaticByteMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetB();
+    return InvokeWithJValues(ts, NULL, mid, args).GetB();
   }
 
   static jchar CallStaticCharMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetC();
   }
 
   static jchar CallStaticCharMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetC();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetC();
   }
 
   static jchar CallStaticCharMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetC();
+    return InvokeWithJValues(ts, NULL, mid, args).GetC();
   }
 
   static jshort CallStaticShortMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetS();
   }
 
   static jshort CallStaticShortMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetS();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetS();
   }
 
   static jshort CallStaticShortMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetS();
+    return InvokeWithJValues(ts, NULL, mid, args).GetS();
   }
 
   static jint CallStaticIntMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetI();
   }
 
   static jint CallStaticIntMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetI();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetI();
   }
 
   static jint CallStaticIntMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetI();
+    return InvokeWithJValues(ts, NULL, mid, args).GetI();
   }
 
   static jlong CallStaticLongMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetJ();
   }
 
   static jlong CallStaticLongMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetJ();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetJ();
   }
 
   static jlong CallStaticLongMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetJ();
+    return InvokeWithJValues(ts, NULL, mid, args).GetJ();
   }
 
   static jfloat CallStaticFloatMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetF();
   }
 
   static jfloat CallStaticFloatMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetF();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetF();
   }
 
   static jfloat CallStaticFloatMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetF();
+    return InvokeWithJValues(ts, NULL, mid, args).GetF();
   }
 
   static jdouble CallStaticDoubleMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    JValue result(InvokeWithVarArgs(env, NULL, mid, ap));
+    JValue result(InvokeWithVarArgs(ts, NULL, mid, ap));
     va_end(ap);
     return result.GetD();
   }
 
   static jdouble CallStaticDoubleMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithVarArgs(env, NULL, mid, args).GetD();
+    return InvokeWithVarArgs(ts, NULL, mid, args).GetD();
   }
 
   static jdouble CallStaticDoubleMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    return InvokeWithJValues(env, NULL, mid, args).GetD();
+    return InvokeWithJValues(ts, NULL, mid, args).GetD();
   }
 
   static void CallStaticVoidMethod(JNIEnv* env, jclass, jmethodID mid, ...) {
     ScopedJniThreadState ts(env);
     va_list ap;
     va_start(ap, mid);
-    InvokeWithVarArgs(env, NULL, mid, ap);
+    InvokeWithVarArgs(ts, NULL, mid, ap);
     va_end(ap);
   }
 
   static void CallStaticVoidMethodV(JNIEnv* env, jclass, jmethodID mid, va_list args) {
     ScopedJniThreadState ts(env);
-    InvokeWithVarArgs(env, NULL, mid, args);
+    InvokeWithVarArgs(ts, NULL, mid, args);
   }
 
   static void CallStaticVoidMethodA(JNIEnv* env, jclass, jmethodID mid, jvalue* args) {
     ScopedJniThreadState ts(env);
-    InvokeWithJValues(env, NULL, mid, args);
+    InvokeWithJValues(ts, NULL, mid, args);
   }
 
   static jstring NewString(JNIEnv* env, const jchar* chars, jsize char_count) {
     ScopedJniThreadState ts(env);
     String* result = String::AllocFromUtf16(char_count, chars);
-    return AddLocalReference<jstring>(env, result);
+    return ts.AddLocalReference<jstring>(result);
   }
 
   static jstring NewStringUTF(JNIEnv* env, const char* utf) {
-    ScopedJniThreadState ts(env);
     if (utf == NULL) {
       return NULL;
     }
+    ScopedJniThreadState ts(env);
     String* result = String::AllocFromModifiedUtf8(utf);
-    return AddLocalReference<jstring>(env, result);
+    return ts.AddLocalReference<jstring>(result);
   }
 
   static jsize GetStringLength(JNIEnv* env, jstring java_string) {
     ScopedJniThreadState ts(env);
-    return Decode<String*>(ts, java_string)->GetLength();
+    return ts.Decode<String*>(java_string)->GetLength();
   }
 
   static jsize GetStringUTFLength(JNIEnv* env, jstring java_string) {
     ScopedJniThreadState ts(env);
-    return Decode<String*>(ts, java_string)->GetUtfLength();
+    return ts.Decode<String*>(java_string)->GetUtfLength();
   }
 
   static void GetStringRegion(JNIEnv* env, jstring java_string, jsize start, jsize length, jchar* buf) {
     ScopedJniThreadState ts(env);
-    String* s = Decode<String*>(ts, java_string);
+    String* s = ts.Decode<String*>(java_string);
     if (start < 0 || length < 0 || start + length > s->GetLength()) {
       ThrowSIOOBE(ts, start, length, s->GetLength());
     } else {
@@ -1915,7 +1827,7 @@
 
   static void GetStringUTFRegion(JNIEnv* env, jstring java_string, jsize start, jsize length, char* buf) {
     ScopedJniThreadState ts(env);
-    String* s = Decode<String*>(ts, java_string);
+    String* s = ts.Decode<String*>(java_string);
     if (start < 0 || length < 0 || start + length > s->GetLength()) {
       ThrowSIOOBE(ts, start, length, s->GetLength());
     } else {
@@ -1926,7 +1838,7 @@
 
   static const jchar* GetStringChars(JNIEnv* env, jstring java_string, jboolean* is_copy) {
     ScopedJniThreadState ts(env);
-    String* s = Decode<String*>(ts, java_string);
+    String* s = ts.Decode<String*>(java_string);
     const CharArray* chars = s->GetCharArray();
     PinPrimitiveArray(ts, chars);
     if (is_copy != NULL) {
@@ -1937,7 +1849,7 @@
 
   static void ReleaseStringChars(JNIEnv* env, jstring java_string, const jchar*) {
     ScopedJniThreadState ts(env);
-    UnpinPrimitiveArray(ts, Decode<String*>(ts, java_string)->GetCharArray());
+    UnpinPrimitiveArray(ts, ts.Decode<String*>(java_string)->GetCharArray());
   }
 
   static const jchar* GetStringCritical(JNIEnv* env, jstring java_string, jboolean* is_copy) {
@@ -1958,7 +1870,7 @@
     if (is_copy != NULL) {
       *is_copy = JNI_TRUE;
     }
-    String* s = Decode<String*>(ts, java_string);
+    String* s = ts.Decode<String*>(java_string);
     size_t byte_count = s->GetUtfLength();
     char* bytes = new char[byte_count + 1];
     CHECK(bytes != NULL); // bionic aborts anyway.
@@ -1975,7 +1887,7 @@
 
   static jsize GetArrayLength(JNIEnv* env, jarray java_array) {
     ScopedJniThreadState ts(env);
-    Object* obj = Decode<Object*>(ts, java_array);
+    Object* obj = ts.Decode<Object*>(java_array);
     CHECK(obj->IsArrayInstance()); // TODO: ReportJniError
     Array* array = obj->AsArray();
     return array->GetLength();
@@ -1983,15 +1895,15 @@
 
   static jobject GetObjectArrayElement(JNIEnv* env, jobjectArray java_array, jsize index) {
     ScopedJniThreadState ts(env);
-    ObjectArray<Object>* array = Decode<ObjectArray<Object>*>(ts, java_array);
-    return AddLocalReference<jobject>(env, array->Get(index));
+    ObjectArray<Object>* array = ts.Decode<ObjectArray<Object>*>(java_array);
+    return ts.AddLocalReference<jobject>(array->Get(index));
   }
 
   static void SetObjectArrayElement(JNIEnv* env,
       jobjectArray java_array, jsize index, jobject java_value) {
     ScopedJniThreadState ts(env);
-    ObjectArray<Object>* array = Decode<ObjectArray<Object>*>(ts, java_array);
-    Object* value = Decode<Object*>(ts, java_value);
+    ObjectArray<Object>* array = ts.Decode<ObjectArray<Object>*>(java_array);
+    Object* value = ts.Decode<Object*>(java_value);
     array->Set(index, value);
   }
 
@@ -2035,7 +1947,7 @@
     CHECK_GE(length, 0); // TODO: ReportJniError
 
     // Compute the array class corresponding to the given element class.
-    Class* element_class = Decode<Class*>(ts, element_jclass);
+    Class* element_class = ts.Decode<Class*>(element_jclass);
     std::string descriptor;
     descriptor += "[";
     descriptor += ClassHelper(element_class).GetDescriptor();
@@ -2047,15 +1959,15 @@
     }
 
     // Allocate and initialize if necessary.
-    Class* array_class = Decode<Class*>(ts, java_array_class.get());
+    Class* array_class = ts.Decode<Class*>(java_array_class.get());
     ObjectArray<Object>* result = ObjectArray<Object>::Alloc(array_class, length);
     if (initial_element != NULL) {
-      Object* initial_object = Decode<Object*>(ts, initial_element);
+      Object* initial_object = ts.Decode<Object*>(initial_element);
       for (jsize i = 0; i < length; ++i) {
         result->Set(i, initial_object);
       }
     }
-    return AddLocalReference<jobjectArray>(env, result);
+    return ts.AddLocalReference<jobjectArray>(result);
   }
 
   static jshortArray NewShortArray(JNIEnv* env, jsize length) {
@@ -2065,7 +1977,7 @@
 
   static void* GetPrimitiveArrayCritical(JNIEnv* env, jarray java_array, jboolean* is_copy) {
     ScopedJniThreadState ts(env);
-    Array* array = Decode<Array*>(ts, java_array);
+    Array* array = ts.Decode<Array*>(java_array);
     PinPrimitiveArray(ts, array);
     if (is_copy != NULL) {
       *is_copy = JNI_FALSE;
@@ -2240,7 +2152,7 @@
 
   static jint RegisterNatives(JNIEnv* env, jclass java_class, const JNINativeMethod* methods, jint method_count) {
     ScopedJniThreadState ts(env);
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts.Decode<Class*>(java_class);
 
     for (int i = 0; i < method_count; i++) {
       const char* name = methods[i].name;
@@ -2274,7 +2186,7 @@
 
   static jint UnregisterNatives(JNIEnv* env, jclass java_class) {
     ScopedJniThreadState ts(env);
-    Class* c = Decode<Class*>(ts, java_class);
+    Class* c = ts.Decode<Class*>(java_class);
 
     VLOG(jni) << "[Unregistering JNI native methods for " << PrettyClass(c) << "]";
 
@@ -2296,7 +2208,7 @@
 
   static jint MonitorEnter(JNIEnv* env, jobject java_object) {
     ScopedJniThreadState ts(env);
-    Object* o = Decode<Object*>(ts, java_object);
+    Object* o = ts.Decode<Object*>(java_object);
     o->MonitorEnter(ts.Self());
     if (ts.Self()->IsExceptionPending()) {
       return JNI_ERR;
@@ -2307,7 +2219,7 @@
 
   static jint MonitorExit(JNIEnv* env, jobject java_object) {
     ScopedJniThreadState ts(env);
-    Object* o = Decode<Object*>(ts, java_object);
+    Object* o = ts.Decode<Object*>(java_object);
     o->MonitorExit(ts.Self());
     if (ts.Self()->IsExceptionPending()) {
       return JNI_ERR;
@@ -2386,7 +2298,7 @@
 
       // If we're handing out direct pointers, check whether it's a direct pointer
       // to a local reference.
-      if (Decode<Object*>(ts, java_object) == reinterpret_cast<Object*>(java_object)) {
+      if (ts.Decode<Object*>(java_object) == reinterpret_cast<Object*>(java_object)) {
         if (ts.Env()->locals.ContainsDirectPointer(reinterpret_cast<Object*>(java_object))) {
           return JNILocalRefType;
         }
@@ -2950,7 +2862,7 @@
     // the comments in the JNI FindClass function.)
     typedef int (*JNI_OnLoadFn)(JavaVM*, void*);
     JNI_OnLoadFn jni_on_load = reinterpret_cast<JNI_OnLoadFn>(sym);
-    const ClassLoader* old_class_loader = self->GetClassLoaderOverride();
+    ClassLoader* old_class_loader = self->GetClassLoaderOverride();
     self->SetClassLoaderOverride(class_loader);
 
     int version = 0;
diff --git a/src/jni_internal.h b/src/jni_internal.h
index be5bca0..b96a4d7 100644
--- a/src/jni_internal.h
+++ b/src/jni_internal.h
@@ -43,6 +43,7 @@
 union JValue;
 class Libraries;
 class Method;
+class ScopedJniThreadState;
 class Thread;
 
 void SetJniGlobalsMax(size_t max);
@@ -50,42 +51,9 @@
 void* FindNativeMethod(Thread* thread);
 void RegisterNativeMethods(JNIEnv* env, const char* jni_class_name, const JNINativeMethod* methods, size_t method_count);
 
-template<typename T> T Decode(JNIEnv*, jobject);
-template<typename T> T AddLocalReference(JNIEnv*, const Object*);
-
-inline Field* DecodeField(jfieldID fid) {
-#ifdef MOVING_GARBAGE_COLLECTOR
-  // TODO: we should make these unique weak globals if Field instances can ever move.
-  UNIMPLEMENTED(WARNING);
-#endif
-  return reinterpret_cast<Field*>(fid);
-}
-
-inline jfieldID EncodeField(Field* field) {
-#ifdef MOVING_GARBAGE_COLLECTOR
-  UNIMPLEMENTED(WARNING);
-#endif
-  return reinterpret_cast<jfieldID>(field);
-}
-
-inline Method* DecodeMethod(jmethodID mid) {
-#ifdef MOVING_GARBAGE_COLLECTOR
-  // TODO: we should make these unique weak globals if Method instances can ever move.
-  UNIMPLEMENTED(WARNING);
-#endif
-  return reinterpret_cast<Method*>(mid);
-}
-
-inline jmethodID EncodeMethod(Method* method) {
-#ifdef MOVING_GARBAGE_COLLECTOR
-  UNIMPLEMENTED(WARNING);
-#endif
-  return reinterpret_cast<jmethodID>(method);
-}
-
 size_t NumArgArrayBytes(const char* shorty, uint32_t shorty_len);
-JValue InvokeWithJValues(JNIEnv* env, jobject obj, jmethodID mid, jvalue* args);
-JValue InvokeWithJValues(Thread* self, Object* receiver, Method* m, JValue* args);
+JValue InvokeWithJValues(const ScopedJniThreadState&, jobject obj, jmethodID mid, jvalue* args);
+JValue InvokeWithJValues(const ScopedJniThreadState&, Object* receiver, Method* m, JValue* args);
 
 int ThrowNewException(JNIEnv* env, jclass exception_class, const char* msg, jobject cause);
 
diff --git a/src/jni_internal_test.cc b/src/jni_internal_test.cc
index 00397a3..daca1b5 100644
--- a/src/jni_internal_test.cc
+++ b/src/jni_internal_test.cc
@@ -20,6 +20,7 @@
 
 #include "common_test.h"
 #include "ScopedLocalRef.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
@@ -1245,6 +1246,7 @@
 
   jobject outer;
   jobject inner1, inner2;
+  ScopedJniThreadState ts(env_);
   Object* inner2_direct_pointer;
   {
     env_->PushLocalFrame(4);
@@ -1254,7 +1256,7 @@
       env_->PushLocalFrame(4);
       inner1 = env_->NewLocalRef(outer);
       inner2 = env_->NewStringUTF("survivor");
-      inner2_direct_pointer = Decode<Object*>(env_, inner2);
+      inner2_direct_pointer = ts.Decode<Object*>(inner2);
       env_->PopLocalFrame(inner2);
     }
 
diff --git a/src/monitor.cc b/src/monitor.cc
index dde67ea..de08b88 100644
--- a/src/monitor.cc
+++ b/src/monitor.cc
@@ -28,6 +28,7 @@
 #include "mutex.h"
 #include "object.h"
 #include "object_utils.h"
+#include "scoped_jni_thread_state.h"
 #include "scoped_thread_list_lock.h"
 #include "stl_util.h"
 #include "thread.h"
@@ -825,15 +826,17 @@
 }
 
 static uint32_t LockOwnerFromThreadLock(Object* thread_lock) {
-  if (thread_lock == NULL || thread_lock->GetClass() != WellKnownClasses::ToClass(WellKnownClasses::java_lang_ThreadLock)) {
+  ScopedJniThreadState ts(Thread::Current());
+  if (thread_lock == NULL ||
+      thread_lock->GetClass() != ts.Decode<Class*>(WellKnownClasses::java_lang_ThreadLock)) {
     return ThreadList::kInvalidId;
   }
-  Field* thread_field = DecodeField(WellKnownClasses::java_lang_ThreadLock_thread);
+  Field* thread_field = ts.DecodeField(WellKnownClasses::java_lang_ThreadLock_thread);
   Object* managed_thread = thread_field->GetObject(thread_lock);
   if (managed_thread == NULL) {
     return ThreadList::kInvalidId;
   }
-  Field* vmData_field = DecodeField(WellKnownClasses::java_lang_Thread_vmData);
+  Field* vmData_field = ts.DecodeField(WellKnownClasses::java_lang_Thread_vmData);
   uintptr_t vmData = static_cast<uintptr_t>(vmData_field->GetInt(managed_thread));
   Thread* thread = reinterpret_cast<Thread*>(vmData);
   if (thread == NULL) {
diff --git a/src/monitor_android.cc b/src/monitor_android.cc
index dc77b6d..94f86e8 100644
--- a/src/monitor_android.cc
+++ b/src/monitor_android.cc
@@ -69,7 +69,8 @@
   cp = EventLogWriteInt(cp, Monitor::IsSensitiveThread());
 
   // Emit self thread name string, <= 37 bytes.
-  std::string thread_name(self->GetThreadName()->ToModifiedUtf8());
+  std::string thread_name;
+  self->GetThreadName(thread_name);
   cp = EventLogWriteString(cp, thread_name.c_str(), thread_name.size());
 
   // Emit the wait time, 5 bytes.
diff --git a/src/native/dalvik_system_DexFile.cc b/src/native/dalvik_system_DexFile.cc
index 89d7130..3bf0ea5 100644
--- a/src/native/dalvik_system_DexFile.cc
+++ b/src/native/dalvik_system_DexFile.cc
@@ -126,7 +126,7 @@
 
 static jclass DexFile_defineClassNative(JNIEnv* env, jclass, jstring javaName, jobject javaLoader,
                                         jint cookie) {
-  ScopedJniThreadState tsc(env);
+  ScopedJniThreadState ts(env);
   const DexFile* dex_file = toDexFile(cookie);
   if (dex_file == NULL) {
     return NULL;
@@ -142,10 +142,10 @@
   }
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   class_linker->RegisterDexFile(*dex_file);
-  Object* class_loader_object = Decode<Object*>(env, javaLoader);
+  Object* class_loader_object = ts.Decode<Object*>(javaLoader);
   ClassLoader* class_loader = down_cast<ClassLoader*>(class_loader_object);
   Class* result = class_linker->DefineClass(descriptor, class_loader, *dex_file, *dex_class_def);
-  return AddLocalReference<jclass>(env, result);
+  return ts.AddLocalReference<jclass>(result);
 }
 
 static jobjectArray DexFile_getClassNameList(JNIEnv* env, jclass, jint cookie) {
diff --git a/src/native/dalvik_system_VMDebug.cc b/src/native/dalvik_system_VMDebug.cc
index 9b10cda..70067fe 100644
--- a/src/native/dalvik_system_VMDebug.cc
+++ b/src/native/dalvik_system_VMDebug.cc
@@ -22,6 +22,7 @@
 #include "hprof/hprof.h"
 #include "jni_internal.h"
 #include "ScopedUtfChars.h"
+#include "scoped_jni_thread_state.h"
 #include "toStringArray.h"
 #include "trace.h"
 
@@ -204,7 +205,8 @@
 }
 
 static jlong VMDebug_countInstancesOfClass(JNIEnv* env, jclass, jclass javaClass, jboolean countAssignable) {
-  Class* c = Decode<Class*>(env, javaClass);
+  ScopedJniThreadState ts(env);
+  Class* c = ts.Decode<Class*>(javaClass);
   if (c == NULL) {
     return 0;
   }
diff --git a/src/native/dalvik_system_VMRuntime.cc b/src/native/dalvik_system_VMRuntime.cc
index 09ca251..417ae5b 100644
--- a/src/native/dalvik_system_VMRuntime.cc
+++ b/src/native/dalvik_system_VMRuntime.cc
@@ -22,6 +22,7 @@
 #include "object.h"
 #include "object_utils.h"
 #include "scoped_heap_lock.h"
+#include "scoped_jni_thread_state.h"
 #include "scoped_thread_list_lock.h"
 #include "space.h"
 #include "thread.h"
@@ -48,7 +49,7 @@
 }
 
 static jobject VMRuntime_newNonMovableArray(JNIEnv* env, jobject, jclass javaElementClass, jint length) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
+  ScopedJniThreadState ts(env);
 #ifdef MOVING_GARBAGE_COLLECTOR
   // TODO: right now, we don't have a copying collector, so there's no need
   // to do anything special here, but we ought to pass the non-movability
@@ -56,7 +57,7 @@
   UNIMPLEMENTED(FATAL);
 #endif
 
-  Class* element_class = Decode<Class*>(env, javaElementClass);
+  Class* element_class = ts.Decode<Class*>(javaElementClass);
   if (element_class == NULL) {
     Thread::Current()->ThrowNewException("Ljava/lang/NullPointerException;", "element class == null");
     return NULL;
@@ -75,15 +76,15 @@
   if (result == NULL) {
     return NULL;
   }
-  return AddLocalReference<jobject>(env, result);
+  return ts.AddLocalReference<jobject>(result);
 }
 
 static jlong VMRuntime_addressOf(JNIEnv* env, jobject, jobject javaArray) {
   if (javaArray == NULL) {  // Most likely allocation failed
     return 0;
   }
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Array* array = Decode<Array*>(env, javaArray);
+  ScopedJniThreadState ts(env);
+  Array* array = ts.Decode<Array*>(javaArray);
   if (!array->IsArrayInstance()) {
     Thread::Current()->ThrowNewException("Ljava/lang/IllegalArgumentException;", "not an array");
     return 0;
diff --git a/src/native/dalvik_system_VMStack.cc b/src/native/dalvik_system_VMStack.cc
index e3ecbd9..933a5d5 100644
--- a/src/native/dalvik_system_VMStack.cc
+++ b/src/native/dalvik_system_VMStack.cc
@@ -26,10 +26,11 @@
 namespace art {
 
 static jobject GetThreadStack(JNIEnv* env, jobject javaThread) {
+  ScopedJniThreadState ts(env);
   ScopedHeapLock heap_lock;
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, javaThread);
-  return (thread != NULL) ? GetThreadStack(env, thread) : NULL;
+  Thread* thread = Thread::FromManagedThread(ts, javaThread);
+  return (thread != NULL) ? GetThreadStack(ts, thread) : NULL;
 }
 
 static jint VMStack_fillStackTraceElements(JNIEnv* env, jclass, jobject javaThread, jobjectArray javaSteArray) {
@@ -44,10 +45,10 @@
 
 // Returns the defining class loader of the caller's caller.
 static jobject VMStack_getCallingClassLoader(JNIEnv* env, jclass) {
-  ScopedJniThreadState ts(env, kNative);  // Not a state change out of native.
+  ScopedJniThreadState ts(env);
   NthCallerVisitor visitor(ts.Self()->GetManagedStack(), ts.Self()->GetTraceStack(), 2);
   visitor.WalkStack();
-  return AddLocalReference<jobject>(env, visitor.caller->GetDeclaringClass()->GetClassLoader());
+  return ts.AddLocalReference<jobject>(visitor.caller->GetDeclaringClass()->GetClassLoader());
 }
 
 static jobject VMStack_getClosestUserClassLoader(JNIEnv* env, jclass, jobject javaBootstrap, jobject javaSystem) {
@@ -72,20 +73,20 @@
     Object* class_loader;
   };
   ScopedJniThreadState ts(env);
-  Object* bootstrap = Decode<Object*>(env, javaBootstrap);
-  Object* system = Decode<Object*>(env, javaSystem);
+  Object* bootstrap = ts.Decode<Object*>(javaBootstrap);
+  Object* system = ts.Decode<Object*>(javaSystem);
   ClosestUserClassLoaderVisitor visitor(ts.Self()->GetManagedStack(), ts.Self()->GetTraceStack(),
                                         bootstrap, system);
   visitor.WalkStack();
-  return AddLocalReference<jobject>(env, visitor.class_loader);
+  return ts.AddLocalReference<jobject>(visitor.class_loader);
 }
 
 // Returns the class of the caller's caller's caller.
 static jclass VMStack_getStackClass2(JNIEnv* env, jclass) {
-  ScopedJniThreadState ts(env, kNative);  // Not a state change out of native.
+  ScopedJniThreadState ts(env);
   NthCallerVisitor visitor(ts.Self()->GetManagedStack(), ts.Self()->GetTraceStack(), 3);
   visitor.WalkStack();
-  return AddLocalReference<jclass>(env, visitor.caller->GetDeclaringClass());
+  return ts.AddLocalReference<jclass>(visitor.caller->GetDeclaringClass());
 }
 
 static jobjectArray VMStack_getThreadStackTrace(JNIEnv* env, jclass, jobject javaThread) {
diff --git a/src/native/java_lang_Class.cc b/src/native/java_lang_Class.cc
index 99e3a26..ecab777 100644
--- a/src/native/java_lang_Class.cc
+++ b/src/native/java_lang_Class.cc
@@ -27,8 +27,8 @@
 
 namespace art {
 
-static Class* DecodeClass(JNIEnv* env, jobject java_class) {
-  Class* c = Decode<Class*>(env, java_class);
+static Class* DecodeClass(const ScopedJniThreadState& ts, jobject java_class) {
+  Class* c = ts.Decode<Class*>(java_class);
   DCHECK(c != NULL);
   DCHECK(c->IsClass());
   // TODO: we could EnsureInitialized here, rather than on every reflective get/set or invoke .
@@ -39,7 +39,7 @@
 
 // "name" is in "binary name" format, e.g. "dalvik.system.Debug$1".
 static jclass Class_classForName(JNIEnv* env, jclass, jstring javaName, jboolean initialize, jobject javaLoader) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
+  ScopedJniThreadState ts(env);
   ScopedUtfChars name(env, javaName);
   if (name.c_str() == NULL) {
     return NULL;
@@ -55,8 +55,7 @@
   }
 
   std::string descriptor(DotToDescriptor(name.c_str()));
-  Object* loader = Decode<Object*>(env, javaLoader);
-  ClassLoader* class_loader = down_cast<ClassLoader*>(loader);
+  ClassLoader* class_loader = ts.Decode<ClassLoader*>(javaLoader);
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   Class* c = class_linker->FindClass(descriptor.c_str(), class_loader);
   if (c == NULL) {
@@ -71,11 +70,12 @@
   if (initialize) {
     class_linker->EnsureInitialized(c, true, true);
   }
-  return AddLocalReference<jclass>(env, c);
+  return ts.AddLocalReference<jclass>(c);
 }
 
 static jint Class_getAnnotationDirectoryOffset(JNIEnv* env, jclass javaClass) {
-  Class* c = DecodeClass(env, javaClass);
+  ScopedJniThreadState ts(env);
+  Class* c = DecodeClass(ts, javaClass);
   if (c->IsPrimitive() || c->IsArrayClass() || c->IsProxyClass()) {
     return 0;  // primitive, array and proxy classes don't have class definitions
   }
@@ -88,12 +88,13 @@
 }
 
 template<typename T>
-static jobjectArray ToArray(JNIEnv* env, const char* array_class_name, const std::vector<T*>& objects) {
-  ScopedLocalRef<jclass> array_class(env, env->FindClass(array_class_name));
-  jobjectArray result = env->NewObjectArray(objects.size(), array_class.get(), NULL);
+static jobjectArray ToArray(const ScopedJniThreadState& ts, const char* array_class_name,
+                            const std::vector<T*>& objects) {
+  ScopedLocalRef<jclass> array_class(ts.Env(), ts.Env()->FindClass(array_class_name));
+  jobjectArray result = ts.Env()->NewObjectArray(objects.size(), array_class.get(), NULL);
   for (size_t i = 0; i < objects.size(); ++i) {
-    ScopedLocalRef<jobject> object(env, AddLocalReference<jobject>(env, objects[i]));
-    env->SetObjectArrayElement(result, i, object.get());
+    ScopedLocalRef<jobject> object(ts.Env(), ts.AddLocalReference<jobject>(objects[i]));
+    ts.Env()->SetObjectArrayElement(result, i, object.get());
   }
   return result;
 }
@@ -109,11 +110,8 @@
 }
 
 static jobjectArray Class_getDeclaredConstructors(JNIEnv* env, jclass javaClass, jboolean publicOnly) {
-  Class* c = DecodeClass(env, javaClass);
-  if (c == NULL) {
-    return NULL;
-  }
-
+  ScopedJniThreadState ts(env);
+  Class* c = DecodeClass(ts, javaClass);
   std::vector<Method*> constructors;
   for (size_t i = 0; i < c->NumDirectMethods(); ++i) {
     Method* m = c->GetDirectMethod(i);
@@ -122,7 +120,7 @@
     }
   }
 
-  return ToArray(env, "java/lang/reflect/Constructor", constructors);
+  return ToArray(ts, "java/lang/reflect/Constructor", constructors);
 }
 
 static bool IsVisibleField(Field* f, bool public_only) {
@@ -133,12 +131,8 @@
 }
 
 static jobjectArray Class_getDeclaredFields(JNIEnv* env, jclass javaClass, jboolean publicOnly) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Class* c = DecodeClass(env, javaClass);
-  if (c == NULL) {
-    return NULL;
-  }
-
+  ScopedJniThreadState ts(env);
+  Class* c = DecodeClass(ts, javaClass);
   std::vector<Field*> fields;
   FieldHelper fh;
   for (size_t i = 0; i < c->NumInstanceFields(); ++i) {
@@ -170,7 +164,7 @@
     }
   }
 
-  return ToArray(env, "java/lang/reflect/Field", fields);
+  return ToArray(ts, "java/lang/reflect/Field", fields);
 }
 
 static bool IsVisibleMethod(Method* m, bool public_only) {
@@ -187,8 +181,8 @@
 }
 
 static jobjectArray Class_getDeclaredMethods(JNIEnv* env, jclass javaClass, jboolean publicOnly) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Class* c = DecodeClass(env, javaClass);
+  ScopedJniThreadState ts(env);
+  Class* c = DecodeClass(ts, javaClass);
   if (c == NULL) {
     return NULL;
   }
@@ -224,11 +218,12 @@
     }
   }
 
-  return ToArray(env, "java/lang/reflect/Method", methods);
+  return ToArray(ts, "java/lang/reflect/Method", methods);
 }
 
 static jobject Class_getDex(JNIEnv* env, jobject javaClass) {
-  Class* c = DecodeClass(env, javaClass);
+  ScopedJniThreadState ts(env);
+  Class* c = DecodeClass(ts, javaClass);
 
   DexCache* dex_cache = c->GetDexCache();
   if (dex_cache == NULL) {
@@ -287,13 +282,10 @@
 
 static jobject Class_getDeclaredConstructorOrMethod(JNIEnv* env, jclass javaClass, jstring javaName,
                                                     jobjectArray javaArgs) {
-  Class* c = DecodeClass(env, javaClass);
-  if (c == NULL) {
-    return NULL;
-  }
-
-  std::string name(Decode<String*>(env, javaName)->ToModifiedUtf8());
-  ObjectArray<Class>* arg_array = Decode<ObjectArray<Class>*>(env, javaArgs);
+  ScopedJniThreadState ts(env);
+  Class* c = DecodeClass(ts, javaClass);
+  std::string name(ts.Decode<String*>(javaName)->ToModifiedUtf8());
+  ObjectArray<Class>* arg_array = ts.Decode<ObjectArray<Class>*>(javaArgs);
 
   Method* m = FindConstructorOrMethodInArray(c->GetDirectMethods(), name, arg_array);
   if (m == NULL) {
@@ -301,7 +293,7 @@
   }
 
   if (m != NULL) {
-    return AddLocalReference<jobject>(env, m);
+    return ts.AddLocalReference<jobject>(m);
   } else {
     return NULL;
   }
@@ -309,12 +301,8 @@
 
 static jobject Class_getDeclaredFieldNative(JNIEnv* env, jclass java_class, jobject jname) {
   ScopedJniThreadState ts(env);
-  Class* c = DecodeClass(env, java_class);
-  if (c == NULL) {
-    return NULL;
-  }
-
-  String* name = Decode<String*>(env, jname);
+  Class* c = DecodeClass(ts, java_class);
+  String* name = ts.Decode<String*>(jname);
   DCHECK(name->GetClass()->IsStringClass());
 
   FieldHelper fh;
@@ -326,7 +314,7 @@
         DCHECK(env->ExceptionOccurred());
         return NULL;
       }
-      return AddLocalReference<jclass>(env, f);
+      return ts.AddLocalReference<jclass>(f);
     }
   }
   for (size_t i = 0; i < c->NumStaticFields(); ++i) {
@@ -337,7 +325,7 @@
         DCHECK(env->ExceptionOccurred());
         return NULL;
       }
-      return AddLocalReference<jclass>(env, f);
+      return ts.AddLocalReference<jclass>(f);
     }
   }
   return NULL;
@@ -345,20 +333,20 @@
 
 static jstring Class_getNameNative(JNIEnv* env, jobject javaThis) {
   ScopedJniThreadState ts(env);
-  Class* c = DecodeClass(env, javaThis);
-  return AddLocalReference<jstring>(env, c->ComputeName());
+  Class* c = DecodeClass(ts, javaThis);
+  return ts.AddLocalReference<jstring>(c->ComputeName());
 }
 
 static jobjectArray Class_getProxyInterfaces(JNIEnv* env, jobject javaThis) {
   ScopedJniThreadState ts(env);
-  SynthesizedProxyClass* c = down_cast<SynthesizedProxyClass*>(DecodeClass(env, javaThis));
-  return AddLocalReference<jobjectArray>(env, c->GetInterfaces()->Clone());
+  SynthesizedProxyClass* c = down_cast<SynthesizedProxyClass*>(DecodeClass(ts, javaThis));
+  return ts.AddLocalReference<jobjectArray>(c->GetInterfaces()->Clone());
 }
 
 static jboolean Class_isAssignableFrom(JNIEnv* env, jobject javaLhs, jclass javaRhs) {
   ScopedJniThreadState ts(env);
-  Class* lhs = DecodeClass(env, javaLhs);
-  Class* rhs = Decode<Class*>(env, javaRhs); // Can be null.
+  Class* lhs = DecodeClass(ts, javaLhs);
+  Class* rhs = ts.Decode<Class*>(javaRhs); // Can be null.
   if (rhs == NULL) {
     ts.Self()->ThrowNewException("Ljava/lang/NullPointerException;", "class == null");
     return JNI_FALSE;
@@ -397,7 +385,7 @@
 
 static jobject Class_newInstanceImpl(JNIEnv* env, jobject javaThis) {
   ScopedJniThreadState ts(env);
-  Class* c = DecodeClass(env, javaThis);
+  Class* c = DecodeClass(ts, javaThis);
   if (c->IsPrimitive() || c->IsInterface() || c->IsArrayClass() || c->IsAbstract()) {
     ts.Self()->ThrowNewExceptionF("Ljava/lang/InstantiationException;",
         "Class %s can not be instantiated", PrettyDescriptor(ClassHelper(c).GetDescriptor()).c_str());
@@ -451,8 +439,8 @@
   }
 
   // invoke constructor; unlike reflection calls, we don't wrap exceptions
-  jclass java_class = AddLocalReference<jclass>(env, c);
-  jmethodID mid = EncodeMethod(init);
+  jclass java_class = ts.AddLocalReference<jclass>(c);
+  jmethodID mid = ts.EncodeMethod(init);
   return env->NewObject(java_class, mid);
 }
 
diff --git a/src/native/java_lang_Object.cc b/src/native/java_lang_Object.cc
index 51e4581..d6b1bd6 100644
--- a/src/native/java_lang_Object.cc
+++ b/src/native/java_lang_Object.cc
@@ -16,27 +16,31 @@
 
 #include "jni_internal.h"
 #include "object.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
 static jobject Object_internalClone(JNIEnv* env, jobject javaThis) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Object* o = Decode<Object*>(env, javaThis);
-  return AddLocalReference<jobject>(env, o->Clone());
+  ScopedJniThreadState ts(env);
+  Object* o = ts.Decode<Object*>(javaThis);
+  return ts.AddLocalReference<jobject>(o->Clone());
 }
 
 static void Object_notify(JNIEnv* env, jobject javaThis) {
-  Object* o = Decode<Object*>(env, javaThis);
+  ScopedJniThreadState ts(env);
+  Object* o = ts.Decode<Object*>(javaThis);
   o->Notify();
 }
 
 static void Object_notifyAll(JNIEnv* env, jobject javaThis) {
-  Object* o = Decode<Object*>(env, javaThis);
+  ScopedJniThreadState ts(env);
+  Object* o = ts.Decode<Object*>(javaThis);
   o->NotifyAll();
 }
 
 static void Object_wait(JNIEnv* env, jobject javaThis, jlong ms, jint ns) {
-  Object* o = Decode<Object*>(env, javaThis);
+  ScopedJniThreadState ts(env);
+  Object* o = ts.Decode<Object*>(javaThis);
   o->Wait(ms, ns);
 }
 
diff --git a/src/native/java_lang_Runtime.cc b/src/native/java_lang_Runtime.cc
index 3019e95..1b657b1 100644
--- a/src/native/java_lang_Runtime.cc
+++ b/src/native/java_lang_Runtime.cc
@@ -17,16 +17,18 @@
 #include <limits.h>
 #include <unistd.h>
 
+#include "class_loader.h"
 #include "heap.h"
 #include "jni_internal.h"
 #include "object.h"
 #include "runtime.h"
+#include "scoped_jni_thread_state.h"
 #include "ScopedUtfChars.h"
 
 namespace art {
 
-static void Runtime_gc(JNIEnv*, jclass) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
+static void Runtime_gc(JNIEnv* env, jclass) {
+  ScopedJniThreadState ts(env);
   Runtime::Current()->GetHeap()->CollectGarbage(false);
 }
 
@@ -43,12 +45,13 @@
  * message on failure.
  */
 static jstring Runtime_nativeLoad(JNIEnv* env, jclass, jstring javaFilename, jobject javaLoader) {
+  ScopedJniThreadState ts(env);
   ScopedUtfChars filename(env, javaFilename);
   if (filename.c_str() == NULL) {
     return NULL;
   }
 
-  ClassLoader* classLoader = Decode<ClassLoader*>(env, javaLoader);
+  ClassLoader* classLoader = ts.Decode<ClassLoader*>(javaLoader);
   std::string detail;
   JavaVMExt* vm = Runtime::Current()->GetJavaVM();
   bool success = vm->LoadNativeLibrary(filename.c_str(), classLoader, detail);
diff --git a/src/native/java_lang_String.cc b/src/native/java_lang_String.cc
index f8fb4a7..96fcf96 100644
--- a/src/native/java_lang_String.cc
+++ b/src/native/java_lang_String.cc
@@ -16,6 +16,7 @@
 
 #include "jni_internal.h"
 #include "object.h"
+#include "scoped_jni_thread_state.h"
 
 #ifdef HAVE__MEMCMP16
 // "count" is in 16-bit units.
@@ -35,9 +36,9 @@
 namespace art {
 
 static jint String_compareTo(JNIEnv* env, jobject javaThis, jobject javaRhs) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  String* lhs = Decode<String*>(env, javaThis);
-  String* rhs = Decode<String*>(env, javaRhs);
+  ScopedJniThreadState ts(env);
+  String* lhs = ts.Decode<String*>(javaThis);
+  String* rhs = ts.Decode<String*>(javaRhs);
 
   if (rhs == NULL) {
     Thread::Current()->ThrowNewException("Ljava/lang/NullPointerException;", "rhs == null");
@@ -69,10 +70,11 @@
 }
 
 static jint String_fastIndexOf(JNIEnv* env, jobject java_this, jint ch, jint start) {
+  ScopedJniThreadState ts(env);
   // This method does not handle supplementary characters. They're dealt with in managed code.
   DCHECK_LE(ch, 0xffff);
 
-  String* s = Decode<String*>(env, java_this);
+  String* s = ts.Decode<String*>(java_this);
 
   jint count = s->GetLength();
   if (start < 0) {
@@ -94,9 +96,10 @@
 }
 
 static jstring String_intern(JNIEnv* env, jobject javaThis) {
-  String* s = Decode<String*>(env, javaThis);
+  ScopedJniThreadState ts(env);
+  String* s = ts.Decode<String*>(javaThis);
   String* result = s->Intern();
-  return AddLocalReference<jstring>(env, result);
+  return ts.AddLocalReference<jstring>(result);
 }
 
 static JNINativeMethod gMethods[] = {
diff --git a/src/native/java_lang_System.cc b/src/native/java_lang_System.cc
index b0d1eec..76ac670 100644
--- a/src/native/java_lang_System.cc
+++ b/src/native/java_lang_System.cc
@@ -16,6 +16,7 @@
 
 #include "jni_internal.h"
 #include "object.h"
+#include "scoped_jni_thread_state.h"
 
 /*
  * We make guarantees about the atomicity of accesses to primitive
@@ -107,22 +108,21 @@
 }
 
 static void System_arraycopy(JNIEnv* env, jclass, jobject javaSrc, jint srcPos, jobject javaDst, jint dstPos, jint length) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Thread* self = Thread::Current();
+  ScopedJniThreadState ts(env);
 
   // Null pointer checks.
   if (javaSrc == NULL) {
-    self->ThrowNewException("Ljava/lang/NullPointerException;", "src == null");
+    ts.Self()->ThrowNewException("Ljava/lang/NullPointerException;", "src == null");
     return;
   }
   if (javaDst == NULL) {
-    self->ThrowNewException("Ljava/lang/NullPointerException;", "dst == null");
+    ts.Self()->ThrowNewException("Ljava/lang/NullPointerException;", "dst == null");
     return;
   }
 
   // Make sure source and destination are both arrays.
-  Object* srcObject = Decode<Object*>(env, javaSrc);
-  Object* dstObject = Decode<Object*>(env, javaDst);
+  Object* srcObject = ts.Decode<Object*>(javaSrc);
+  Object* dstObject = ts.Decode<Object*>(javaDst);
   if (!srcObject->IsArrayInstance()) {
     ThrowArrayStoreException_NotAnArray("source", srcObject);
     return;
@@ -138,7 +138,7 @@
 
   // Bounds checking.
   if (srcPos < 0 || dstPos < 0 || length < 0 || srcPos > srcArray->GetLength() - length || dstPos > dstArray->GetLength() - length) {
-    self->ThrowNewExceptionF("Ljava/lang/ArrayIndexOutOfBoundsException;",
+    ts.Self()->ThrowNewExceptionF("Ljava/lang/ArrayIndexOutOfBoundsException;",
         "src.length=%d srcPos=%d dst.length=%d dstPos=%d length=%d",
         srcArray->GetLength(), srcPos, dstArray->GetLength(), dstPos, length);
     return;
@@ -150,7 +150,7 @@
     if (srcComponentType->IsPrimitive() != dstComponentType->IsPrimitive() || srcComponentType != dstComponentType) {
       std::string srcType(PrettyTypeOf(srcArray));
       std::string dstType(PrettyTypeOf(dstArray));
-      self->ThrowNewExceptionF("Ljava/lang/ArrayStoreException;",
+      ts.Self()->ThrowNewExceptionF("Ljava/lang/ArrayStoreException;",
           "Incompatible types: src=%s, dst=%s", srcType.c_str(), dstType.c_str());
       return;
     }
@@ -233,7 +233,7 @@
   if (i != length) {
     std::string actualSrcType(PrettyTypeOf(o));
     std::string dstType(PrettyTypeOf(dstArray));
-    self->ThrowNewExceptionF("Ljava/lang/ArrayStoreException;",
+    ts.Self()->ThrowNewExceptionF("Ljava/lang/ArrayStoreException;",
         "source[%d] of type %s cannot be stored in destination array of type %s",
         srcPos + i, actualSrcType.c_str(), dstType.c_str());
     return;
@@ -241,7 +241,8 @@
 }
 
 static jint System_identityHashCode(JNIEnv* env, jclass, jobject javaObject) {
-  Object* o = Decode<Object*>(env, javaObject);
+  ScopedJniThreadState ts(env);
+  Object* o = ts.Decode<Object*>(javaObject);
   return static_cast<jint>(reinterpret_cast<uintptr_t>(o));
 }
 
diff --git a/src/native/java_lang_Thread.cc b/src/native/java_lang_Thread.cc
index ed95a6c..86b3a20 100644
--- a/src/native/java_lang_Thread.cc
+++ b/src/native/java_lang_Thread.cc
@@ -17,6 +17,7 @@
 #include "debugger.h"
 #include "jni_internal.h"
 #include "object.h"
+#include "scoped_jni_thread_state.h"
 #include "scoped_thread_list_lock.h"
 #include "ScopedUtfChars.h"
 #include "thread.h"
@@ -25,22 +26,24 @@
 namespace art {
 
 static jobject Thread_currentThread(JNIEnv* env, jclass) {
-  return AddLocalReference<jobject>(env, Thread::Current()->GetPeer());
+  ScopedJniThreadState ts(env);
+  return ts.AddLocalReference<jobject>(ts.Self()->GetPeer());
 }
 
-static jboolean Thread_interrupted(JNIEnv*, jclass) {
-  return Thread::Current()->Interrupted();
+static jboolean Thread_interrupted(JNIEnv* env, jclass) {
+  ScopedJniThreadState ts(env, kNative);  // Doesn't touch objects, so keep in native state.
+  return ts.Self()->Interrupted();
 }
 
 static jboolean Thread_isInterrupted(JNIEnv* env, jobject java_thread) {
+  ScopedJniThreadState ts(env);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, java_thread);
+  Thread* thread = Thread::FromManagedThread(ts, java_thread);
   return (thread != NULL) ? thread->IsInterrupted() : JNI_FALSE;
 }
 
 static void Thread_nativeCreate(JNIEnv* env, jclass, jobject java_thread, jlong stack_size) {
-  Object* managedThread = Decode<Object*>(env, java_thread);
-  Thread::CreateNativeThread(managedThread, stack_size);
+  Thread::CreateNativeThread(env, java_thread, stack_size);
 }
 
 static jint Thread_nativeGetStatus(JNIEnv* env, jobject java_thread, jboolean has_been_started) {
@@ -52,9 +55,10 @@
   const jint kJavaTimedWaiting = 4;
   const jint kJavaTerminated = 5;
 
+  ScopedJniThreadState ts(env);
   ThreadState internal_thread_state = (has_been_started ? kTerminated : kStarting);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, java_thread);
+  Thread* thread = Thread::FromManagedThread(ts, java_thread);
   if (thread != NULL) {
     internal_thread_state = thread->GetState();
   }
@@ -74,28 +78,30 @@
 }
 
 static jboolean Thread_nativeHoldsLock(JNIEnv* env, jobject java_thread, jobject java_object) {
-  Object* object = Decode<Object*>(env, java_object);
+  ScopedJniThreadState ts(env);
+  Object* object = ts.Decode<Object*>(java_object);
   if (object == NULL) {
-    ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
     Thread::Current()->ThrowNewException("Ljava/lang/NullPointerException;", "object == null");
     return JNI_FALSE;
   }
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, java_thread);
+  Thread* thread = Thread::FromManagedThread(ts, java_thread);
   return thread->HoldsLock(object);
 }
 
 static void Thread_nativeInterrupt(JNIEnv* env, jobject java_thread) {
+  ScopedJniThreadState ts(env);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, java_thread);
+  Thread* thread = Thread::FromManagedThread(ts, java_thread);
   if (thread != NULL) {
     thread->Interrupt();
   }
 }
 
 static void Thread_nativeSetName(JNIEnv* env, jobject java_thread, jstring java_name) {
+  ScopedJniThreadState ts(env);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, java_thread);
+  Thread* thread = Thread::FromManagedThread(ts, java_thread);
   if (thread == NULL) {
     return;
   }
@@ -112,8 +118,9 @@
  * threads at Thread.NORM_PRIORITY (5).
  */
 static void Thread_nativeSetPriority(JNIEnv* env, jobject java_thread, jint new_priority) {
+  ScopedJniThreadState ts(env);
   ScopedThreadListLock thread_list_lock;
-  Thread* thread = Thread::FromManagedThread(env, java_thread);
+  Thread* thread = Thread::FromManagedThread(ts, java_thread);
   if (thread != NULL) {
     thread->SetNativePriority(new_priority);
   }
diff --git a/src/native/java_lang_Throwable.cc b/src/native/java_lang_Throwable.cc
index 625a34b..1c59a34 100644
--- a/src/native/java_lang_Throwable.cc
+++ b/src/native/java_lang_Throwable.cc
@@ -15,13 +15,14 @@
  */
 
 #include "jni_internal.h"
+#include "scoped_jni_thread_state.h"
 #include "thread.h"
 
 namespace art {
 
 static jobject Throwable_nativeFillInStackTrace(JNIEnv* env, jclass) {
-  JNIEnvExt* env_ext = reinterpret_cast<JNIEnvExt*>(env);
-  return env_ext->self->CreateInternalStackTrace(env);
+  ScopedJniThreadState ts(env);
+  return ts.Self()->CreateInternalStackTrace(ts);
 }
 
 static jobjectArray Throwable_nativeGetStackTrace(JNIEnv* env, jclass, jobject javaStackState) {
diff --git a/src/native/java_lang_VMClassLoader.cc b/src/native/java_lang_VMClassLoader.cc
index a976933..0689f74 100644
--- a/src/native/java_lang_VMClassLoader.cc
+++ b/src/native/java_lang_VMClassLoader.cc
@@ -15,14 +15,17 @@
  */
 
 #include "class_linker.h"
+#include "class_loader.h"
 #include "jni_internal.h"
+#include "scoped_jni_thread_state.h"
 #include "ScopedUtfChars.h"
 #include "zip_archive.h"
 
 namespace art {
 
 static jclass VMClassLoader_findLoadedClass(JNIEnv* env, jclass, jobject javaLoader, jstring javaName) {
-  ClassLoader* loader = Decode<ClassLoader*>(env, javaLoader);
+  ScopedJniThreadState ts(env);
+  ClassLoader* loader = ts.Decode<ClassLoader*>(javaLoader);
   ScopedUtfChars name(env, javaName);
   if (name.c_str() == NULL) {
     return NULL;
@@ -31,7 +34,7 @@
   std::string descriptor(DotToDescriptor(name.c_str()));
   Class* c = Runtime::Current()->GetClassLinker()->LookupClass(descriptor.c_str(), loader);
   if (c != NULL && c->IsResolved()) {
-    return AddLocalReference<jclass>(env, c);
+    return ts.AddLocalReference<jclass>(c);
   } else {
     // Class wasn't resolved so it may be erroneous or not yet ready, force the caller to go into
     // the regular loadClass code.
diff --git a/src/native/java_lang_reflect_Array.cc b/src/native/java_lang_reflect_Array.cc
index ea635d3..729312e 100644
--- a/src/native/java_lang_reflect_Array.cc
+++ b/src/native/java_lang_reflect_Array.cc
@@ -18,6 +18,7 @@
 #include "jni_internal.h"
 #include "object.h"
 #include "object_utils.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
@@ -68,12 +69,12 @@
 // subtract pieces off.  Besides, we want to start with the outermost
 // piece and work our way in.
 static jobject Array_createMultiArray(JNIEnv* env, jclass, jclass javaElementClass, jobject javaDimArray) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
+  ScopedJniThreadState ts(env);
   DCHECK(javaElementClass != NULL);
-  Class* element_class = Decode<Class*>(env, javaElementClass);
+  Class* element_class = ts.Decode<Class*>(javaElementClass);
   DCHECK(element_class->IsClass());
   DCHECK(javaDimArray != NULL);
-  Object* dimensions_obj = Decode<Object*>(env, javaDimArray);
+  Object* dimensions_obj = ts.Decode<Object*>(javaDimArray);
   DCHECK(dimensions_obj->IsArrayInstance());
   DCHECK_STREQ(ClassHelper(dimensions_obj->GetClass()).GetDescriptor(), "[I");
   IntArray* dimensions_array = down_cast<IntArray*>(dimensions_obj);
@@ -89,7 +90,7 @@
   for (int i = 0; i < num_dimensions; i++) {
     int dimension = dimensions_array->Get(i);
     if (dimension < 0) {
-      Thread::Current()->ThrowNewExceptionF("Ljava/lang/NegativeArraySizeException;",
+      ts.Self()->ThrowNewExceptionF("Ljava/lang/NegativeArraySizeException;",
           "Dimension %d: %d", i, dimension);
       return NULL;
     }
@@ -112,15 +113,15 @@
     CHECK(Thread::Current()->IsExceptionPending());
     return NULL;
   }
-  return AddLocalReference<jobject>(env, new_array);
+  return ts.AddLocalReference<jobject>(new_array);
 }
 
 static jobject Array_createObjectArray(JNIEnv* env, jclass, jclass javaElementClass, jint length) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
+  ScopedJniThreadState ts(env);
   DCHECK(javaElementClass != NULL);
-  Class* element_class = Decode<Class*>(env, javaElementClass);
+  Class* element_class = ts.Decode<Class*>(javaElementClass);
   if (length < 0) {
-    Thread::Current()->ThrowNewExceptionF("Ljava/lang/NegativeArraySizeException;", "%d", length);
+    ts.Self()->ThrowNewExceptionF("Ljava/lang/NegativeArraySizeException;", "%d", length);
     return NULL;
   }
   std::string descriptor;
@@ -130,16 +131,16 @@
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   Class* array_class = class_linker->FindClass(descriptor.c_str(), element_class->GetClassLoader());
   if (array_class == NULL) {
-    CHECK(Thread::Current()->IsExceptionPending());
+    CHECK(ts.Self()->IsExceptionPending());
     return NULL;
   }
   DCHECK(array_class->IsArrayClass());
   Array* new_array = Array::Alloc(array_class, length);
   if (new_array == NULL) {
-    CHECK(Thread::Current()->IsExceptionPending());
+    CHECK(ts.Self()->IsExceptionPending());
     return NULL;
   }
-  return AddLocalReference<jobject>(env, new_array);
+  return ts.AddLocalReference<jobject>(new_array);
 }
 
 static JNINativeMethod gMethods[] = {
diff --git a/src/native/java_lang_reflect_Constructor.cc b/src/native/java_lang_reflect_Constructor.cc
index 1094d06..564d6db 100644
--- a/src/native/java_lang_reflect_Constructor.cc
+++ b/src/native/java_lang_reflect_Constructor.cc
@@ -19,6 +19,7 @@
 #include "object.h"
 #include "object_utils.h"
 #include "reflection.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
@@ -30,17 +31,17 @@
  * with an interface, array, or primitive class.
  */
 static jobject Constructor_newInstance(JNIEnv* env, jobject javaMethod, jobjectArray javaArgs) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Method* m = Decode<Object*>(env, javaMethod)->AsMethod();
+  ScopedJniThreadState ts(env);
+  Method* m = ts.Decode<Object*>(javaMethod)->AsMethod();
   Class* c = m->GetDeclaringClass();
   if (c->IsAbstract()) {
-    Thread::Current()->ThrowNewExceptionF("Ljava/lang/InstantiationException;",
+    ts.Self()->ThrowNewExceptionF("Ljava/lang/InstantiationException;",
         "Can't instantiate abstract class %s", PrettyDescriptor(c).c_str());
     return NULL;
   }
 
   if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
-    DCHECK(Thread::Current()->IsExceptionPending());
+    DCHECK(ts.Self()->IsExceptionPending());
     return NULL;
   }
 
@@ -49,8 +50,8 @@
     return NULL;
   }
 
-  jobject javaReceiver = AddLocalReference<jobject>(env, receiver);
-  InvokeMethod(env, javaMethod, javaReceiver, javaArgs);
+  jobject javaReceiver = ts.AddLocalReference<jobject>(receiver);
+  InvokeMethod(ts, javaMethod, javaReceiver, javaArgs);
 
   // Constructors are ()V methods, so we shouldn't touch the result of InvokeMethod.
   return javaReceiver;
diff --git a/src/native/java_lang_reflect_Field.cc b/src/native/java_lang_reflect_Field.cc
index bd33c0e..b2ede63 100644
--- a/src/native/java_lang_reflect_Field.cc
+++ b/src/native/java_lang_reflect_Field.cc
@@ -19,12 +19,13 @@
 #include "object.h"
 #include "object_utils.h"
 #include "reflection.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
-static bool GetFieldValue(Object* o, Field* f, JValue& value, bool allow_references) {
+static bool GetFieldValue(const ScopedJniThreadState& ts, Object* o, Field* f, JValue& value,
+                          bool allow_references) {
   DCHECK_EQ(value.GetJ(), 0LL);
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
   if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(f->GetDeclaringClass(), true, true)) {
     return false;
   }
@@ -64,18 +65,18 @@
     // Never okay.
     break;
   }
-  Thread::Current()->ThrowNewExceptionF("Ljava/lang/IllegalArgumentException;",
+  ts.Self()->ThrowNewExceptionF("Ljava/lang/IllegalArgumentException;",
       "Not a primitive field: %s", PrettyField(f).c_str());
   return false;
 }
 
-static bool CheckReceiver(JNIEnv* env, jobject javaObj, Field* f, Object*& o) {
+static bool CheckReceiver(const ScopedJniThreadState& ts, jobject javaObj, Field* f, Object*& o) {
   if (f->IsStatic()) {
     o = NULL;
     return true;
   }
 
-  o = Decode<Object*>(env, javaObj);
+  o = ts.Decode<Object*>(javaObj);
   Class* declaringClass = f->GetDeclaringClass();
   if (!VerifyObjectInClass(o, declaringClass)) {
     return false;
@@ -84,32 +85,34 @@
 }
 
 static jobject Field_get(JNIEnv* env, jobject javaField, jobject javaObj) {
-  Field* f = DecodeField(env->FromReflectedField(javaField));
+  ScopedJniThreadState ts(env);
+  Field* f = ts.DecodeField(env->FromReflectedField(javaField));
   Object* o = NULL;
-  if (!CheckReceiver(env, javaObj, f, o)) {
+  if (!CheckReceiver(ts, javaObj, f, o)) {
     return NULL;
   }
 
   // Get the field's value, boxing if necessary.
   JValue value;
-  if (!GetFieldValue(o, f, value, true)) {
+  if (!GetFieldValue(ts, o, f, value, true)) {
     return NULL;
   }
   BoxPrimitive(FieldHelper(f).GetTypeAsPrimitiveType(), value);
 
-  return AddLocalReference<jobject>(env, value.GetL());
+  return ts.AddLocalReference<jobject>(value.GetL());
 }
 
 static JValue GetPrimitiveField(JNIEnv* env, jobject javaField, jobject javaObj, char dst_descriptor) {
-  Field* f = DecodeField(env->FromReflectedField(javaField));
+  ScopedJniThreadState ts(env);
+  Field* f = ts.DecodeField(env->FromReflectedField(javaField));
   Object* o = NULL;
-  if (!CheckReceiver(env, javaObj, f, o)) {
+  if (!CheckReceiver(ts, javaObj, f, o)) {
     return JValue();
   }
 
   // Read the value.
   JValue field_value;
-  if (!GetFieldValue(o, f, field_value, false)) {
+  if (!GetFieldValue(ts, o, f, field_value, false)) {
     return JValue();
   }
 
@@ -205,11 +208,11 @@
 }
 
 static void Field_set(JNIEnv* env, jobject javaField, jobject javaObj, jobject javaValue) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Field* f = DecodeField(env->FromReflectedField(javaField));
+  ScopedJniThreadState ts(env);
+  Field* f = ts.DecodeField(env->FromReflectedField(javaField));
 
   // Unbox the value, if necessary.
-  Object* boxed_value = Decode<Object*>(env, javaValue);
+  Object* boxed_value = ts.Decode<Object*>(javaValue);
   JValue unboxed_value;
   if (!UnboxPrimitiveForField(boxed_value, FieldHelper(f).GetType(), unboxed_value, f)) {
     return;
@@ -217,7 +220,7 @@
 
   // Check that the receiver is non-null and an instance of the field's declaring class.
   Object* o = NULL;
-  if (!CheckReceiver(env, javaObj, f, o)) {
+  if (!CheckReceiver(ts, javaObj, f, o)) {
     return;
   }
 
@@ -226,15 +229,15 @@
 
 static void SetPrimitiveField(JNIEnv* env, jobject javaField, jobject javaObj, char src_descriptor,
                               const JValue& new_value) {
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Field* f = DecodeField(env->FromReflectedField(javaField));
+  ScopedJniThreadState ts(env);
+  Field* f = ts.DecodeField(env->FromReflectedField(javaField));
   Object* o = NULL;
-  if (!CheckReceiver(env, javaObj, f, o)) {
+  if (!CheckReceiver(ts, javaObj, f, o)) {
     return;
   }
   FieldHelper fh(f);
   if (!fh.IsPrimitiveType()) {
-    Thread::Current()->ThrowNewExceptionF("Ljava/lang/IllegalArgumentException;",
+    ts.Self()->ThrowNewExceptionF("Ljava/lang/IllegalArgumentException;",
         "Not a primitive field: %s", PrettyField(f).c_str());
     return;
   }
diff --git a/src/native/java_lang_reflect_Method.cc b/src/native/java_lang_reflect_Method.cc
index bf5c850..2695822 100644
--- a/src/native/java_lang_reflect_Method.cc
+++ b/src/native/java_lang_reflect_Method.cc
@@ -19,15 +19,18 @@
 #include "object.h"
 #include "object_utils.h"
 #include "reflection.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
 static jobject Method_invoke(JNIEnv* env, jobject javaMethod, jobject javaReceiver, jobject javaArgs) {
-  return InvokeMethod(env, javaMethod, javaReceiver, javaArgs);
+  ScopedJniThreadState ts(env);
+  return InvokeMethod(ts, javaMethod, javaReceiver, javaArgs);
 }
 
 static jobject Method_getExceptionTypesNative(JNIEnv* env, jobject javaMethod) {
-  Method* proxy_method = Decode<Object*>(env, javaMethod)->AsMethod();
+  ScopedJniThreadState ts(env);
+  Method* proxy_method = ts.Decode<Object*>(javaMethod)->AsMethod();
   CHECK(proxy_method->GetDeclaringClass()->IsProxyClass());
   SynthesizedProxyClass* proxy_class =
       down_cast<SynthesizedProxyClass*>(proxy_method->GetDeclaringClass());
@@ -41,14 +44,13 @@
   }
   CHECK_NE(throws_index, -1);
   ObjectArray<Class>* declared_exceptions = proxy_class->GetThrows()->Get(throws_index);
-  // Change thread state for allocation
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  return AddLocalReference<jobject>(env, declared_exceptions->Clone());
+  return ts.AddLocalReference<jobject>(declared_exceptions->Clone());
 }
 
 static jobject Method_findOverriddenMethodNative(JNIEnv* env, jobject javaMethod) {
-  Method* method = Decode<Object*>(env, javaMethod)->AsMethod();
-  return AddLocalReference<jobject>(env, method->FindOverriddenMethod());
+  ScopedJniThreadState ts(env);
+  Method* method = ts.Decode<Object*>(javaMethod)->AsMethod();
+  return ts.AddLocalReference<jobject>(method->FindOverriddenMethod());
 }
 
 static JNINativeMethod gMethods[] = {
diff --git a/src/native/java_lang_reflect_Proxy.cc b/src/native/java_lang_reflect_Proxy.cc
index eca6c32..a1337a6 100644
--- a/src/native/java_lang_reflect_Proxy.cc
+++ b/src/native/java_lang_reflect_Proxy.cc
@@ -15,22 +15,23 @@
  */
 
 #include "class_linker.h"
+#include "class_loader.h"
 #include "jni_internal.h"
 #include "object.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
 static jclass Proxy_generateProxy(JNIEnv* env, jclass, jstring javaName, jobjectArray javaInterfaces, jobject javaLoader, jobjectArray javaMethods, jobjectArray javaThrows) {
-  // Allocates Class so transition thread state to runnable
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  String* name = Decode<String*>(env, javaName);
-  ObjectArray<Class>* interfaces = Decode<ObjectArray<Class>*>(env, javaInterfaces);
-  ClassLoader* loader = Decode<ClassLoader*>(env, javaLoader);
-  ObjectArray<Method>* methods = Decode<ObjectArray<Method>*>(env, javaMethods);
-  ObjectArray<ObjectArray<Class> >* throws = Decode<ObjectArray<ObjectArray<Class> >*>(env, javaThrows);
+  ScopedJniThreadState ts(env);
+  String* name = ts.Decode<String*>(javaName);
+  ObjectArray<Class>* interfaces = ts.Decode<ObjectArray<Class>*>(javaInterfaces);
+  ClassLoader* loader = ts.Decode<ClassLoader*>(javaLoader);
+  ObjectArray<Method>* methods = ts.Decode<ObjectArray<Method>*>(javaMethods);
+  ObjectArray<ObjectArray<Class> >* throws = ts.Decode<ObjectArray<ObjectArray<Class> >*>(javaThrows);
   ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
   Class* result = class_linker->CreateProxyClass(name, interfaces, loader, methods, throws);
-  return AddLocalReference<jclass>(env, result);
+  return ts.AddLocalReference<jclass>(result);
 }
 
 static JNINativeMethod gMethods[] = {
diff --git a/src/native/org_apache_harmony_dalvik_ddmc_DdmVmInternal.cc b/src/native/org_apache_harmony_dalvik_ddmc_DdmVmInternal.cc
index 3766546..87d2b22 100644
--- a/src/native/org_apache_harmony_dalvik_ddmc_DdmVmInternal.cc
+++ b/src/native/org_apache_harmony_dalvik_ddmc_DdmVmInternal.cc
@@ -18,6 +18,7 @@
 #include "jni_internal.h"
 #include "logging.h"
 #include "scoped_heap_lock.h"
+#include "scoped_jni_thread_state.h"
 #include "scoped_thread_list_lock.h"
 #include "ScopedPrimitiveArray.h"
 #include "stack.h"
@@ -68,7 +69,8 @@
   if (thread == NULL) {
     return NULL;
   }
-  jobject stack = GetThreadStack(env, thread);
+  ScopedJniThreadState ts(env);
+  jobject stack = GetThreadStack(ts, thread);
   return (stack != NULL) ? Thread::InternalStackTraceToStackTraceElementArray(env, stack) : NULL;
 }
 
diff --git a/src/native/sun_misc_Unsafe.cc b/src/native/sun_misc_Unsafe.cc
index 360f241..dfddd86 100644
--- a/src/native/sun_misc_Unsafe.cc
+++ b/src/native/sun_misc_Unsafe.cc
@@ -16,30 +16,34 @@
 
 #include "jni_internal.h"
 #include "object.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
 static jlong Unsafe_objectFieldOffset0(JNIEnv* env, jclass, jobject javaField) {
   // TODO: move to Java code
   jfieldID fid = env->FromReflectedField(javaField);
-  Field* field = DecodeField(fid);
+  ScopedJniThreadState ts(env);
+  Field* field = ts.DecodeField(fid);
   return field->GetOffset().Int32Value();
 }
 
 static jint Unsafe_arrayBaseOffset0(JNIEnv* env, jclass, jclass javaArrayClass) {
   // TODO: move to Java code
-  ScopedThreadStateChange tsc(Thread::Current(), kRunnable);
-  Class* array_class = Decode<Class*>(env, javaArrayClass);
+  ScopedJniThreadState ts(env);
+  Class* array_class = ts.Decode<Class*>(javaArrayClass);
   return Array::DataOffset(array_class->GetComponentSize()).Int32Value();
 }
 
 static jint Unsafe_arrayIndexScale0(JNIEnv* env, jclass, jclass javaClass) {
-  Class* c = Decode<Class*>(env, javaClass);
+  ScopedJniThreadState ts(env);
+  Class* c = ts.Decode<Class*>(javaClass);
   return c->GetComponentSize();
 }
 
 static jboolean Unsafe_compareAndSwapInt(JNIEnv* env, jobject, jobject javaObj, jlong offset, jint expectedValue, jint newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   byte* raw_addr = reinterpret_cast<byte*>(obj) + offset;
   volatile int32_t* address = reinterpret_cast<volatile int32_t*>(raw_addr);
   // Note: android_atomic_release_cas() returns 0 on success, not failure.
@@ -48,7 +52,8 @@
 }
 
 static jboolean Unsafe_compareAndSwapLong(JNIEnv* env, jobject, jobject javaObj, jlong offset, jlong expectedValue, jlong newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   byte* raw_addr = reinterpret_cast<byte*>(obj) + offset;
   volatile int64_t* address = reinterpret_cast<volatile int64_t*>(raw_addr);
   // Note: android_atomic_cmpxchg() returns 0 on success, not failure.
@@ -57,9 +62,10 @@
 }
 
 static jboolean Unsafe_compareAndSwapObject(JNIEnv* env, jobject, jobject javaObj, jlong offset, jobject javaExpectedValue, jobject javaNewValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
-  Object* expectedValue = Decode<Object*>(env, javaExpectedValue);
-  Object* newValue = Decode<Object*>(env, javaNewValue);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
+  Object* expectedValue = ts.Decode<Object*>(javaExpectedValue);
+  Object* newValue = ts.Decode<Object*>(javaNewValue);
   byte* raw_addr = reinterpret_cast<byte*>(obj) + offset;
   int32_t* address = reinterpret_cast<int32_t*>(raw_addr);
   // Note: android_atomic_cmpxchg() returns 0 on success, not failure.
@@ -72,90 +78,105 @@
 }
 
 static jint Unsafe_getInt(JNIEnv* env, jobject, jobject javaObj, jlong offset) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   return obj->GetField32(MemberOffset(offset), false);
 }
 
 static jint Unsafe_getIntVolatile(JNIEnv* env, jobject, jobject javaObj, jlong offset) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   byte* raw_addr = reinterpret_cast<byte*>(obj) + offset;
   volatile int32_t* address = reinterpret_cast<volatile int32_t*>(raw_addr);
   return android_atomic_acquire_load(address);
 }
 
 static void Unsafe_putInt(JNIEnv* env, jobject, jobject javaObj, jlong offset, jint newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   obj->SetField32(MemberOffset(offset), newValue, false);
 }
 
 static void Unsafe_putIntVolatile(JNIEnv* env, jobject, jobject javaObj, jlong offset, jint newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   byte* raw_addr = reinterpret_cast<byte*>(obj) + offset;
   volatile int32_t* address = reinterpret_cast<volatile int32_t*>(raw_addr);
   android_atomic_release_store(newValue, address);
 }
 
 static void Unsafe_putOrderedInt(JNIEnv* env, jobject, jobject javaObj, jlong offset, jint newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   ANDROID_MEMBAR_STORE();
   obj->SetField32(MemberOffset(offset), newValue, false);
 }
 
 static jlong Unsafe_getLong(JNIEnv* env, jobject, jobject javaObj, jlong offset) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   byte* raw_addr = reinterpret_cast<byte*>(obj) + offset;
   int64_t* address = reinterpret_cast<int64_t*>(raw_addr);
   return *address;
 }
 
 static jlong Unsafe_getLongVolatile(JNIEnv* env, jobject, jobject javaObj, jlong offset) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   return obj->GetField64(MemberOffset(offset), true);
 }
 
 static void Unsafe_putLong(JNIEnv* env, jobject, jobject javaObj, jlong offset, jlong newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   obj->SetField64(MemberOffset(offset), newValue, false);
 }
 
 static void Unsafe_putLongVolatile(JNIEnv* env, jobject, jobject javaObj, jlong offset, jlong newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   obj->SetField64(MemberOffset(offset), newValue, true);
 }
 
 static void Unsafe_putOrderedLong(JNIEnv* env, jobject, jobject javaObj, jlong offset, jlong newValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   ANDROID_MEMBAR_STORE();
   obj->SetField64(MemberOffset(offset), newValue, false);
 }
 
 static jobject Unsafe_getObjectVolatile(JNIEnv* env, jobject, jobject javaObj, jlong offset) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   Object* value = obj->GetFieldObject<Object*>(MemberOffset(offset), true);
-  return AddLocalReference<jobject>(env, value);
+  return ts.AddLocalReference<jobject>(value);
 }
 
 static jobject Unsafe_getObject(JNIEnv* env, jobject, jobject javaObj, jlong offset) {
-  Object* obj = Decode<Object*>(env, javaObj);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
   Object* value = obj->GetFieldObject<Object*>(MemberOffset(offset), false);
-  return AddLocalReference<jobject>(env, value);
+  return ts.AddLocalReference<jobject>(value);
 }
 
 static void Unsafe_putObject(JNIEnv* env, jobject, jobject javaObj, jlong offset, jobject javaNewValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
-  Object* newValue = Decode<Object*>(env, javaNewValue);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
+  Object* newValue = ts.Decode<Object*>(javaNewValue);
   obj->SetFieldObject(MemberOffset(offset), newValue, false);
 }
 
 static void Unsafe_putObjectVolatile(JNIEnv* env, jobject, jobject javaObj, jlong offset, jobject javaNewValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
-  Object* newValue = Decode<Object*>(env, javaNewValue);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
+  Object* newValue = ts.Decode<Object*>(javaNewValue);
   obj->SetFieldObject(MemberOffset(offset), newValue, true);
 }
 
 static void Unsafe_putOrderedObject(JNIEnv* env, jobject, jobject javaObj, jlong offset, jobject javaNewValue) {
-  Object* obj = Decode<Object*>(env, javaObj);
-  Object* newValue = Decode<Object*>(env, javaNewValue);
+  ScopedJniThreadState ts(env);
+  Object* obj = ts.Decode<Object*>(javaObj);
+  Object* newValue = ts.Decode<Object*>(javaNewValue);
   ANDROID_MEMBAR_STORE();
   obj->SetFieldObject(MemberOffset(offset), newValue, false);
 }
diff --git a/src/oat/runtime/support_proxy.cc b/src/oat/runtime/support_proxy.cc
index 37cacb4..83d2265 100644
--- a/src/oat/runtime/support_proxy.cc
+++ b/src/oat/runtime/support_proxy.cc
@@ -18,6 +18,7 @@
 #include "object_utils.h"
 #include "reflection.h"
 #include "runtime_support.h"
+#include "scoped_jni_thread_state.h"
 #include "thread.h"
 #include "well_known_classes.h"
 
@@ -50,10 +51,11 @@
   DCHECK_EQ(proxy_method->GetFrameSizeInBytes(), FRAME_SIZE_IN_BYTES);
   // Start new JNI local reference state
   JNIEnvExt* env = self->GetJniEnv();
+  ScopedJniThreadState ts(env);
   ScopedJniEnvLocalRefState env_state(env);
   // Create local ref. copies of proxy method and the receiver
-  jobject rcvr_jobj = AddLocalReference<jobject>(env, receiver);
-  jobject proxy_method_jobj = AddLocalReference<jobject>(env, proxy_method);
+  jobject rcvr_jobj = ts.AddLocalReference<jobject>(receiver);
+  jobject proxy_method_jobj = ts.AddLocalReference<jobject>(proxy_method);
 
   // Placing into local references incoming arguments from the caller's register arguments,
   // replacing original Object* with jobject
@@ -72,7 +74,7 @@
   while (cur_arg < args_in_regs && param_index < num_params) {
     if (proxy_mh.IsParamAReference(param_index)) {
       Object* obj = *reinterpret_cast<Object**>(stack_args + (cur_arg * kPointerSize));
-      jobject jobj = AddLocalReference<jobject>(env, obj);
+      jobject jobj = ts.AddLocalReference<jobject>(obj);
       *reinterpret_cast<jobject*>(stack_args + (cur_arg * kPointerSize)) = jobj;
     }
     cur_arg = cur_arg + (proxy_mh.IsParamALongOrDouble(param_index) ? 2 : 1);
@@ -83,7 +85,7 @@
   while (param_index < num_params) {
     if (proxy_mh.IsParamAReference(param_index)) {
       Object* obj = *reinterpret_cast<Object**>(stack_args + (cur_arg * kPointerSize));
-      jobject jobj = AddLocalReference<jobject>(env, obj);
+      jobject jobj = ts.AddLocalReference<jobject>(obj);
       *reinterpret_cast<jobject*>(stack_args + (cur_arg * kPointerSize)) = jobj;
     }
     cur_arg = cur_arg + (proxy_mh.IsParamALongOrDouble(param_index) ? 2 : 1);
@@ -102,13 +104,13 @@
       CHECK(self->IsExceptionPending());
       return;
     }
-    args_jobj[2].l = AddLocalReference<jobjectArray>(env, args);
+    args_jobj[2].l = ts.AddLocalReference<jobjectArray>(args);
   }
   // Convert proxy method into expected interface method
   Method* interface_method = proxy_method->FindOverriddenMethod();
   DCHECK(interface_method != NULL);
   DCHECK(!interface_method->IsProxyMethod()) << PrettyMethod(interface_method);
-  args_jobj[1].l = AddLocalReference<jobject>(env, interface_method);
+  args_jobj[1].l = ts.AddLocalReference<jobject>(interface_method);
   // Box arguments
   cur_arg = 0;  // reset stack location to read to start
   // reset index, will index into param type array which doesn't include the receiver
diff --git a/src/oat/runtime/support_stubs.cc b/src/oat/runtime/support_stubs.cc
index 522ccf2..3f6bc8f 100644
--- a/src/oat/runtime/support_stubs.cc
+++ b/src/oat/runtime/support_stubs.cc
@@ -23,6 +23,7 @@
 #if defined(ART_USE_LLVM_COMPILER)
 #include "nth_caller_visitor.h"
 #endif
+#include "scoped_jni_thread_state.h"
 
 // Architecture specific assembler helper to deliver exception.
 extern "C" void art_deliver_exception_from_code(void*);
@@ -81,6 +82,7 @@
   FinishCalleeSaveFrameSetup(thread, sp, Runtime::kRefsAndArgs);
   // Start new JNI local reference state
   JNIEnvExt* env = thread->GetJniEnv();
+  ScopedJniThreadState ts(env);
   ScopedJniEnvLocalRefState env_state(env);
 
   // Compute details about the called method (avoid GCs)
@@ -145,7 +147,7 @@
       // If we thought we had fewer than 3 arguments in registers, account for the receiver
       args_in_regs++;
     }
-    AddLocalReference<jobject>(env, obj);
+    ts.AddLocalReference<jobject>(obj);
   }
   size_t shorty_index = 1;  // skip return value
   // Iterate while arguments and arguments in registers (less 1 from cur_arg which is offset to skip
@@ -155,7 +157,7 @@
     shorty_index++;
     if (c == 'L') {
       Object* obj = reinterpret_cast<Object*>(regs[cur_arg]);
-      AddLocalReference<jobject>(env, obj);
+      ts.AddLocalReference<jobject>(obj);
     }
     cur_arg = cur_arg + (c == 'J' || c == 'D' ? 2 : 1);
   }
@@ -166,7 +168,7 @@
     shorty_index++;
     if (c == 'L') {
       Object* obj = reinterpret_cast<Object*>(regs[cur_arg]);
-      AddLocalReference<jobject>(env, obj);
+      ts.AddLocalReference<jobject>(obj);
     }
     cur_arg = cur_arg + (c == 'J' || c == 'D' ? 2 : 1);
   }
diff --git a/src/oat_compilation_unit.h b/src/oat_compilation_unit.h
index 0000f21..41c1847 100644
--- a/src/oat_compilation_unit.h
+++ b/src/oat_compilation_unit.h
@@ -30,7 +30,7 @@
 
 class OatCompilationUnit {
  public:
-  OatCompilationUnit(const ClassLoader* class_loader, ClassLinker* class_linker,
+  OatCompilationUnit(ClassLoader* class_loader, ClassLinker* class_linker,
                      const DexFile& dex_file, DexCache& dex_cache,
                      const DexFile::CodeItem* code_item,
                      uint32_t method_idx, uint32_t access_flags)
@@ -46,7 +46,7 @@
                                   callee_access_flags);
   }
 
-  const ClassLoader* GetClassLoader() const {
+  ClassLoader* GetClassLoader() const {
     return class_loader_;
   }
 
@@ -85,7 +85,7 @@
   }
 
  public:
-  const ClassLoader* class_loader_;
+  ClassLoader* class_loader_;
   ClassLinker* class_linker_;
 
   const DexFile* dex_file_;
diff --git a/src/oat_writer.cc b/src/oat_writer.cc
index 852320d..ec25ae9 100644
--- a/src/oat_writer.cc
+++ b/src/oat_writer.cc
@@ -30,7 +30,7 @@
 namespace art {
 
 bool OatWriter::Create(File* file,
-                       const ClassLoader* class_loader,
+                       ClassLoader* class_loader,
                        const std::vector<const DexFile*>& dex_files,
                        uint32_t image_file_location_checksum,
                        const std::string& image_file_location,
@@ -46,7 +46,7 @@
 OatWriter::OatWriter(const std::vector<const DexFile*>& dex_files,
                      uint32_t image_file_location_checksum,
                      const std::string& image_file_location,
-                     const ClassLoader* class_loader,
+                     ClassLoader* class_loader,
                      const Compiler& compiler) {
   compiler_ = &compiler;
   class_loader_ = class_loader;
diff --git a/src/oat_writer.h b/src/oat_writer.h
index fe0bd67..29072ab 100644
--- a/src/oat_writer.h
+++ b/src/oat_writer.h
@@ -74,7 +74,7 @@
  public:
   // Write an oat file. Returns true on success, false on failure.
   static bool Create(File* file,
-                     const ClassLoader* class_loader,
+                     ClassLoader* class_loader,
                      const std::vector<const DexFile*>& dex_files,
                      uint32_t image_file_location_checksum,
                      const std::string& image_file_location,
@@ -84,7 +84,7 @@
   OatWriter(const std::vector<const DexFile*>& dex_files,
             uint32_t image_file_location_checksum,
             const std::string& image_file_location,
-            const ClassLoader* class_loader,
+            ClassLoader* class_loader,
             const Compiler& compiler);
   ~OatWriter();
 
@@ -177,7 +177,7 @@
   const Compiler* compiler_;
 
   // TODO: remove the ClassLoader when the code storage moves out of Method
-  const ClassLoader* class_loader_;
+  ClassLoader* class_loader_;
 
   // note OatFile does not take ownership of the DexFiles
   const std::vector<const DexFile*>* dex_files_;
diff --git a/src/object.cc b/src/object.cc
index b728e28..94e1759 100644
--- a/src/object.cc
+++ b/src/object.cc
@@ -322,35 +322,6 @@
   java_lang_reflect_Method_ = NULL;
 }
 
-Class* ExtractNextClassFromSignature(ClassLinker* class_linker, const ClassLoader* cl, const char*& p) {
-  if (*p == '[') {
-    // Something like "[[[Ljava/lang/String;".
-    const char* start = p;
-    while (*p == '[') {
-      ++p;
-    }
-    if (*p == 'L') {
-      while (*p != ';') {
-        ++p;
-      }
-    }
-    ++p; // Either the ';' or the primitive type.
-
-    std::string descriptor(start, (p - start));
-    return class_linker->FindClass(descriptor.c_str(), cl);
-  } else if (*p == 'L') {
-    const char* start = p;
-    while (*p != ';') {
-      ++p;
-    }
-    ++p;
-    std::string descriptor(start, (p - start));
-    return class_linker->FindClass(descriptor.c_str(), cl);
-  } else {
-    return class_linker->FindPrimitiveClass(*p++);
-  }
-}
-
 ObjectArray<String>* Method::GetDexCacheStrings() const {
   return GetFieldObject<ObjectArray<String>*>(
       OFFSET_OF_OBJECT_MEMBER(Method, dex_cache_strings_), false);
@@ -937,8 +908,7 @@
   return GetFieldObject<ClassLoader*>(OFFSET_OF_OBJECT_MEMBER(Class, class_loader_), false);
 }
 
-void Class::SetClassLoader(const ClassLoader* new_cl) {
-  ClassLoader* new_class_loader = const_cast<ClassLoader*>(new_cl);
+void Class::SetClassLoader(ClassLoader* new_class_loader) {
   SetFieldObject(OFFSET_OF_OBJECT_MEMBER(Class, class_loader_), new_class_loader, false);
 }
 
diff --git a/src/object.h b/src/object.h
index eeac3ea..6334f8e 100644
--- a/src/object.h
+++ b/src/object.h
@@ -1513,7 +1513,7 @@
 
   ClassLoader* GetClassLoader() const;
 
-  void SetClassLoader(const ClassLoader* new_cl);
+  void SetClassLoader(ClassLoader* new_cl);
 
   static MemberOffset DexCacheOffset() {
     return MemberOffset(OFFSETOF_MEMBER(Class, dex_cache_));
diff --git a/src/reflection.cc b/src/reflection.cc
index 2b72944..7726998 100644
--- a/src/reflection.cc
+++ b/src/reflection.cc
@@ -20,6 +20,7 @@
 #include "jni_internal.h"
 #include "object.h"
 #include "object_utils.h"
+#include "scoped_jni_thread_state.h"
 
 namespace art {
 
@@ -44,12 +45,10 @@
   gShort_valueOf = class_linker->FindSystemClass("Ljava/lang/Short;")->FindDeclaredDirectMethod("valueOf", "(S)Ljava/lang/Short;");
 }
 
-jobject InvokeMethod(JNIEnv* env, jobject javaMethod, jobject javaReceiver, jobject javaArgs) {
-  Thread* self = Thread::Current();
-  ScopedThreadStateChange tsc(self, kRunnable);
-
-  jmethodID mid = env->FromReflectedMethod(javaMethod);
-  Method* m = reinterpret_cast<Method*>(mid);
+jobject InvokeMethod(const ScopedJniThreadState& ts, jobject javaMethod, jobject javaReceiver,
+                     jobject javaArgs) {
+  jmethodID mid = ts.Env()->FromReflectedMethod(javaMethod);
+  Method* m = ts.DecodeMethod(mid);
 
   Class* declaring_class = m->GetDeclaringClass();
   if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(declaring_class, true, true)) {
@@ -59,24 +58,24 @@
   Object* receiver = NULL;
   if (!m->IsStatic()) {
     // Check that the receiver is non-null and an instance of the field's declaring class.
-    receiver = Decode<Object*>(env, javaReceiver);
+    receiver = ts.Decode<Object*>(javaReceiver);
     if (!VerifyObjectInClass(receiver, declaring_class)) {
       return NULL;
     }
 
     // Find the actual implementation of the virtual method.
     m = receiver->GetClass()->FindVirtualMethodForVirtualOrInterface(m);
-    mid = reinterpret_cast<jmethodID>(m);
+    mid = ts.EncodeMethod(m);
   }
 
   // Get our arrays of arguments and their types, and check they're the same size.
-  ObjectArray<Object>* objects = Decode<ObjectArray<Object>*>(env, javaArgs);
+  ObjectArray<Object>* objects = ts.Decode<ObjectArray<Object>*>(javaArgs);
   MethodHelper mh(m);
   const DexFile::TypeList* classes = mh.GetParameterTypeList();
   uint32_t classes_size = classes == NULL ? 0 : classes->Size();
   uint32_t arg_count = (objects != NULL) ? objects->GetLength() : 0;
   if (arg_count != classes_size) {
-    self->ThrowNewExceptionF("Ljava/lang/IllegalArgumentException;",
+    ts.Self()->ThrowNewExceptionF("Ljava/lang/IllegalArgumentException;",
         "wrong number of arguments; expected %d, got %d",
         classes_size, arg_count);
     return NULL;
@@ -92,27 +91,27 @@
       return NULL;
     }
     if (!dst_class->IsPrimitive()) {
-      args[i].l = AddLocalReference<jobject>(env, arg);
+      args[i].l = ts.AddLocalReference<jobject>(arg);
     }
   }
 
   // Invoke the method.
-  JValue value(InvokeWithJValues(env, javaReceiver, mid, args.get()));
+  JValue value(InvokeWithJValues(ts, javaReceiver, mid, args.get()));
 
   // Wrap any exception with "Ljava/lang/reflect/InvocationTargetException;" and return early.
-  if (self->IsExceptionPending()) {
-    jthrowable th = env->ExceptionOccurred();
-    env->ExceptionClear();
-    jclass exception_class = env->FindClass("java/lang/reflect/InvocationTargetException");
-    jmethodID mid = env->GetMethodID(exception_class, "<init>", "(Ljava/lang/Throwable;)V");
-    jobject exception_instance = env->NewObject(exception_class, mid, th);
-    env->Throw(reinterpret_cast<jthrowable>(exception_instance));
+  if (ts.Self()->IsExceptionPending()) {
+    jthrowable th = ts.Env()->ExceptionOccurred();
+    ts.Env()->ExceptionClear();
+    jclass exception_class = ts.Env()->FindClass("java/lang/reflect/InvocationTargetException");
+    jmethodID mid = ts.Env()->GetMethodID(exception_class, "<init>", "(Ljava/lang/Throwable;)V");
+    jobject exception_instance = ts.Env()->NewObject(exception_class, mid, th);
+    ts.Env()->Throw(reinterpret_cast<jthrowable>(exception_instance));
     return NULL;
   }
 
   // Box if necessary and return.
   BoxPrimitive(mh.GetReturnType()->GetPrimitiveType(), value);
-  return AddLocalReference<jobject>(env, value.GetL());
+  return ts.AddLocalReference<jobject>(value.GetL());
 }
 
 bool VerifyObjectInClass(Object* o, Class* c) {
diff --git a/src/reflection.h b/src/reflection.h
index 6b47440..03847f8 100644
--- a/src/reflection.h
+++ b/src/reflection.h
@@ -27,6 +27,7 @@
 union JValue;
 class Method;
 class Object;
+class ScopedJniThreadState;
 
 void InitBoxingMethods();
 void BoxPrimitive(Primitive::Type src_class, JValue& value);
@@ -36,7 +37,7 @@
 
 bool ConvertPrimitiveValue(Primitive::Type src_class, Primitive::Type dst_class, const JValue& src, JValue& dst);
 
-jobject InvokeMethod(JNIEnv* env, jobject method, jobject receiver, jobject args);
+jobject InvokeMethod(const ScopedJniThreadState& ts, jobject method, jobject receiver, jobject args);
 
 bool VerifyObjectInClass(Object* o, Class* c);
 
diff --git a/src/runtime.cc b/src/runtime.cc
index b071ef4..5f20a4b 100644
--- a/src/runtime.cc
+++ b/src/runtime.cc
@@ -36,6 +36,7 @@
 #include "monitor.h"
 #include "oat_file.h"
 #include "scoped_heap_lock.h"
+#include "scoped_jni_thread_state.h"
 #include "ScopedLocalRef.h"
 #include "signal_catcher.h"
 #include "signal_set.h"
@@ -80,7 +81,9 @@
       method_trace_(0),
       method_trace_file_size_(0),
       tracer_(NULL),
-      use_compile_time_class_path_(false) {
+      use_compile_time_class_path_(false),
+      main_thread_group_(NULL),
+      system_thread_group_(NULL) {
   for (int i = 0; i < Runtime::kLastTrampolineMethodType; i++) {
     resolution_stub_array_[i] = NULL;
   }
@@ -534,33 +537,33 @@
   return instance_;
 }
 
-void CreateSystemClassLoader() {
+static void CreateSystemClassLoader() {
   if (Runtime::Current()->UseCompileTimeClassPath()) {
     return;
   }
 
-  Thread* self = Thread::Current();
+  ScopedJniThreadState ts(Thread::Current());
 
-  // Must be in the kNative state for calling native methods.
-  CHECK_EQ(self->GetState(), kNative);
+  Class* class_loader_class = ts.Decode<Class*>(WellKnownClasses::java_lang_ClassLoader);
+  CHECK(Runtime::Current()->GetClassLinker()->EnsureInitialized(class_loader_class, true, true));
 
-  JNIEnv* env = self->GetJniEnv();
-  jmethodID getSystemClassLoader = env->GetStaticMethodID(WellKnownClasses::java_lang_ClassLoader,
-                                                          "getSystemClassLoader",
-                                                          "()Ljava/lang/ClassLoader;");
+  Method* getSystemClassLoader = class_loader_class->FindDirectMethod("getSystemClassLoader", "()Ljava/lang/ClassLoader;");
   CHECK(getSystemClassLoader != NULL);
-  ScopedLocalRef<jobject> class_loader(env, env->CallStaticObjectMethod(WellKnownClasses::java_lang_ClassLoader,
-                                                                        getSystemClassLoader));
-  CHECK(class_loader.get() != NULL);
 
-  Thread::Current()->SetClassLoaderOverride(Decode<ClassLoader*>(env, class_loader.get()));
+  ClassLoader* class_loader =
+    down_cast<ClassLoader*>(InvokeWithJValues(ts, NULL, getSystemClassLoader, NULL).GetL());
+  CHECK(class_loader != NULL);
 
-  jfieldID contextClassLoader = env->GetFieldID(WellKnownClasses::java_lang_Thread,
-                                                "contextClassLoader",
-                                                "Ljava/lang/ClassLoader;");
+  ts.Self()->SetClassLoaderOverride(class_loader);
+
+  Class* thread_class = ts.Decode<Class*>(WellKnownClasses::java_lang_Thread);
+  CHECK(Runtime::Current()->GetClassLinker()->EnsureInitialized(thread_class, true, true));
+
+  Field* contextClassLoader = thread_class->FindDeclaredInstanceField("contextClassLoader",
+                                                                      "Ljava/lang/ClassLoader;");
   CHECK(contextClassLoader != NULL);
-  ScopedLocalRef<jobject> self_jobject(env, AddLocalReference<jobject>(env, self->GetPeer()));
-  env->SetObjectField(self_jobject.get(), contextClassLoader, class_loader.get());
+
+  contextClassLoader->SetObject(ts.Self()->GetPeer(), class_loader);
 }
 
 void Runtime::Start() {
@@ -587,6 +590,9 @@
   // it touches will have methods linked to the oat file if necessary.
   InitNativeMethods();
 
+  // Initialize well known thread group values that may be accessed threads while attaching.
+  InitThreadGroups(self);
+
   Thread::FinishStartup();
 
   if (!is_zygote_) {
@@ -739,6 +745,17 @@
   VLOG(startup) << "Runtime::InitNativeMethods exiting";
 }
 
+void Runtime::InitThreadGroups(Thread* self) {
+  JNIEnvExt* env = self->GetJniEnv();
+  ScopedJniEnvLocalRefState env_state(env);
+  main_thread_group_ =
+      env->NewGlobalRef(env->GetStaticObjectField(WellKnownClasses::java_lang_ThreadGroup,
+                                                  WellKnownClasses::java_lang_ThreadGroup_mainThreadGroup));
+  system_thread_group_ =
+      env->NewGlobalRef(env->GetStaticObjectField(WellKnownClasses::java_lang_ThreadGroup,
+                                                  WellKnownClasses::java_lang_ThreadGroup_systemThreadGroup));
+}
+
 void Runtime::RegisterRuntimeNativeMethods(JNIEnv* env) {
 #define REGISTER(FN) extern void FN(JNIEnv*); FN(env)
   // Register Throwable first so that registration of other native methods can throw exceptions
@@ -850,7 +867,7 @@
   signals.Block();
 }
 
-void Runtime::AttachCurrentThread(const char* thread_name, bool as_daemon, Object* thread_group) {
+void Runtime::AttachCurrentThread(const char* thread_name, bool as_daemon, jobject thread_group) {
   Thread::Attach(thread_name, as_daemon, thread_group);
   if (thread_name == NULL) {
     LOG(WARNING) << *Thread::Current() << " attached without supplying a name";
diff --git a/src/runtime.h b/src/runtime.h
index e3e0caf..a6cebe7 100644
--- a/src/runtime.h
+++ b/src/runtime.h
@@ -136,8 +136,18 @@
   // that the native stack trace we get may point at the wrong call site.
   static void Abort();
 
+  // Returns the "main" ThreadGroup, used when attaching user threads.
+  jobject GetMainThreadGroup() const {
+    return main_thread_group_;
+  }
+
+  // Returns the "system" ThreadGroup, used when attaching our internal threads.
+  jobject GetSystemThreadGroup() const {
+    return system_thread_group_;
+  }
+
   // Attaches the calling native thread to the runtime.
-  void AttachCurrentThread(const char* thread_name, bool as_daemon, Object* thread_group);
+  void AttachCurrentThread(const char* thread_name, bool as_daemon, jobject thread_group);
 
   void CallExitHook(jint status);
 
@@ -323,6 +333,7 @@
 
   bool Init(const Options& options, bool ignore_unrecognized);
   void InitNativeMethods();
+  void InitThreadGroups(Thread* self);
   void RegisterRuntimeNativeMethods(JNIEnv* env);
 
   void StartDaemonThreads();
@@ -409,6 +420,9 @@
   CompileTimeClassPaths compile_time_class_paths_;
   bool use_compile_time_class_path_;
 
+  jobject main_thread_group_;
+  jobject system_thread_group_;
+
   DISALLOW_COPY_AND_ASSIGN(Runtime);
 };
 
diff --git a/src/scoped_jni_thread_state.h b/src/scoped_jni_thread_state.h
index 9b3e63c..42ed19c 100644
--- a/src/scoped_jni_thread_state.h
+++ b/src/scoped_jni_thread_state.h
@@ -14,38 +14,147 @@
  * limitations under the License.
  */
 
+#include "casts.h"
 #include "jni_internal.h"
+#include "thread.h"
 
 namespace art {
 
-// Entry/exit processing for all JNI calls.
+// Entry/exit processing for transitions from Native to Runnable (ie within JNI functions).
 //
-// This performs the necessary thread state switching, lets us amortize the
-// cost of working out the current thread, and lets us check (and repair) apps
-// that are using a JNIEnv on the wrong thread.
+// This class performs the necessary thread state switching to and from Runnable and lets us
+// amortize the cost of working out the current thread. Additionally it lets us check (and repair)
+// apps that are using a JNIEnv on the wrong thread. The class also decodes and encodes Objects
+// into jobjects via methods of this class. Performing this here enforces the Runnable thread state
+// for use of Object, thereby inhibiting the Object being modified by GC whilst native or VM code
+// is also manipulating the Object.
+//
+// The destructor transitions back to the previous thread state, typically Native. In this case
+// GC and thread suspension may occur.
 class ScopedJniThreadState {
  public:
   explicit ScopedJniThreadState(JNIEnv* env, ThreadState new_state = kRunnable)
-      : env_(reinterpret_cast<JNIEnvExt*>(env)) {
-    self_ = ThreadForEnv(env);
-    old_thread_state_ = self_->SetState(new_state);
+    : env_(reinterpret_cast<JNIEnvExt*>(env)), vm_(env_->vm), self_(ThreadForEnv(env)),
+      old_thread_state_(self_->SetState(new_state)), thread_state_(new_state) {
     self_->VerifyStack();
   }
 
-  ~ScopedJniThreadState() {
-    self_->SetState(old_thread_state_);
+  explicit ScopedJniThreadState(Thread* self, ThreadState new_state = kRunnable)
+    : env_(reinterpret_cast<JNIEnvExt*>(self->GetJniEnv())), vm_(env_->vm), self_(self),
+      old_thread_state_(self_->SetState(new_state)), thread_state_(new_state) {
+    self_->VerifyStack();
   }
 
-  JNIEnvExt* Env() {
+  // Used when we want a scoped jni thread state but have no thread/JNIEnv.
+  explicit ScopedJniThreadState(JavaVM* vm)
+    : env_(NULL), vm_(reinterpret_cast<JavaVMExt*>(vm)), self_(NULL),
+      old_thread_state_(kTerminated), thread_state_(kTerminated) {
+  }
+
+  ~ScopedJniThreadState() {
+    if (self_ != NULL) {
+      self_->SetState(old_thread_state_);
+    }
+  }
+
+  JNIEnvExt* Env() const {
     return env_;
   }
 
-  Thread* Self() {
+  Thread* Self() const {
     return self_;
   }
 
-  JavaVMExt* Vm() {
-    return env_->vm;
+  JavaVMExt* Vm() const {
+    return vm_;
+  }
+
+  /*
+   * Add a local reference for an object to the indirect reference table associated with the
+   * current stack frame.  When the native function returns, the reference will be discarded.
+   * Part of the ScopedJniThreadState as native code shouldn't be working on raw Object* without
+   * having transitioned its state.
+   *
+   * We need to allow the same reference to be added multiple times.
+   *
+   * This will be called on otherwise unreferenced objects.  We cannot do GC allocations here, and
+   * it's best if we don't grab a mutex.
+   *
+   * Returns the local reference (currently just the same pointer that was
+   * passed in), or NULL on failure.
+   */
+  template<typename T>
+  T AddLocalReference(Object* obj) const {
+    DCHECK_EQ(thread_state_, kRunnable);  // Don't work with raw objects in non-runnable states.
+    if (obj == NULL) {
+      return NULL;
+    }
+
+    DCHECK_NE((reinterpret_cast<uintptr_t>(obj) & 0xffff0000), 0xebad0000);
+
+    IndirectReferenceTable& locals = Env()->locals;
+
+    uint32_t cookie = Env()->local_ref_cookie;
+    IndirectRef ref = locals.Add(cookie, obj);
+
+  #if 0 // TODO: fix this to understand PushLocalFrame, so we can turn it on.
+    if (Env()->check_jni) {
+      size_t entry_count = locals.Capacity();
+      if (entry_count > 16) {
+        LOG(WARNING) << "Warning: more than 16 JNI local references: "
+                     << entry_count << " (most recent was a " << PrettyTypeOf(obj) << ")\n"
+                     << Dumpable<IndirectReferenceTable>(locals);
+        // TODO: LOG(FATAL) in a later release?
+      }
+    }
+  #endif
+
+    if (Vm()->work_around_app_jni_bugs) {
+      // Hand out direct pointers to support broken old apps.
+      return reinterpret_cast<T>(obj);
+    }
+
+    return reinterpret_cast<T>(ref);
+  }
+
+  template<typename T>
+  T Decode(jobject obj) const {
+    DCHECK_EQ(thread_state_, kRunnable);  // Don't work with raw objects in non-runnable states.
+    return down_cast<T>(Self()->DecodeJObject(obj));
+  }
+
+  Field* DecodeField(jfieldID fid) const {
+    DCHECK_EQ(thread_state_, kRunnable);  // Don't work with raw objects in non-runnable states.
+  #ifdef MOVING_GARBAGE_COLLECTOR
+    // TODO: we should make these unique weak globals if Field instances can ever move.
+    UNIMPLEMENTED(WARNING);
+  #endif
+    return reinterpret_cast<Field*>(fid);
+  }
+
+  jfieldID EncodeField(Field* field) const {
+    DCHECK_EQ(thread_state_, kRunnable);  // Don't work with raw objects in non-runnable states.
+  #ifdef MOVING_GARBAGE_COLLECTOR
+    UNIMPLEMENTED(WARNING);
+  #endif
+    return reinterpret_cast<jfieldID>(field);
+  }
+
+  Method* DecodeMethod(jmethodID mid) const {
+    DCHECK_EQ(thread_state_, kRunnable);  // Don't work with raw objects in non-runnable states.
+  #ifdef MOVING_GARBAGE_COLLECTOR
+    // TODO: we should make these unique weak globals if Method instances can ever move.
+    UNIMPLEMENTED(WARNING);
+  #endif
+    return reinterpret_cast<Method*>(mid);
+  }
+
+  jmethodID EncodeMethod(Method* method) const {
+    DCHECK_EQ(thread_state_, kRunnable);  // Don't work with raw objects in non-runnable states.
+  #ifdef MOVING_GARBAGE_COLLECTOR
+    UNIMPLEMENTED(WARNING);
+  #endif
+    return reinterpret_cast<jmethodID>(method);
   }
 
  private:
@@ -62,9 +171,16 @@
     return self;
   }
 
-  JNIEnvExt* env_;
-  Thread* self_;
-  ThreadState old_thread_state_;
+  // The full JNIEnv.
+  JNIEnvExt* const env_;
+  // The full JavaVM.
+  JavaVMExt* const vm_;
+  // Cached current thread derived from the JNIEnv.
+  Thread* const self_;
+  // Previous thread state, most likely kNative.
+  const ThreadState old_thread_state_;
+  // Local cache of thread state to enable quick sanity checks.
+  const ThreadState thread_state_;
   DISALLOW_COPY_AND_ASSIGN(ScopedJniThreadState);
 };
 
diff --git a/src/signal_catcher.cc b/src/signal_catcher.cc
index d3c799c..919923e 100644
--- a/src/signal_catcher.cc
+++ b/src/signal_catcher.cc
@@ -180,7 +180,7 @@
   CHECK(signal_catcher != NULL);
 
   Runtime* runtime = Runtime::Current();
-  runtime->AttachCurrentThread("Signal Catcher", true, Thread::GetSystemThreadGroup());
+  runtime->AttachCurrentThread("Signal Catcher", true, runtime->GetSystemThreadGroup());
 
   Thread* self = Thread::Current();
   self->SetState(kRunnable);
diff --git a/src/stack.cc b/src/stack.cc
index 336f8ad..07a1cb1 100644
--- a/src/stack.cc
+++ b/src/stack.cc
@@ -26,7 +26,8 @@
 
 class StackGetter {
  public:
-  StackGetter(JNIEnv* env, Thread* thread) : env_(env), thread_(thread), trace_(NULL) {
+  StackGetter(const ScopedJniThreadState& ts, Thread* thread)
+    : ts_(ts), thread_(thread), trace_(NULL) {
   }
 
   static void Callback(void* arg) {
@@ -39,17 +40,17 @@
 
  private:
   void Callback() {
-    trace_ = thread_->CreateInternalStackTrace(env_);
+    trace_ = thread_->CreateInternalStackTrace(ts_);
   }
 
-  JNIEnv* env_;
-  Thread* thread_;
+  const ScopedJniThreadState& ts_;
+  Thread* const thread_;
   jobject trace_;
 };
 
-jobject GetThreadStack(JNIEnv* env, Thread* thread) {
+jobject GetThreadStack(const ScopedJniThreadState& ts, Thread* thread) {
   ThreadList* thread_list = Runtime::Current()->GetThreadList();
-  StackGetter stack_getter(env, thread);
+  StackGetter stack_getter(ts, thread);
   thread_list->RunWhileSuspended(thread, StackGetter::Callback, &stack_getter);
   return stack_getter.GetTrace();
 }
diff --git a/src/stack.h b/src/stack.h
index bd0aee6..243ca28 100644
--- a/src/stack.h
+++ b/src/stack.h
@@ -31,9 +31,10 @@
 class Method;
 class Object;
 class ShadowFrame;
+class ScopedJniThreadState;
 class Thread;
 
-jobject GetThreadStack(JNIEnv*, Thread*);
+jobject GetThreadStack(const ScopedJniThreadState&, Thread*);
 
 class ShadowFrame {
  public:
diff --git a/src/thread.cc b/src/thread.cc
index a8ba701..ba2919a 100644
--- a/src/thread.cc
+++ b/src/thread.cc
@@ -107,38 +107,41 @@
   runtime->GetThreadList()->WaitForGo();
 
   {
-    CHECK_EQ(self->GetState(), kRunnable);
-    SirtRef<String> thread_name(self->GetThreadName());
-    self->SetThreadName(thread_name->ToModifiedUtf8().c_str());
+    ScopedJniThreadState ts(self);
+    {
+      SirtRef<String> thread_name(self->GetThreadName(ts));
+      self->SetThreadName(thread_name->ToModifiedUtf8().c_str());
+    }
+
+    Dbg::PostThreadStart(self);
+
+    // Invoke the 'run' method of our java.lang.Thread.
+    CHECK(self->peer_ != NULL);
+    Object* receiver = self->peer_;
+    jmethodID mid = WellKnownClasses::java_lang_Thread_run;
+    Method* m = receiver->GetClass()->FindVirtualMethodForVirtualOrInterface(ts.DecodeMethod(mid));
+    m->Invoke(self, receiver, NULL, NULL);
   }
 
-  Dbg::PostThreadStart(self);
-
-  // Invoke the 'run' method of our java.lang.Thread.
-  CHECK(self->peer_ != NULL);
-  Object* receiver = self->peer_;
-  jmethodID mid = WellKnownClasses::java_lang_Thread_run;
-  Method* m = receiver->GetClass()->FindVirtualMethodForVirtualOrInterface(DecodeMethod(mid));
-  m->Invoke(self, receiver, NULL, NULL);
-
   // Detach.
   runtime->GetThreadList()->Unregister();
 
   return NULL;
 }
 
-static void SetVmData(Object* managed_thread, Thread* native_thread) {
-  Field* f = DecodeField(WellKnownClasses::java_lang_Thread_vmData);
+static void SetVmData(const ScopedJniThreadState& ts, Object* managed_thread,
+                      Thread* native_thread) {
+  Field* f = ts.DecodeField(WellKnownClasses::java_lang_Thread_vmData);
   f->SetInt(managed_thread, reinterpret_cast<uintptr_t>(native_thread));
 }
 
-Thread* Thread::FromManagedThread(Object* thread_peer) {
-  Field* f = DecodeField(WellKnownClasses::java_lang_Thread_vmData);
+Thread* Thread::FromManagedThread(const ScopedJniThreadState& ts, Object* thread_peer) {
+  Field* f = ts.DecodeField(WellKnownClasses::java_lang_Thread_vmData);
   return reinterpret_cast<Thread*>(static_cast<uintptr_t>(f->GetInt(thread_peer)));
 }
 
-Thread* Thread::FromManagedThread(JNIEnv* env, jobject java_thread) {
-  return FromManagedThread(Decode<Object*>(env, java_thread));
+Thread* Thread::FromManagedThread(const ScopedJniThreadState& ts, jobject java_thread) {
+  return FromManagedThread(ts, ts.Decode<Object*>(java_thread));
 }
 
 static size_t FixStackSize(size_t stack_size) {
@@ -204,41 +207,43 @@
   delete[] allocated_signal_stack;
 }
 
-void Thread::CreateNativeThread(Object* peer, size_t stack_size) {
-  CHECK(peer != NULL);
-
-  stack_size = FixStackSize(stack_size);
-
+void Thread::CreateNativeThread(JNIEnv* env, jobject java_peer, size_t stack_size) {
   Thread* native_thread = new Thread;
-  native_thread->peer_ = peer;
-
-  // Thread.start is synchronized, so we know that vmData is 0,
-  // and know that we're not racing to assign it.
-  SetVmData(peer, native_thread);
-
-  int pthread_create_result = 0;
   {
-    ScopedThreadStateChange tsc(Thread::Current(), kVmWait);
-    pthread_t new_pthread;
-    pthread_attr_t attr;
-    CHECK_PTHREAD_CALL(pthread_attr_init, (&attr), "new thread");
-    CHECK_PTHREAD_CALL(pthread_attr_setdetachstate, (&attr, PTHREAD_CREATE_DETACHED), "PTHREAD_CREATE_DETACHED");
-    CHECK_PTHREAD_CALL(pthread_attr_setstacksize, (&attr, stack_size), stack_size);
-    pthread_create_result = pthread_create(&new_pthread, &attr, Thread::CreateCallback, native_thread);
-    CHECK_PTHREAD_CALL(pthread_attr_destroy, (&attr), "new thread");
+    ScopedJniThreadState ts(env);
+    Object* peer = ts.Decode<Object*>(java_peer);
+    CHECK(peer != NULL);
+    native_thread->peer_ = peer;
+
+    stack_size = FixStackSize(stack_size);
+
+    // Thread.start is synchronized, so we know that vmData is 0,
+    // and know that we're not racing to assign it.
+    SetVmData(ts, peer, native_thread);
+
+    int pthread_create_result = 0;
+    {
+      ScopedThreadStateChange tsc(Thread::Current(), kVmWait);
+      pthread_t new_pthread;
+      pthread_attr_t attr;
+      CHECK_PTHREAD_CALL(pthread_attr_init, (&attr), "new thread");
+      CHECK_PTHREAD_CALL(pthread_attr_setdetachstate, (&attr, PTHREAD_CREATE_DETACHED), "PTHREAD_CREATE_DETACHED");
+      CHECK_PTHREAD_CALL(pthread_attr_setstacksize, (&attr, stack_size), stack_size);
+      pthread_create_result = pthread_create(&new_pthread, &attr, Thread::CreateCallback, native_thread);
+      CHECK_PTHREAD_CALL(pthread_attr_destroy, (&attr), "new thread");
+    }
+
+    if (pthread_create_result != 0) {
+      // pthread_create(3) failed, so clean up.
+      SetVmData(ts, peer, 0);
+      delete native_thread;
+
+      std::string msg(StringPrintf("pthread_create (%s stack) failed: %s",
+                                   PrettySize(stack_size).c_str(), strerror(pthread_create_result)));
+      Thread::Current()->ThrowOutOfMemoryError(msg.c_str());
+      return;
+    }
   }
-
-  if (pthread_create_result != 0) {
-    // pthread_create(3) failed, so clean up.
-    SetVmData(peer, 0);
-    delete native_thread;
-
-    std::string msg(StringPrintf("pthread_create (%s stack) failed: %s",
-                                 PrettySize(stack_size).c_str(), strerror(pthread_create_result)));
-    Thread::Current()->ThrowOutOfMemoryError(msg.c_str());
-    return;
-  }
-
   // Let the child know when it's safe to start running.
   Runtime::Current()->GetThreadList()->SignalGo(native_thread);
 }
@@ -271,7 +276,7 @@
   runtime->GetThreadList()->Register();
 }
 
-Thread* Thread::Attach(const char* thread_name, bool as_daemon, Object* thread_group) {
+Thread* Thread::Attach(const char* thread_name, bool as_daemon, jobject thread_group) {
   Thread* self = new Thread;
   self->Init();
 
@@ -295,30 +300,14 @@
   return self;
 }
 
-static Object* GetWellKnownThreadGroup(jfieldID which) {
-  Class* c = WellKnownClasses::ToClass(WellKnownClasses::java_lang_ThreadGroup);
-  if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(c, true, true)) {
-    return NULL;
-  }
-  return DecodeField(which)->GetObject(NULL);
-}
-
-Object* Thread::GetMainThreadGroup() {
-  return GetWellKnownThreadGroup(WellKnownClasses::java_lang_ThreadGroup_mainThreadGroup);
-}
-
-Object* Thread::GetSystemThreadGroup() {
-  return GetWellKnownThreadGroup(WellKnownClasses::java_lang_ThreadGroup_systemThreadGroup);
-}
-
-void Thread::CreatePeer(const char* name, bool as_daemon, Object* thread_group) {
-  CHECK(Runtime::Current()->IsStarted());
+void Thread::CreatePeer(const char* name, bool as_daemon, jobject thread_group) {
+  Runtime* runtime = Runtime::Current();
+  CHECK(runtime->IsStarted());
   JNIEnv* env = jni_env_;
 
   if (thread_group == NULL) {
-    thread_group = Thread::GetMainThreadGroup();
+    thread_group = runtime->GetMainThreadGroup();
   }
-  ScopedLocalRef<jobject> java_thread_group(env, AddLocalReference<jobject>(env, thread_group));
   ScopedLocalRef<jobject> thread_name(env, env->NewStringUTF(name));
   jint thread_priority = GetNativePriority();
   jboolean thread_is_daemon = as_daemon;
@@ -332,21 +321,22 @@
   env->CallNonvirtualVoidMethod(peer.get(),
                                 WellKnownClasses::java_lang_Thread,
                                 WellKnownClasses::java_lang_Thread_init,
-                                java_thread_group.get(), thread_name.get(), thread_priority, thread_is_daemon);
+                                thread_group, thread_name.get(), thread_priority, thread_is_daemon);
   CHECK(!IsExceptionPending()) << " " << PrettyTypeOf(GetException());
-  SetVmData(peer_, Thread::Current());
 
-  SirtRef<String> peer_thread_name(GetThreadName());
+  ScopedJniThreadState ts(this);
+  SetVmData(ts, peer_, Thread::Current());
+  SirtRef<String> peer_thread_name(GetThreadName(ts));
   if (peer_thread_name.get() == NULL) {
     // The Thread constructor should have set the Thread.name to a
     // non-null value. However, because we can run without code
     // available (in the compiler, in tests), we manually assign the
     // fields the constructor should have set.
-    DecodeField(WellKnownClasses::java_lang_Thread_daemon)->SetBoolean(peer_, thread_is_daemon);
-    DecodeField(WellKnownClasses::java_lang_Thread_group)->SetObject(peer_, thread_group);
-    DecodeField(WellKnownClasses::java_lang_Thread_name)->SetObject(peer_, Decode<Object*>(env, thread_name.get()));
-    DecodeField(WellKnownClasses::java_lang_Thread_priority)->SetInt(peer_, thread_priority);
-    peer_thread_name.reset(GetThreadName());
+    ts.DecodeField(WellKnownClasses::java_lang_Thread_daemon)->SetBoolean(peer_, thread_is_daemon);
+    ts.DecodeField(WellKnownClasses::java_lang_Thread_group)->SetObject(peer_, ts.Decode<Object*>(thread_group));
+    ts.DecodeField(WellKnownClasses::java_lang_Thread_name)->SetObject(peer_, ts.Decode<Object*>(thread_name.get()));
+    ts.DecodeField(WellKnownClasses::java_lang_Thread_priority)->SetInt(peer_, thread_priority);
+    peer_thread_name.reset(GetThreadName(ts));
   }
   // 'thread_name' may have been null, so don't trust 'peer_thread_name' to be non-null.
   if (peer_thread_name.get() != NULL) {
@@ -432,8 +422,8 @@
   }
 }
 
-String* Thread::GetThreadName() const {
-  Field* f = DecodeField(WellKnownClasses::java_lang_Thread_name);
+String* Thread::GetThreadName(const ScopedJniThreadState& ts) const {
+  Field* f = ts.DecodeField(WellKnownClasses::java_lang_Thread_name);
   return (peer_ != NULL) ? reinterpret_cast<String*>(f->GetObject(peer_)) : NULL;
 }
 
@@ -447,12 +437,13 @@
   bool is_daemon = false;
 
   if (thread != NULL && thread->peer_ != NULL) {
-    priority = DecodeField(WellKnownClasses::java_lang_Thread_priority)->GetInt(thread->peer_);
-    is_daemon = DecodeField(WellKnownClasses::java_lang_Thread_daemon)->GetBoolean(thread->peer_);
+    ScopedJniThreadState ts(Thread::Current());
+    priority = ts.DecodeField(WellKnownClasses::java_lang_Thread_priority)->GetInt(thread->peer_);
+    is_daemon = ts.DecodeField(WellKnownClasses::java_lang_Thread_daemon)->GetBoolean(thread->peer_);
 
-    Object* thread_group = thread->GetThreadGroup();
+    Object* thread_group = thread->GetThreadGroup(ts);
     if (thread_group != NULL) {
-      Field* group_name_field = DecodeField(WellKnownClasses::java_lang_ThreadGroup_name);
+      Field* group_name_field = ts.DecodeField(WellKnownClasses::java_lang_ThreadGroup_name);
       String* group_name_string = reinterpret_cast<String*>(group_name_field->GetObject(thread_group));
       group_name = (group_name_string != NULL) ? group_name_string->ToModifiedUtf8() : "<null>";
     }
@@ -750,12 +741,13 @@
 }
 
 void Thread::FinishStartup() {
-  CHECK(Runtime::Current()->IsStarted());
+  Runtime* runtime = Runtime::Current();
+  CHECK(runtime->IsStarted());
   Thread* self = Thread::Current();
 
   // Finish attaching the main thread.
   ScopedThreadStateChange tsc(self, kRunnable);
-  Thread::Current()->CreatePeer("main", false, Thread::GetMainThreadGroup());
+  Thread::Current()->CreatePeer("main", false, runtime->GetMainThreadGroup());
 
   InitBoxingMethods();
   Runtime::Current()->GetClassLinker()->RunRootClinits();
@@ -826,19 +818,19 @@
     Thread* self = this;
 
     // We may need to call user-supplied managed code.
-    SetState(kRunnable);
+    ScopedJniThreadState ts(this);
 
-    HandleUncaughtExceptions();
-    RemoveFromThreadGroup();
+    HandleUncaughtExceptions(ts);
+    RemoveFromThreadGroup(ts);
 
     // this.vmData = 0;
-    SetVmData(peer_, NULL);
+    SetVmData(ts, peer_, NULL);
 
     Dbg::PostThreadDeath(self);
 
     // Thread.join() is implemented as an Object.wait() on the Thread.lock
     // object. Signal anyone who is waiting.
-    Object* lock = DecodeField(WellKnownClasses::java_lang_Thread_lock)->GetObject(peer_);
+    Object* lock = ts.DecodeField(WellKnownClasses::java_lang_Thread_lock)->GetObject(peer_);
     // (This conditional is only needed for tests, where Thread.lock won't have been set.)
     if (lock != NULL) {
       lock->MonitorEnter(self);
@@ -868,25 +860,25 @@
   TearDownAlternateSignalStack();
 }
 
-void Thread::HandleUncaughtExceptions() {
+void Thread::HandleUncaughtExceptions(const ScopedJniThreadState& ts) {
   if (!IsExceptionPending()) {
     return;
   }
-
   // Get and clear the exception.
   Object* exception = GetException();
   ClearException();
 
   // If the thread has its own handler, use that.
-  Object* handler = DecodeField(WellKnownClasses::java_lang_Thread_uncaughtHandler)->GetObject(peer_);
+  Object* handler =
+      ts.DecodeField(WellKnownClasses::java_lang_Thread_uncaughtHandler)->GetObject(peer_);
   if (handler == NULL) {
     // Otherwise use the thread group's default handler.
-    handler = GetThreadGroup();
+    handler = GetThreadGroup(ts);
   }
 
   // Call the handler.
   jmethodID mid = WellKnownClasses::java_lang_Thread$UncaughtExceptionHandler_uncaughtException;
-  Method* m = handler->GetClass()->FindVirtualMethodForVirtualOrInterface(DecodeMethod(mid));
+  Method* m = handler->GetClass()->FindVirtualMethodForVirtualOrInterface(ts.DecodeMethod(mid));
   JValue args[2];
   args[0].SetL(peer_);
   args[1].SetL(exception);
@@ -896,17 +888,17 @@
   ClearException();
 }
 
-Object* Thread::GetThreadGroup() const {
-  return DecodeField(WellKnownClasses::java_lang_Thread_group)->GetObject(peer_);
+Object* Thread::GetThreadGroup(const ScopedJniThreadState& ts) const {
+  return ts.DecodeField(WellKnownClasses::java_lang_Thread_group)->GetObject(peer_);
 }
 
-void Thread::RemoveFromThreadGroup() {
+void Thread::RemoveFromThreadGroup(const ScopedJniThreadState& ts) {
   // this.group.removeThread(this);
   // group can be null if we're in the compiler or a test.
-  Object* group = GetThreadGroup();
+  Object* group = GetThreadGroup(ts);
   if (group != NULL) {
     jmethodID mid = WellKnownClasses::java_lang_ThreadGroup_removeThread;
-    Method* m = group->GetClass()->FindVirtualMethodForVirtualOrInterface(DecodeMethod(mid));
+    Method* m = group->GetClass()->FindVirtualMethodForVirtualOrInterface(ts.DecodeMethod(mid));
     JValue args[1];
     args[0].SetL(peer_);
     m->Invoke(this, group, args, NULL);
@@ -1051,7 +1043,7 @@
     StackVisitor(stack, trace_stack), skip_depth_(skip_depth), count_(0), dex_pc_trace_(NULL),
     method_trace_(NULL) {}
 
-  bool Init(int depth, ScopedJniThreadState& ts) {
+  bool Init(int depth, const ScopedJniThreadState& ts) {
     // Allocate method trace with an extra slot that will hold the PC trace
     SirtRef<ObjectArray<Object> >
       method_trace(Runtime::Current()->GetClassLinker()->AllocObjectArray<Object>(depth + 1));
@@ -1121,16 +1113,13 @@
   return sirt;
 }
 
-jobject Thread::CreateInternalStackTrace(JNIEnv* env) const {
+jobject Thread::CreateInternalStackTrace(const ScopedJniThreadState& ts) const {
   // Compute depth of stack
   CountStackDepthVisitor count_visitor(GetManagedStack(), GetTraceStack());
   count_visitor.WalkStack();
   int32_t depth = count_visitor.GetDepth();
   int32_t skip_depth = count_visitor.GetSkipDepth();
 
-  // Transition into runnable state to work on Object*/Array*
-  ScopedJniThreadState ts(env);
-
   // Build internal stack trace
   BuildInternalStackTraceVisitor build_trace_visitor(GetManagedStack(), GetTraceStack(),
                                                      skip_depth);
@@ -1138,7 +1127,7 @@
     return NULL;  // Allocation failed
   }
   build_trace_visitor.WalkStack();
-  return AddLocalReference<jobjectArray>(ts.Env(), build_trace_visitor.GetInternalStackTrace());
+  return ts.AddLocalReference<jobjectArray>(build_trace_visitor.GetInternalStackTrace());
 }
 
 jobjectArray Thread::InternalStackTraceToStackTraceElementArray(JNIEnv* env, jobject internal,
@@ -1146,8 +1135,7 @@
   // Transition into runnable state to work on Object*/Array*
   ScopedJniThreadState ts(env);
   // Decode the internal stack trace into the depth, method trace and PC trace
-  ObjectArray<Object>* method_trace =
-      down_cast<ObjectArray<Object>*>(Decode<Object*>(ts.Env(), internal));
+  ObjectArray<Object>* method_trace = ts.Decode<ObjectArray<Object>*>(internal);
   int32_t depth = method_trace->GetLength() - 1;
   IntArray* pc_trace = down_cast<IntArray*>(method_trace->Get(depth));
 
@@ -1158,8 +1146,7 @@
   if (output_array != NULL) {
     // Reuse the array we were given.
     result = output_array;
-    java_traces = reinterpret_cast<ObjectArray<StackTraceElement>*>(Decode<Array*>(env,
-        output_array));
+    java_traces = ts.Decode<ObjectArray<StackTraceElement>*>(output_array);
     // ...adjusting the number of frames we'll write to not exceed the array length.
     depth = std::min(depth, java_traces->GetLength());
   } else {
@@ -1168,7 +1155,7 @@
     if (java_traces == NULL) {
       return NULL;
     }
-    result = AddLocalReference<jobjectArray>(ts.Env(), java_traces);
+    result = ts.AddLocalReference<jobjectArray>(java_traces);
   }
 
   if (stack_depth != NULL) {
@@ -1602,7 +1589,8 @@
 }
 
 bool Thread::IsDaemon() {
-  return DecodeField(WellKnownClasses::java_lang_Thread_daemon)->GetBoolean(peer_);
+  ScopedJniThreadState ts(this);
+  return ts.DecodeField(WellKnownClasses::java_lang_Thread_daemon)->GetBoolean(peer_);
 }
 
 class ReferenceMapVisitor : public StackVisitor {
diff --git a/src/thread.h b/src/thread.h
index 5ff0414..7cd55a3 100644
--- a/src/thread.h
+++ b/src/thread.h
@@ -94,11 +94,11 @@
 
   // Creates a new native thread corresponding to the given managed peer.
   // Used to implement Thread.start.
-  static void CreateNativeThread(Object* peer, size_t stack_size);
+  static void CreateNativeThread(JNIEnv* env, jobject peer, size_t stack_size);
 
   // Attaches the calling native thread to the runtime, returning the new native peer.
   // Used to implement JNI AttachCurrentThread and AttachCurrentThreadAsDaemon calls.
-  static Thread* Attach(const char* thread_name, bool as_daemon, Object* thread_group);
+  static Thread* Attach(const char* thread_name, bool as_daemon, jobject thread_group);
 
   // Reset internal state of child thread after fork.
   void InitAfterFork();
@@ -110,8 +110,8 @@
     return reinterpret_cast<Thread*>(thread);
   }
 
-  static Thread* FromManagedThread(Object* thread_peer);
-  static Thread* FromManagedThread(JNIEnv* env, jobject thread);
+  static Thread* FromManagedThread(const ScopedJniThreadState& ts, Object* thread_peer);
+  static Thread* FromManagedThread(const ScopedJniThreadState& ts, jobject thread);
 
   // Translates 172 to pAllocArrayFromCode and so on.
   static void DumpThreadOffset(std::ostream& os, uint32_t offset, size_t size_of_pointers);
@@ -179,11 +179,6 @@
    */
   static int GetNativePriority();
 
-  // Returns the "main" ThreadGroup, used when attaching user threads.
-  static Object* GetMainThreadGroup();
-  // Returns the "system" ThreadGroup, used when attaching our internal threads.
-  static Object* GetSystemThreadGroup();
-
   uint32_t GetThinLockId() const {
     return thin_lock_id_;
   }
@@ -193,7 +188,7 @@
   }
 
   // Returns the java.lang.Thread's name, or NULL if this Thread* doesn't have a peer.
-  String* GetThreadName() const;
+  String* GetThreadName(const ScopedJniThreadState& ts) const;
 
   // Sets 'name' to the java.lang.Thread's name. This requires no transition to managed code,
   // allocation, or locking.
@@ -206,7 +201,7 @@
     return peer_;
   }
 
-  Object* GetThreadGroup() const;
+  Object* GetThreadGroup(const ScopedJniThreadState& ts) const;
 
   RuntimeStats* GetStats() {
     return &stats_;
@@ -322,19 +317,19 @@
     NotifyLocked();
   }
 
-  const ClassLoader* GetClassLoaderOverride() {
+  ClassLoader* GetClassLoaderOverride() {
     // TODO: need to place the class_loader_override_ in a handle
     // DCHECK(CanAccessDirectReferences());
     return class_loader_override_;
   }
 
-  void SetClassLoaderOverride(const ClassLoader* class_loader_override) {
+  void SetClassLoaderOverride(ClassLoader* class_loader_override) {
     class_loader_override_ = class_loader_override;
   }
 
   // Create the internal representation of a stack trace, that is more time
   // and space efficient to compute than the StackTraceElement[]
-  jobject CreateInternalStackTrace(JNIEnv* env) const;
+  jobject CreateInternalStackTrace(const ScopedJniThreadState& ts) const;
 
   // Convert an internal stack trace representation (returned by CreateInternalStackTrace) to a
   // StackTraceElement[]. If output_array is NULL, a new array is created, otherwise as many
@@ -504,7 +499,7 @@
   void Destroy();
   friend class ThreadList;  // For ~Thread and Destroy.
 
-  void CreatePeer(const char* name, bool as_daemon, Object* thread_group);
+  void CreatePeer(const char* name, bool as_daemon, jobject thread_group);
   friend class Runtime; // For CreatePeer.
 
   void DumpState(std::ostream& os) const;
@@ -516,8 +511,8 @@
 
   static void* CreateCallback(void* arg);
 
-  void HandleUncaughtExceptions();
-  void RemoveFromThreadGroup();
+  void HandleUncaughtExceptions(const ScopedJniThreadState& ts);
+  void RemoveFromThreadGroup(const ScopedJniThreadState& ts);
 
   void Init();
   void InitCardTable();
@@ -609,7 +604,7 @@
 
   // Needed to get the right ClassLoader in JNI_OnLoad, but also
   // useful for testing.
-  const ClassLoader* class_loader_override_;
+  ClassLoader* class_loader_override_;
 
   // Thread local, lazily allocated, long jump context. Used to deliver exceptions.
   Context* long_jump_context_;
diff --git a/src/verifier/method_verifier.cc b/src/verifier/method_verifier.cc
index f6a2ddb..ccc83da 100644
--- a/src/verifier/method_verifier.cc
+++ b/src/verifier/method_verifier.cc
@@ -204,8 +204,8 @@
   return VerifyClass(&dex_file, kh.GetDexCache(), klass->GetClassLoader(), class_def_idx, error);
 }
 
-MethodVerifier::FailureKind MethodVerifier::VerifyClass(const DexFile* dex_file, DexCache* dex_cache,
-    const ClassLoader* class_loader, uint32_t class_def_idx, std::string& error) {
+MethodVerifier::FailureKind MethodVerifier::VerifyClass(const DexFile* dex_file,
+    DexCache* dex_cache, ClassLoader* class_loader, uint32_t class_def_idx, std::string& error) {
   const DexFile::ClassDef& class_def = dex_file->GetClassDef(class_def_idx);
   const byte* class_data = dex_file->GetClassData(class_def);
   if (class_data == NULL) {
@@ -277,7 +277,7 @@
 }
 
 MethodVerifier::FailureKind MethodVerifier::VerifyMethod(uint32_t method_idx, const DexFile* dex_file,
-    DexCache* dex_cache, const ClassLoader* class_loader, uint32_t class_def_idx,
+    DexCache* dex_cache, ClassLoader* class_loader, uint32_t class_def_idx,
     const DexFile::CodeItem* code_item, Method* method, uint32_t method_access_flags) {
   MethodVerifier verifier(dex_file, dex_cache, class_loader, class_def_idx, code_item, method_idx,
                           method, method_access_flags);
@@ -317,7 +317,7 @@
 }
 
 MethodVerifier::MethodVerifier(const DexFile* dex_file, DexCache* dex_cache,
-    const ClassLoader* class_loader, uint32_t class_def_idx, const DexFile::CodeItem* code_item,
+    ClassLoader* class_loader, uint32_t class_def_idx, const DexFile::CodeItem* code_item,
     uint32_t method_idx, Method* method, uint32_t method_access_flags)
     : work_insn_idx_(-1),
       method_idx_(method_idx),
@@ -2900,7 +2900,7 @@
     field = GetInstanceField(object_type, field_idx);
   }
   const char* descriptor;
-  const ClassLoader* loader;
+  ClassLoader* loader;
   if (field != NULL) {
     descriptor = FieldHelper(field).GetTypeDescriptor();
     loader = field->GetDeclaringClass()->GetClassLoader();
@@ -2949,7 +2949,7 @@
     field = GetInstanceField(object_type, field_idx);
   }
   const char* descriptor;
-  const ClassLoader* loader;
+  ClassLoader* loader;
   if (field != NULL) {
     descriptor = FieldHelper(field).GetTypeDescriptor();
     loader = field->GetDeclaringClass()->GetClassLoader();
diff --git a/src/verifier/method_verifier.h b/src/verifier/method_verifier.h
index 5f72678..64a723e 100644
--- a/src/verifier/method_verifier.h
+++ b/src/verifier/method_verifier.h
@@ -171,7 +171,8 @@
   /* Verify a class. Returns "kNoFailure" on success. */
   static FailureKind VerifyClass(const Class* klass, std::string& error);
   static FailureKind VerifyClass(const DexFile* dex_file, DexCache* dex_cache,
-      const ClassLoader* class_loader, uint32_t class_def_idx, std::string& error);
+                                 ClassLoader* class_loader, uint32_t class_def_idx,
+                                 std::string& error);
 
   uint8_t EncodePcToReferenceMapData() const;
 
@@ -212,7 +213,7 @@
 
  private:
   explicit MethodVerifier(const DexFile* dex_file, DexCache* dex_cache,
-      const ClassLoader* class_loader, uint32_t class_def_idx, const DexFile::CodeItem* code_item,
+      ClassLoader* class_loader, uint32_t class_def_idx, const DexFile::CodeItem* code_item,
       uint32_t method_idx, Method* method, uint32_t access_flags);
 
   // Adds the given string to the beginning of the last failure message.
@@ -233,7 +234,7 @@
    *      for code flow problems.
    */
   static FailureKind VerifyMethod(uint32_t method_idx, const DexFile* dex_file, DexCache* dex_cache,
-      const ClassLoader* class_loader, uint32_t class_def_idx, const DexFile::CodeItem* code_item,
+      ClassLoader* class_loader, uint32_t class_def_idx, const DexFile::CodeItem* code_item,
       Method* method, uint32_t method_access_flags);
   static void VerifyMethodAndDump(Method* method);
 
@@ -611,7 +612,7 @@
   uint32_t method_access_flags_;  // Method's access flags.
   const DexFile* dex_file_;  // The dex file containing the method.
   DexCache* dex_cache_;  // The dex_cache for the declaring class of the method.
-  const ClassLoader* class_loader_;  // The class loader for the declaring class of the method.
+  ClassLoader* class_loader_;  // The class loader for the declaring class of the method.
   uint32_t class_def_idx_;  // The class def index of the declaring class of the method.
   const DexFile::CodeItem* code_item_;  // The code item containing the code for the method.
   UniquePtr<InsnFlags[]> insn_flags_;  // Instruction widths and flags, one entry per code unit.
diff --git a/src/verifier/reg_type.cc b/src/verifier/reg_type.cc
index 217084f..dd54b5f 100644
--- a/src/verifier/reg_type.cc
+++ b/src/verifier/reg_type.cc
@@ -294,7 +294,7 @@
     }
     Class* common_elem = ClassJoin(s_ct, t_ct);
     ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
-    const ClassLoader* class_loader = s->GetClassLoader();
+    ClassLoader* class_loader = s->GetClassLoader();
     std::string descriptor("[");
     descriptor += ClassHelper(common_elem).GetDescriptor();
     Class* array_class = class_linker->FindClass(descriptor.c_str(), class_loader);
diff --git a/src/verifier/reg_type_cache.cc b/src/verifier/reg_type_cache.cc
index c860bd7..bb05e7e 100644
--- a/src/verifier/reg_type_cache.cc
+++ b/src/verifier/reg_type_cache.cc
@@ -57,13 +57,11 @@
   }
 }
 
-const RegType& RegTypeCache::FromDescriptor(const ClassLoader* loader,
-                                            const char* descriptor) {
+const RegType& RegTypeCache::FromDescriptor(ClassLoader* loader, const char* descriptor) {
   return From(RegTypeFromDescriptor(descriptor), loader, descriptor);
 }
 
-const RegType& RegTypeCache::From(RegType::Type type, const ClassLoader* loader,
-                                  const char* descriptor) {
+const RegType& RegTypeCache::From(RegType::Type type, ClassLoader* loader, const char* descriptor) {
   if (type <= RegType::kRegTypeLastFixedLocation) {
     // entries should be sized greater than primitive types
     DCHECK_GT(entries_.size(), static_cast<size_t>(type));
@@ -258,7 +256,7 @@
   return *entry;
 }
 
-const RegType& RegTypeCache::GetComponentType(const RegType& array, const ClassLoader* loader) {
+const RegType& RegTypeCache::GetComponentType(const RegType& array, ClassLoader* loader) {
   CHECK(array.IsArrayTypes());
   if (array.IsUnresolvedTypes()) {
     std::string descriptor(array.GetDescriptor()->ToModifiedUtf8());
diff --git a/src/verifier/reg_type_cache.h b/src/verifier/reg_type_cache.h
index 91a2933..765809c 100644
--- a/src/verifier/reg_type_cache.h
+++ b/src/verifier/reg_type_cache.h
@@ -40,10 +40,10 @@
     return *result;
   }
 
-  const RegType& From(RegType::Type type, const ClassLoader* loader, const char* descriptor);
+  const RegType& From(RegType::Type type, ClassLoader* loader, const char* descriptor);
   const RegType& FromClass(Class* klass);
   const RegType& FromCat1Const(int32_t value);
-  const RegType& FromDescriptor(const ClassLoader* loader, const char* descriptor);
+  const RegType& FromDescriptor(ClassLoader* loader, const char* descriptor);
   const RegType& FromType(RegType::Type);
 
   const RegType& Boolean() { return FromType(RegType::kRegTypeBoolean); }
@@ -77,7 +77,7 @@
   const RegType& ShortConstant() { return FromCat1Const(std::numeric_limits<jshort>::min()); }
   const RegType& IntConstant() { return FromCat1Const(std::numeric_limits<jint>::max()); }
 
-  const RegType& GetComponentType(const RegType& array, const ClassLoader* loader);
+  const RegType& GetComponentType(const RegType& array, ClassLoader* loader);
 
  private:
   // The allocated entries