Bring ReferenceTypePropagation to HInvoke return types

Change-Id: Id0683f67e32874713a30c072c71dc537b1271926
diff --git a/compiler/optimizing/builder.cc b/compiler/optimizing/builder.cc
index e4680ff..1f9287c 100644
--- a/compiler/optimizing/builder.cc
+++ b/compiler/optimizing/builder.cc
@@ -723,10 +723,16 @@
       }
     }
 
-    invoke = new (arena_) HInvokeStaticOrDirect(
-        arena_, number_of_arguments, return_type, dex_pc, target_method.dex_method_index,
-        is_recursive, string_init_offset, invoke_type, optimized_invoke_type,
-        clinit_check_requirement);
+    invoke = new (arena_) HInvokeStaticOrDirect(arena_,
+                                                number_of_arguments,
+                                                return_type,
+                                                dex_pc,
+                                                target_method.dex_method_index,
+                                                is_recursive,
+                                                string_init_offset,
+                                                invoke_type,
+                                                optimized_invoke_type,
+                                                clinit_check_requirement);
   }
 
   size_t start_index = 0;
diff --git a/compiler/optimizing/inliner.cc b/compiler/optimizing/inliner.cc
index 5aeaad2..92ebf06 100644
--- a/compiler/optimizing/inliner.cc
+++ b/compiler/optimizing/inliner.cc
@@ -256,7 +256,7 @@
     return false;
   }
 
-  if (!TryBuildAndInline(resolved_method, invoke_instruction, method_index, same_dex_file)) {
+  if (!TryBuildAndInline(resolved_method, invoke_instruction, same_dex_file)) {
     return false;
   }
 
@@ -267,11 +267,11 @@
 
 bool HInliner::TryBuildAndInline(ArtMethod* resolved_method,
                                  HInvoke* invoke_instruction,
-                                 uint32_t method_index,
                                  bool same_dex_file) const {
   ScopedObjectAccess soa(Thread::Current());
   const DexFile::CodeItem* code_item = resolved_method->GetCodeItem();
-  const DexFile& caller_dex_file = *caller_compilation_unit_.GetDexFile();
+  const DexFile& callee_dex_file = *resolved_method->GetDexFile();
+  uint32_t method_index = resolved_method->GetDexMethodIndex();
 
   DexCompilationUnit dex_compilation_unit(
     nullptr,
@@ -311,7 +311,7 @@
   }
   HGraph* callee_graph = new (graph_->GetArena()) HGraph(
       graph_->GetArena(),
-      caller_dex_file,
+      callee_dex_file,
       method_index,
       requires_ctor_barrier,
       compiler_driver_->GetInstructionSet(),
@@ -328,7 +328,7 @@
                         &inline_stats);
 
   if (!builder.BuildGraph(*code_item)) {
-    VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+    VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                    << " could not be built, so cannot be inlined";
     // There could be multiple reasons why the graph could not be built, including
     // unaccessible methods/fields due to using a different dex cache. We do not mark
@@ -338,14 +338,14 @@
 
   if (!RegisterAllocator::CanAllocateRegistersFor(*callee_graph,
                                                   compiler_driver_->GetInstructionSet())) {
-    VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+    VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                    << " cannot be inlined because of the register allocator";
     resolved_method->SetShouldNotInline();
     return false;
   }
 
   if (!callee_graph->TryBuildingSsa()) {
-    VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+    VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                    << " could not be transformed to SSA";
     resolved_method->SetShouldNotInline();
     return false;
@@ -385,7 +385,7 @@
   // a throw predecessor.
   HBasicBlock* exit_block = callee_graph->GetExitBlock();
   if (exit_block == nullptr) {
-    VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+    VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                    << " could not be inlined because it has an infinite loop";
     resolved_method->SetShouldNotInline();
     return false;
@@ -399,7 +399,7 @@
     }
   }
   if (has_throw_predecessor) {
-    VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+    VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                    << " could not be inlined because one branch always throws";
     resolved_method->SetShouldNotInline();
     return false;
@@ -410,7 +410,7 @@
   for (; !it.Done(); it.Advance()) {
     HBasicBlock* block = it.Current();
     if (block->IsLoopHeader()) {
-      VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+      VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                      << " could not be inlined because it contains a loop";
       resolved_method->SetShouldNotInline();
       return false;
@@ -424,21 +424,21 @@
       if (current->IsInvokeInterface()) {
         // Disable inlining of interface calls. The cost in case of entering the
         // resolution conflict is currently too high.
-        VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+        VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                        << " could not be inlined because it has an interface call.";
         resolved_method->SetShouldNotInline();
         return false;
       }
 
       if (!same_dex_file && current->NeedsEnvironment()) {
-        VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+        VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                        << " could not be inlined because " << current->DebugName()
                        << " needs an environment and is in a different dex file";
         return false;
       }
 
       if (!same_dex_file && current->NeedsDexCache()) {
-        VLOG(compiler) << "Method " << PrettyMethod(method_index, caller_dex_file)
+        VLOG(compiler) << "Method " << PrettyMethod(method_index, callee_dex_file)
                        << " could not be inlined because " << current->DebugName()
                        << " it is in a different dex file and requires access to the dex cache";
         // Do not flag the method as not-inlineable. A caller within the same
diff --git a/compiler/optimizing/inliner.h b/compiler/optimizing/inliner.h
index 7465278..24044b7 100644
--- a/compiler/optimizing/inliner.h
+++ b/compiler/optimizing/inliner.h
@@ -52,7 +52,6 @@
   bool TryInline(HInvoke* invoke_instruction, uint32_t method_index) const;
   bool TryBuildAndInline(ArtMethod* resolved_method,
                          HInvoke* invoke_instruction,
-                         uint32_t method_index,
                          bool same_dex_file) const;
 
   const DexCompilationUnit& outer_compilation_unit_;
diff --git a/compiler/optimizing/reference_type_propagation.cc b/compiler/optimizing/reference_type_propagation.cc
index 3d81c20..c3db551 100644
--- a/compiler/optimizing/reference_type_propagation.cc
+++ b/compiler/optimizing/reference_type_propagation.cc
@@ -23,6 +23,29 @@
 
 namespace art {
 
+class RTPVisitor : public HGraphDelegateVisitor {
+ public:
+  RTPVisitor(HGraph* graph, StackHandleScopeCollection* handles)
+    : HGraphDelegateVisitor(graph),
+      handles_(handles) {}
+
+  void VisitNewInstance(HNewInstance* new_instance) OVERRIDE;
+  void VisitLoadClass(HLoadClass* load_class) OVERRIDE;
+  void VisitNewArray(HNewArray* instr) OVERRIDE;
+  void UpdateFieldAccessTypeInfo(HInstruction* instr, const FieldInfo& info);
+  void SetClassAsTypeInfo(HInstruction* instr, mirror::Class* klass, bool is_exact);
+  void VisitInstanceFieldGet(HInstanceFieldGet* instr) OVERRIDE;
+  void VisitStaticFieldGet(HStaticFieldGet* instr) OVERRIDE;
+  void VisitInvoke(HInvoke* instr) OVERRIDE;
+  void UpdateReferenceTypeInfo(HInstruction* instr,
+                               uint16_t type_idx,
+                               const DexFile& dex_file,
+                               bool is_exact);
+
+ private:
+  StackHandleScopeCollection* handles_;
+};
+
 void ReferenceTypePropagation::Run() {
   // To properly propagate type info we need to visit in the dominator-based order.
   // Reverse post order guarantees a node's dominators are visited first.
@@ -35,23 +58,13 @@
 
 void ReferenceTypePropagation::VisitBasicBlock(HBasicBlock* block) {
   // TODO: handle other instructions that give type info
-  // (Call/array accesses)
+  // (array accesses)
 
+  RTPVisitor visitor(graph_, handles_);
   // Initialize exact types first for faster convergence.
   for (HInstructionIterator it(block->GetInstructions()); !it.Done(); it.Advance()) {
     HInstruction* instr = it.Current();
-    // TODO: Make ReferenceTypePropagation a visitor or create a new one.
-    if (instr->IsNewInstance()) {
-      VisitNewInstance(instr->AsNewInstance());
-    } else if (instr->IsLoadClass()) {
-      VisitLoadClass(instr->AsLoadClass());
-    } else if (instr->IsNewArray()) {
-      VisitNewArray(instr->AsNewArray());
-    } else if (instr->IsInstanceFieldGet()) {
-      VisitInstanceFieldGet(instr->AsInstanceFieldGet());
-    } else if (instr->IsStaticFieldGet()) {
-      VisitStaticFieldGet(instr->AsStaticFieldGet());
-    }
+    instr->Accept(&visitor);
   }
 
   // Handle Phis.
@@ -166,9 +179,9 @@
   }
 }
 
-void ReferenceTypePropagation::SetClassAsTypeInfo(HInstruction* instr,
-                                                  mirror::Class* klass,
-                                                  bool is_exact) {
+void RTPVisitor::SetClassAsTypeInfo(HInstruction* instr,
+                                    mirror::Class* klass,
+                                    bool is_exact) {
   if (klass != nullptr) {
     ScopedObjectAccess soa(Thread::Current());
     MutableHandle<mirror::Class> handle = handles_->NewHandle(klass);
@@ -177,10 +190,10 @@
   }
 }
 
-void ReferenceTypePropagation::UpdateReferenceTypeInfo(HInstruction* instr,
-                                                       uint16_t type_idx,
-                                                       const DexFile& dex_file,
-                                                       bool is_exact) {
+void RTPVisitor::UpdateReferenceTypeInfo(HInstruction* instr,
+                                         uint16_t type_idx,
+                                         const DexFile& dex_file,
+                                         bool is_exact) {
   DCHECK_EQ(instr->GetType(), Primitive::kPrimNot);
 
   ScopedObjectAccess soa(Thread::Current());
@@ -189,16 +202,16 @@
   SetClassAsTypeInfo(instr, dex_cache->GetResolvedType(type_idx), is_exact);
 }
 
-void ReferenceTypePropagation::VisitNewInstance(HNewInstance* instr) {
+void RTPVisitor::VisitNewInstance(HNewInstance* instr) {
   UpdateReferenceTypeInfo(instr, instr->GetTypeIndex(), instr->GetDexFile(), /* is_exact */ true);
 }
 
-void ReferenceTypePropagation::VisitNewArray(HNewArray* instr) {
+void RTPVisitor::VisitNewArray(HNewArray* instr) {
   UpdateReferenceTypeInfo(instr, instr->GetTypeIndex(), instr->GetDexFile(), /* is_exact */ true);
 }
 
-void ReferenceTypePropagation::UpdateFieldAccessTypeInfo(HInstruction* instr,
-                                                         const FieldInfo& info) {
+void RTPVisitor::UpdateFieldAccessTypeInfo(HInstruction* instr,
+                                           const FieldInfo& info) {
   // The field index is unknown only during tests.
   if (instr->GetType() != Primitive::kPrimNot || info.GetFieldIndex() == kUnknownFieldIndex) {
     return;
@@ -213,15 +226,15 @@
   SetClassAsTypeInfo(instr, klass, /* is_exact */ false);
 }
 
-void ReferenceTypePropagation::VisitInstanceFieldGet(HInstanceFieldGet* instr) {
+void RTPVisitor::VisitInstanceFieldGet(HInstanceFieldGet* instr) {
   UpdateFieldAccessTypeInfo(instr, instr->GetFieldInfo());
 }
 
-void ReferenceTypePropagation::VisitStaticFieldGet(HStaticFieldGet* instr) {
+void RTPVisitor::VisitStaticFieldGet(HStaticFieldGet* instr) {
   UpdateFieldAccessTypeInfo(instr, instr->GetFieldInfo());
 }
 
-void ReferenceTypePropagation::VisitLoadClass(HLoadClass* instr) {
+void RTPVisitor::VisitLoadClass(HLoadClass* instr) {
   ScopedObjectAccess soa(Thread::Current());
   mirror::DexCache* dex_cache =
       Runtime::Current()->GetClassLinker()->FindDexCache(instr->GetDexFile());
@@ -299,6 +312,21 @@
   return !previous_rti.IsEqual(instr->GetReferenceTypeInfo());
 }
 
+void RTPVisitor::VisitInvoke(HInvoke* instr) {
+  if (instr->GetType() != Primitive::kPrimNot) {
+    return;
+  }
+
+  ScopedObjectAccess soa(Thread::Current());
+  ClassLinker* cl = Runtime::Current()->GetClassLinker();
+  mirror::DexCache* dex_cache = cl->FindDexCache(instr->GetDexFile());
+  ArtMethod* method = dex_cache->GetResolvedMethod(
+      instr->GetDexMethodIndex(), cl->GetImagePointerSize());
+  DCHECK(method != nullptr);
+  mirror::Class* klass = method->GetReturnType(false);
+  SetClassAsTypeInfo(instr, klass, /* is_exact */ false);
+}
+
 void ReferenceTypePropagation::UpdateBoundType(HBoundType* instr) {
   ReferenceTypeInfo new_rti = instr->InputAt(0)->GetReferenceTypeInfo();
   // Be sure that we don't go over the bounded type.
diff --git a/compiler/optimizing/reference_type_propagation.h b/compiler/optimizing/reference_type_propagation.h
index 0a1d4c4..0d687d2 100644
--- a/compiler/optimizing/reference_type_propagation.h
+++ b/compiler/optimizing/reference_type_propagation.h
@@ -40,26 +40,12 @@
   static constexpr const char* kReferenceTypePropagationPassName = "reference_type_propagation";
 
  private:
-  void VisitNewInstance(HNewInstance* new_instance);
-  void VisitLoadClass(HLoadClass* load_class);
-  void VisitNewArray(HNewArray* instr);
   void VisitPhi(HPhi* phi);
   void VisitBasicBlock(HBasicBlock* block);
-  void UpdateFieldAccessTypeInfo(HInstruction* instr, const FieldInfo& info);
-  void SetClassAsTypeInfo(HInstruction* instr, mirror::Class* klass, bool is_exact);
-
   void UpdateBoundType(HBoundType* bound_type) SHARED_LOCKS_REQUIRED(Locks::mutator_lock_);
   void UpdatePhi(HPhi* phi) SHARED_LOCKS_REQUIRED(Locks::mutator_lock_);
-
   void BoundTypeForIfNotNull(HBasicBlock* block);
   void BoundTypeForIfInstanceOf(HBasicBlock* block);
-  void UpdateReferenceTypeInfo(HInstruction* instr,
-                               uint16_t type_idx,
-                               const DexFile& dex_file,
-                               bool is_exact);
-  void VisitInstanceFieldGet(HInstanceFieldGet* instr);
-  void VisitStaticFieldGet(HStaticFieldGet* instr);
-
   void ProcessWorklist();
   void AddToWorklist(HInstruction* instr);
   void AddDependentInstructionsToWorklist(HInstruction* instr);
diff --git a/test/450-checker-types/src/Main.java b/test/450-checker-types/src/Main.java
index 4056275..9b447e9 100644
--- a/test/450-checker-types/src/Main.java
+++ b/test/450-checker-types/src/Main.java
@@ -364,6 +364,28 @@
     ((SubclassA)b).$noinline$g();
   }
 
+  public SubclassA $noinline$getSubclass() { throw new RuntimeException(); }
+
+  /// CHECK-START: void Main.testArraySimpleRemove() instruction_simplifier_after_types (before)
+  /// CHECK:         CheckCast
+
+  /// CHECK-START: void Main.testArraySimpleRemove() instruction_simplifier_after_types (after)
+  /// CHECK-NOT:     CheckCast
+  public void testArraySimpleRemove() {
+    Super[] b = new SubclassA[10];
+    SubclassA[] c = (SubclassA[])b;
+  }
+
+  /// CHECK-START: void Main.testInvokeSimpleRemove() instruction_simplifier_after_types (before)
+  /// CHECK:         CheckCast
+
+  /// CHECK-START: void Main.testInvokeSimpleRemove() instruction_simplifier_after_types (after)
+  /// CHECK-NOT:     CheckCast
+  public void testInvokeSimpleRemove() {
+    Super b = $noinline$getSubclass();
+    ((SubclassA)b).$noinline$g();
+  }
+
   public static void main(String[] args) {
   }
 }