Revert "Revert "Implement on-stack replacement for arm/arm64/x86/x86_64.""

This reverts commit bd89a5c556324062b7d841843b039392e84cfaf4.

Change-Id: I08d190431520baa7fcec8fbdb444519f25ac8d44
diff --git a/runtime/jit/jit.cc b/runtime/jit/jit.cc
index fa5c41d..3e152e1 100644
--- a/runtime/jit/jit.cc
+++ b/runtime/jit/jit.cc
@@ -25,10 +25,12 @@
 #include "jit_code_cache.h"
 #include "jit_instrumentation.h"
 #include "oat_file_manager.h"
+#include "oat_quick_method_header.h"
 #include "offline_profiling_info.h"
 #include "profile_saver.h"
 #include "runtime.h"
 #include "runtime_options.h"
+#include "stack_map.h"
 #include "utils.h"
 
 namespace art {
@@ -43,6 +45,8 @@
       options.GetOrDefault(RuntimeArgumentMap::JITCodeCacheMaxCapacity);
   jit_options->compile_threshold_ =
       options.GetOrDefault(RuntimeArgumentMap::JITCompileThreshold);
+  // TODO(ngeoffray): Make this a proper option.
+  jit_options->osr_threshold_ = jit_options->compile_threshold_ * 2;
   jit_options->warmup_threshold_ =
       options.GetOrDefault(RuntimeArgumentMap::JITWarmupThreshold);
   jit_options->dump_info_on_shutdown_ =
@@ -121,7 +125,7 @@
     *error_msg = "JIT couldn't find jit_unload entry point";
     return false;
   }
-  jit_compile_method_ = reinterpret_cast<bool (*)(void*, ArtMethod*, Thread*)>(
+  jit_compile_method_ = reinterpret_cast<bool (*)(void*, ArtMethod*, Thread*, bool)>(
       dlsym(jit_library_handle_, "jit_compile_method"));
   if (jit_compile_method_ == nullptr) {
     dlclose(jit_library_handle_);
@@ -156,7 +160,7 @@
   return true;
 }
 
-bool Jit::CompileMethod(ArtMethod* method, Thread* self) {
+bool Jit::CompileMethod(ArtMethod* method, Thread* self, bool osr) {
   DCHECK(!method->IsRuntimeMethod());
   // Don't compile the method if it has breakpoints.
   if (Dbg::IsDebuggerActive() && Dbg::MethodHasAnyBreakpoints(method)) {
@@ -171,10 +175,11 @@
     return false;
   }
 
-  if (!code_cache_->NotifyCompilationOf(method, self)) {
+  if (!code_cache_->NotifyCompilationOf(method, self, osr)) {
+    VLOG(jit) << "JIT not compiling " << PrettyMethod(method) << " due to code cache";
     return false;
   }
-  bool success = jit_compile_method_(jit_compiler_handle_, method, self);
+  bool success = jit_compile_method_(jit_compiler_handle_, method, self, osr);
   code_cache_->DoneCompiling(method, self);
   return success;
 }
@@ -224,9 +229,11 @@
   }
 }
 
-void Jit::CreateInstrumentationCache(size_t compile_threshold, size_t warmup_threshold) {
+void Jit::CreateInstrumentationCache(size_t compile_threshold,
+                                     size_t warmup_threshold,
+                                     size_t osr_threshold) {
   instrumentation_cache_.reset(
-      new jit::JitInstrumentationCache(compile_threshold, warmup_threshold));
+      new jit::JitInstrumentationCache(compile_threshold, warmup_threshold, osr_threshold));
 }
 
 void Jit::NewTypeLoadedIfUsingJit(mirror::Class* type) {
@@ -255,5 +262,120 @@
   }
 }
 
+extern "C" void art_quick_osr_stub(void** stack,
+                                   uint32_t stack_size_in_bytes,
+                                   const uint8_t* native_pc,
+                                   JValue* result,
+                                   const char* shorty,
+                                   Thread* self);
+
+bool Jit::MaybeDoOnStackReplacement(Thread* thread,
+                                    ArtMethod* method,
+                                    uint32_t dex_pc,
+                                    int32_t dex_pc_offset,
+                                    JValue* result) {
+  Jit* jit = Runtime::Current()->GetJit();
+  if (jit == nullptr) {
+    return false;
+  }
+
+  if (kRuntimeISA == kMips || kRuntimeISA == kMips64) {
+    VLOG(jit) << "OSR not supported on this platform";
+    return false;
+  }
+
+  // Cheap check if the method has been compiled already. That's an indicator that we should
+  // osr into it.
+  if (!jit->GetCodeCache()->ContainsPc(method->GetEntryPointFromQuickCompiledCode())) {
+    return false;
+  }
+
+  const OatQuickMethodHeader* osr_method = jit->GetCodeCache()->LookupOsrMethodHeader(method);
+  if (osr_method == nullptr) {
+    // No osr method yet, just return to the interpreter.
+    return false;
+  }
+
+  const size_t number_of_vregs = method->GetCodeItem()->registers_size_;
+  CodeInfo code_info = osr_method->GetOptimizedCodeInfo();
+  StackMapEncoding encoding = code_info.ExtractEncoding();
+
+  // Find stack map starting at the target dex_pc.
+  StackMap stack_map = code_info.GetOsrStackMapForDexPc(dex_pc + dex_pc_offset, encoding);
+  if (!stack_map.IsValid()) {
+    // There is no OSR stack map for this dex pc offset. Just return to the interpreter in the
+    // hope that the next branch has one.
+    return false;
+  }
+
+  // We found a stack map, now fill the frame with dex register values from the interpreter's
+  // shadow frame.
+  DexRegisterMap vreg_map =
+      code_info.GetDexRegisterMapOf(stack_map, encoding, number_of_vregs);
+
+  ShadowFrame* shadow_frame = thread->PopShadowFrame();
+
+  size_t frame_size = osr_method->GetFrameSizeInBytes();
+  void** memory = reinterpret_cast<void**>(malloc(frame_size));
+  memset(memory, 0, frame_size);
+
+  // Art ABI: ArtMethod is at the bottom of the stack.
+  memory[0] = method;
+
+  if (!vreg_map.IsValid()) {
+    // If we don't have a dex register map, then there are no live dex registers at
+    // this dex pc.
+  } else {
+    for (uint16_t vreg = 0; vreg < number_of_vregs; ++vreg) {
+      DexRegisterLocation::Kind location =
+          vreg_map.GetLocationKind(vreg, number_of_vregs, code_info, encoding);
+      if (location == DexRegisterLocation::Kind::kNone) {
+        // Dex register is dead or unitialized.
+        continue;
+      }
+
+      if (location == DexRegisterLocation::Kind::kConstant) {
+        // We skip constants because the compiled code knows how to handle them.
+        continue;
+      }
+
+      DCHECK(location == DexRegisterLocation::Kind::kInStack);
+
+      int32_t vreg_value = shadow_frame->GetVReg(vreg);
+      int32_t slot_offset = vreg_map.GetStackOffsetInBytes(vreg,
+                                                           number_of_vregs,
+                                                           code_info,
+                                                           encoding);
+      DCHECK_LT(slot_offset, static_cast<int32_t>(frame_size));
+      DCHECK_GT(slot_offset, 0);
+      (reinterpret_cast<int32_t*>(memory))[slot_offset / sizeof(int32_t)] = vreg_value;
+    }
+  }
+
+  const uint8_t* native_pc = stack_map.GetNativePcOffset(encoding) + osr_method->GetEntryPoint();
+  VLOG(jit) << "Jumping to "
+            << PrettyMethod(method)
+            << "@"
+            << std::hex << reinterpret_cast<uintptr_t>(native_pc);
+  {
+    ManagedStack fragment;
+    thread->PushManagedStackFragment(&fragment);
+    (*art_quick_osr_stub)(memory,
+                          frame_size,
+                          native_pc,
+                          result,
+                          method->GetInterfaceMethodIfProxy(sizeof(void*))->GetShorty(),
+                          thread);
+    if (UNLIKELY(thread->GetException() == Thread::GetDeoptimizationException())) {
+      thread->DeoptimizeWithDeoptimizationException(result);
+    }
+    thread->PopManagedStackFragment(fragment);
+  }
+  free(memory);
+  thread->PushShadowFrame(shadow_frame);
+  VLOG(jit) << "Done running OSR code for " << PrettyMethod(method);
+  return true;
+}
+
 }  // namespace jit
 }  // namespace art
diff --git a/runtime/jit/jit.h b/runtime/jit/jit.h
index a80f51f..042da92 100644
--- a/runtime/jit/jit.h
+++ b/runtime/jit/jit.h
@@ -49,9 +49,11 @@
 
   virtual ~Jit();
   static Jit* Create(JitOptions* options, std::string* error_msg);
-  bool CompileMethod(ArtMethod* method, Thread* self)
+  bool CompileMethod(ArtMethod* method, Thread* self, bool osr)
       SHARED_REQUIRES(Locks::mutator_lock_);
-  void CreateInstrumentationCache(size_t compile_threshold, size_t warmup_threshold);
+  void CreateInstrumentationCache(size_t compile_threshold,
+                                  size_t warmup_threshold,
+                                  size_t osr_threshold);
   void CreateThreadPool();
   CompilerCallbacks* GetCompilerCallbacks() {
     return compiler_callbacks_;
@@ -88,6 +90,17 @@
 
   bool JitAtFirstUse();
 
+  // If an OSR compiled version is available for `method`,
+  // and `dex_pc + dex_pc_offset` is an entry point of that compiled
+  // version, this method will jump to the compiled code, let it run,
+  // and return true afterwards. Return false otherwise.
+  static bool MaybeDoOnStackReplacement(Thread* thread,
+                                        ArtMethod* method,
+                                        uint32_t dex_pc,
+                                        int32_t dex_pc_offset,
+                                        JValue* result)
+      SHARED_REQUIRES(Locks::mutator_lock_);
+
  private:
   Jit();
   bool LoadCompiler(std::string* error_msg);
@@ -97,7 +110,7 @@
   void* jit_compiler_handle_;
   void* (*jit_load_)(CompilerCallbacks**, bool*);
   void (*jit_unload_)(void*);
-  bool (*jit_compile_method_)(void*, ArtMethod*, Thread*);
+  bool (*jit_compile_method_)(void*, ArtMethod*, Thread*, bool);
   void (*jit_types_loaded_)(void*, mirror::Class**, size_t count);
 
   // Performance monitoring.
@@ -123,6 +136,9 @@
   size_t GetWarmupThreshold() const {
     return warmup_threshold_;
   }
+  size_t GetOsrThreshold() const {
+    return osr_threshold_;
+  }
   size_t GetCodeCacheInitialCapacity() const {
     return code_cache_initial_capacity_;
   }
@@ -155,6 +171,7 @@
   size_t code_cache_max_capacity_;
   size_t compile_threshold_;
   size_t warmup_threshold_;
+  size_t osr_threshold_;
   bool dump_info_on_shutdown_;
   bool save_profiling_info_;
 
diff --git a/runtime/jit/jit_code_cache.cc b/runtime/jit/jit_code_cache.cc
index f325949..464c441 100644
--- a/runtime/jit/jit_code_cache.cc
+++ b/runtime/jit/jit_code_cache.cc
@@ -184,7 +184,8 @@
                                   size_t core_spill_mask,
                                   size_t fp_spill_mask,
                                   const uint8_t* code,
-                                  size_t code_size) {
+                                  size_t code_size,
+                                  bool osr) {
   uint8_t* result = CommitCodeInternal(self,
                                        method,
                                        mapping_table,
@@ -194,7 +195,8 @@
                                        core_spill_mask,
                                        fp_spill_mask,
                                        code,
-                                       code_size);
+                                       code_size,
+                                       osr);
   if (result == nullptr) {
     // Retry.
     GarbageCollectCache(self);
@@ -207,7 +209,8 @@
                                 core_spill_mask,
                                 fp_spill_mask,
                                 code,
-                                code_size);
+                                code_size,
+                                osr);
   }
   return result;
 }
@@ -287,7 +290,8 @@
                                           size_t core_spill_mask,
                                           size_t fp_spill_mask,
                                           const uint8_t* code,
-                                          size_t code_size) {
+                                          size_t code_size,
+                                          bool osr) {
   size_t alignment = GetInstructionSetAlignment(kRuntimeISA);
   // Ensure the header ends up at expected instruction alignment.
   size_t header_size = RoundUp(sizeof(OatQuickMethodHeader), alignment);
@@ -329,8 +333,12 @@
   {
     MutexLock mu(self, lock_);
     method_code_map_.Put(code_ptr, method);
-    Runtime::Current()->GetInstrumentation()->UpdateMethodsCode(
-        method, method_header->GetEntryPoint());
+    if (osr) {
+      osr_code_map_.Put(method, code_ptr);
+    } else {
+      Runtime::Current()->GetInstrumentation()->UpdateMethodsCode(
+          method, method_header->GetEntryPoint());
+    }
     if (collection_in_progress_) {
       // We need to update the live bitmap if there is a GC to ensure it sees this new
       // code.
@@ -338,7 +346,7 @@
     }
     last_update_time_ns_.StoreRelease(NanoTime());
     VLOG(jit)
-        << "JIT added "
+        << "JIT added (osr = " << std::boolalpha << osr << std::noboolalpha << ") "
         << PrettyMethod(method) << "@" << method
         << " ccache_size=" << PrettySize(CodeCacheSizeLocked()) << ": "
         << " dcache_size=" << PrettySize(DataCacheSizeLocked()) << ": "
@@ -569,6 +577,10 @@
         info->GetMethod()->SetProfilingInfo(nullptr);
       }
     }
+
+    // Empty osr method map, as osr compile code will be deleted (except the ones
+    // on thread stacks).
+    osr_code_map_.clear();
   }
 
   // Run a checkpoint on all threads to mark the JIT compiled code they are running.
@@ -662,6 +674,15 @@
   return method_header;
 }
 
+OatQuickMethodHeader* JitCodeCache::LookupOsrMethodHeader(ArtMethod* method) {
+  MutexLock mu(Thread::Current(), lock_);
+  auto it = osr_code_map_.find(method);
+  if (it == osr_code_map_.end()) {
+    return nullptr;
+  }
+  return OatQuickMethodHeader::FromCodePointer(it->second);
+}
+
 ProfilingInfo* JitCodeCache::AddProfilingInfo(Thread* self,
                                               ArtMethod* method,
                                               const std::vector<uint32_t>& entries,
@@ -733,12 +754,15 @@
   return last_update_time_ns_.LoadAcquire();
 }
 
-bool JitCodeCache::NotifyCompilationOf(ArtMethod* method, Thread* self) {
-  if (ContainsPc(method->GetEntryPointFromQuickCompiledCode())) {
+bool JitCodeCache::NotifyCompilationOf(ArtMethod* method, Thread* self, bool osr) {
+  if (!osr && ContainsPc(method->GetEntryPointFromQuickCompiledCode())) {
     return false;
   }
 
   MutexLock mu(self, lock_);
+  if (osr && (osr_code_map_.find(method) != osr_code_map_.end())) {
+    return false;
+  }
   ProfilingInfo* info = method->GetProfilingInfo(sizeof(void*));
   if (info == nullptr || info->IsMethodBeingCompiled()) {
     return false;
diff --git a/runtime/jit/jit_code_cache.h b/runtime/jit/jit_code_cache.h
index 69fc553..048f8d0 100644
--- a/runtime/jit/jit_code_cache.h
+++ b/runtime/jit/jit_code_cache.h
@@ -71,7 +71,7 @@
   // Number of compilations done throughout the lifetime of the JIT.
   size_t NumberOfCompilations() REQUIRES(!lock_);
 
-  bool NotifyCompilationOf(ArtMethod* method, Thread* self)
+  bool NotifyCompilationOf(ArtMethod* method, Thread* self, bool osr)
       SHARED_REQUIRES(Locks::mutator_lock_)
       REQUIRES(!lock_);
 
@@ -89,7 +89,8 @@
                       size_t core_spill_mask,
                       size_t fp_spill_mask,
                       const uint8_t* code,
-                      size_t code_size)
+                      size_t code_size,
+                      bool osr)
       SHARED_REQUIRES(Locks::mutator_lock_)
       REQUIRES(!lock_);
 
@@ -131,6 +132,10 @@
       REQUIRES(!lock_)
       SHARED_REQUIRES(Locks::mutator_lock_);
 
+  OatQuickMethodHeader* LookupOsrMethodHeader(ArtMethod* method)
+      REQUIRES(!lock_)
+      SHARED_REQUIRES(Locks::mutator_lock_);
+
   // Remove all methods in our cache that were allocated by 'alloc'.
   void RemoveMethodsIn(Thread* self, const LinearAlloc& alloc)
       REQUIRES(!lock_)
@@ -187,7 +192,8 @@
                               size_t core_spill_mask,
                               size_t fp_spill_mask,
                               const uint8_t* code,
-                              size_t code_size)
+                              size_t code_size,
+                              bool osr)
       REQUIRES(!lock_)
       SHARED_REQUIRES(Locks::mutator_lock_);
 
@@ -237,8 +243,10 @@
   void* data_mspace_ GUARDED_BY(lock_);
   // Bitmap for collecting code and data.
   std::unique_ptr<CodeCacheBitmap> live_bitmap_;
-  // This map holds compiled code associated to the ArtMethod.
+  // Holds compiled code associated to the ArtMethod.
   SafeMap<const void*, ArtMethod*> method_code_map_ GUARDED_BY(lock_);
+  // Holds osr compiled code associated to the ArtMethod.
+  SafeMap<ArtMethod*, const void*> osr_code_map_ GUARDED_BY(lock_);
   // ProfilingInfo objects we have allocated.
   std::vector<ProfilingInfo*> profiling_infos_ GUARDED_BY(lock_);
 
diff --git a/runtime/jit/jit_instrumentation.cc b/runtime/jit/jit_instrumentation.cc
index d597b36..a4e40ad 100644
--- a/runtime/jit/jit_instrumentation.cc
+++ b/runtime/jit/jit_instrumentation.cc
@@ -29,7 +29,8 @@
  public:
   enum TaskKind {
     kAllocateProfile,
-    kCompile
+    kCompile,
+    kCompileOsr
   };
 
   JitCompileTask(ArtMethod* method, TaskKind kind) : method_(method), kind_(kind) {
@@ -48,9 +49,14 @@
     ScopedObjectAccess soa(self);
     if (kind_ == kCompile) {
       VLOG(jit) << "JitCompileTask compiling method " << PrettyMethod(method_);
-      if (!Runtime::Current()->GetJit()->CompileMethod(method_, self)) {
+      if (!Runtime::Current()->GetJit()->CompileMethod(method_, self, /* osr */ false)) {
         VLOG(jit) << "Failed to compile method " << PrettyMethod(method_);
       }
+    } else if (kind_ == kCompileOsr) {
+      VLOG(jit) << "JitCompileTask compiling method osr " << PrettyMethod(method_);
+      if (!Runtime::Current()->GetJit()->CompileMethod(method_, self, /* osr */ true)) {
+        VLOG(jit) << "Failed to compile method osr " << PrettyMethod(method_);
+      }
     } else {
       DCHECK(kind_ == kAllocateProfile);
       if (ProfilingInfo::Create(self, method_, /* retry_allocation */ true)) {
@@ -72,9 +78,11 @@
 };
 
 JitInstrumentationCache::JitInstrumentationCache(size_t hot_method_threshold,
-                                                 size_t warm_method_threshold)
+                                                 size_t warm_method_threshold,
+                                                 size_t osr_method_threshold)
     : hot_method_threshold_(hot_method_threshold),
       warm_method_threshold_(warm_method_threshold),
+      osr_method_threshold_(osr_method_threshold),
       listener_(this) {
 }
 
@@ -151,6 +159,11 @@
     DCHECK(thread_pool_ != nullptr);
     thread_pool_->AddTask(self, new JitCompileTask(method, JitCompileTask::kCompile));
   }
+
+  if (sample_count == osr_method_threshold_) {
+    DCHECK(thread_pool_ != nullptr);
+    thread_pool_->AddTask(self, new JitCompileTask(method, JitCompileTask::kCompileOsr));
+  }
 }
 
 JitInstrumentationListener::JitInstrumentationListener(JitInstrumentationCache* cache)
diff --git a/runtime/jit/jit_instrumentation.h b/runtime/jit/jit_instrumentation.h
index 06559ad..d1c5c44 100644
--- a/runtime/jit/jit_instrumentation.h
+++ b/runtime/jit/jit_instrumentation.h
@@ -96,7 +96,9 @@
 // Keeps track of which methods are hot.
 class JitInstrumentationCache {
  public:
-  JitInstrumentationCache(size_t hot_method_threshold, size_t warm_method_threshold);
+  JitInstrumentationCache(size_t hot_method_threshold,
+                          size_t warm_method_threshold,
+                          size_t osr_method_threshold);
   void AddSamples(Thread* self, ArtMethod* method, size_t samples)
       SHARED_REQUIRES(Locks::mutator_lock_);
   void CreateThreadPool();
@@ -112,6 +114,7 @@
  private:
   size_t hot_method_threshold_;
   size_t warm_method_threshold_;
+  size_t osr_method_threshold_;
   JitInstrumentationListener listener_;
   std::unique_ptr<ThreadPool> thread_pool_;