Quick: Eliminate check-cast guaranteed by instance-of.

Eliminate check-cast if the result of an instance-of with
the very same type on the same value is used to branch to
the check-cast's block or a dominator of it.

Note that there already exists a verifier-based elimination
of check-cast but it excludes check-cast on interfaces. This
new optimization works for interface types and, since it's
GVN-based, it can better recognize when the same reference
is used for instance-of and check-cast.

Change-Id: Ib315199805099d1cb0534bb4a90dc51baa409685
diff --git a/compiler/dex/compiler_enums.h b/compiler/dex/compiler_enums.h
index 7edb490..39725de 100644
--- a/compiler/dex/compiler_enums.h
+++ b/compiler/dex/compiler_enums.h
@@ -345,6 +345,7 @@
 enum MIROptimizationFlagPositions {
   kMIRIgnoreNullCheck = 0,
   kMIRIgnoreRangeCheck,
+  kMIRIgnoreCheckCast,
   kMIRStoreNonNullValue,              // Storing non-null value, always mark GC card.
   kMIRClassIsInitialized,
   kMIRClassIsInDexCache,
diff --git a/compiler/dex/global_value_numbering.cc b/compiler/dex/global_value_numbering.cc
index ab3c946..30e3ce0 100644
--- a/compiler/dex/global_value_numbering.cc
+++ b/compiler/dex/global_value_numbering.cc
@@ -16,6 +16,7 @@
 
 #include "global_value_numbering.h"
 
+#include "base/bit_vector-inl.h"
 #include "base/stl_util.h"
 #include "local_value_numbering.h"
 
@@ -206,4 +207,41 @@
   return true;
 }
 
+bool GlobalValueNumbering::IsBlockEnteredOnTrue(uint16_t cond, BasicBlockId bb_id) {
+  DCHECK_NE(cond, kNoValue);
+  BasicBlock* bb = mir_graph_->GetBasicBlock(bb_id);
+  if (bb->predecessors.size() == 1u) {
+    BasicBlockId pred_id = bb->predecessors[0];
+    BasicBlock* pred_bb = mir_graph_->GetBasicBlock(pred_id);
+    if (pred_bb->last_mir_insn != nullptr) {
+      Instruction::Code opcode = pred_bb->last_mir_insn->dalvikInsn.opcode;
+      if ((opcode == Instruction::IF_NEZ && pred_bb->taken == bb_id) ||
+          (opcode == Instruction::IF_EQZ && pred_bb->fall_through == bb_id)) {
+        DCHECK(lvns_[pred_id] != nullptr);
+        uint16_t operand = lvns_[pred_id]->GetSregValue(pred_bb->last_mir_insn->ssa_rep->uses[0]);
+        if (operand == cond) {
+          return true;
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool GlobalValueNumbering::IsTrueInBlock(uint16_t cond, BasicBlockId bb_id) {
+  // We're not doing proper value propagation, so just see if the condition is used
+  // with if-nez/if-eqz to branch/fall-through to this bb or one of its dominators.
+  DCHECK_NE(cond, kNoValue);
+  if (IsBlockEnteredOnTrue(cond, bb_id)) {
+    return true;
+  }
+  BasicBlock* bb = mir_graph_->GetBasicBlock(bb_id);
+  for (uint32_t dom_id : bb->dominators->Indexes()) {
+    if (IsBlockEnteredOnTrue(cond, dom_id)) {
+      return true;
+    }
+  }
+  return false;
+}
+
 }  // namespace art
diff --git a/compiler/dex/global_value_numbering.h b/compiler/dex/global_value_numbering.h
index 6fa658c..bd2f187 100644
--- a/compiler/dex/global_value_numbering.h
+++ b/compiler/dex/global_value_numbering.h
@@ -200,6 +200,9 @@
 
   bool DivZeroCheckedInAllPredecessors(const ScopedArenaVector<uint16_t>& merge_names) const;
 
+  bool IsBlockEnteredOnTrue(uint16_t cond, BasicBlockId bb_id);
+  bool IsTrueInBlock(uint16_t cond, BasicBlockId bb_id);
+
   ScopedArenaAllocator* Allocator() const {
     return allocator_;
   }
diff --git a/compiler/dex/global_value_numbering_test.cc b/compiler/dex/global_value_numbering_test.cc
index b91c3ca..b4559ef 100644
--- a/compiler/dex/global_value_numbering_test.cc
+++ b/compiler/dex/global_value_numbering_test.cc
@@ -136,6 +136,7 @@
     { bb, static_cast<Instruction::Code>(kMirOpPhi), 0, 0u, 2u, { src1, src2 }, 1, { reg } }
 #define DEF_BINOP(bb, opcode, result, src1, src2) \
     { bb, opcode, 0u, 0u, 2, { src1, src2 }, 1, { result } }
+#define DEF_UNOP(bb, opcode, result, src) DEF_MOVE(bb, opcode, result, src)
 
   void DoPrepareIFields(const IFieldDef* defs, size_t count) {
     cu_.mir_graph->ifield_lowering_infos_.clear();
@@ -2315,4 +2316,95 @@
   }
 }
 
+TEST_F(GlobalValueNumberingTestDiamond, CheckCastDiamond) {
+  static const MIRDef mirs[] = {
+      DEF_UNOP(3u, Instruction::INSTANCE_OF, 0u, 100u),
+      DEF_UNOP(3u, Instruction::INSTANCE_OF, 1u, 200u),
+      DEF_IFZ(3u, Instruction::IF_NEZ, 0u),
+      DEF_INVOKE1(4u, Instruction::CHECK_CAST, 100u),
+      DEF_INVOKE1(5u, Instruction::CHECK_CAST, 100u),
+      DEF_INVOKE1(5u, Instruction::CHECK_CAST, 200u),
+      DEF_INVOKE1(5u, Instruction::CHECK_CAST, 100u),
+      DEF_INVOKE1(6u, Instruction::CHECK_CAST, 100u),
+  };
+
+  static const bool expected_ignore_check_cast[] = {
+      false,  // instance-of
+      false,  // instance-of
+      false,  // if-nez
+      false,  // Not eliminated, fall-through branch.
+      true,   // Eliminated.
+      false,  // Not eliminated, different value.
+      false,  // Not eliminated, different type.
+      false,  // Not eliminated, bottom block.
+  };
+
+  PrepareMIRs(mirs);
+  mirs_[0].dalvikInsn.vC = 1234;  // type for instance-of
+  mirs_[1].dalvikInsn.vC = 1234;  // type for instance-of
+  mirs_[3].dalvikInsn.vB = 1234;  // type for check-cast
+  mirs_[4].dalvikInsn.vB = 1234;  // type for check-cast
+  mirs_[5].dalvikInsn.vB = 1234;  // type for check-cast
+  mirs_[6].dalvikInsn.vB = 4321;  // type for check-cast
+  mirs_[7].dalvikInsn.vB = 1234;  // type for check-cast
+  PerformGVN();
+  PerformGVNCodeModifications();
+  ASSERT_EQ(arraysize(expected_ignore_check_cast), mir_count_);
+  for (size_t i = 0u; i != mir_count_; ++i) {
+    int expected = expected_ignore_check_cast[i] ? MIR_IGNORE_CHECK_CAST : 0u;
+    EXPECT_EQ(expected, mirs_[i].optimization_flags) << i;
+  }
+}
+
+TEST_F(GlobalValueNumberingTest, CheckCastDominators) {
+  const BBDef bbs[] = {
+      DEF_BB(kNullBlock, DEF_SUCC0(), DEF_PRED0()),
+      DEF_BB(kEntryBlock, DEF_SUCC1(3), DEF_PRED0()),
+      DEF_BB(kExitBlock, DEF_SUCC0(), DEF_PRED1(7)),
+      DEF_BB(kDalvikByteCode, DEF_SUCC2(4, 5), DEF_PRED1(1)),  // Block #3, top of the diamond.
+      DEF_BB(kDalvikByteCode, DEF_SUCC1(7), DEF_PRED1(3)),     // Block #4, left side.
+      DEF_BB(kDalvikByteCode, DEF_SUCC1(6), DEF_PRED1(3)),     // Block #5, right side.
+      DEF_BB(kDalvikByteCode, DEF_SUCC1(7), DEF_PRED1(5)),     // Block #6, right side.
+      DEF_BB(kDalvikByteCode, DEF_SUCC1(2), DEF_PRED2(4, 6)),  // Block #7, bottom.
+  };
+  static const MIRDef mirs[] = {
+      DEF_UNOP(3u, Instruction::INSTANCE_OF, 0u, 100u),
+      DEF_UNOP(3u, Instruction::INSTANCE_OF, 1u, 200u),
+      DEF_IFZ(3u, Instruction::IF_NEZ, 0u),
+      DEF_INVOKE1(4u, Instruction::CHECK_CAST, 100u),
+      DEF_INVOKE1(6u, Instruction::CHECK_CAST, 100u),
+      DEF_INVOKE1(6u, Instruction::CHECK_CAST, 200u),
+      DEF_INVOKE1(6u, Instruction::CHECK_CAST, 100u),
+      DEF_INVOKE1(7u, Instruction::CHECK_CAST, 100u),
+  };
+
+  static const bool expected_ignore_check_cast[] = {
+      false,  // instance-of
+      false,  // instance-of
+      false,  // if-nez
+      false,  // Not eliminated, fall-through branch.
+      true,   // Eliminated.
+      false,  // Not eliminated, different value.
+      false,  // Not eliminated, different type.
+      false,  // Not eliminated, bottom block.
+  };
+
+  PrepareBasicBlocks(bbs);
+  PrepareMIRs(mirs);
+  mirs_[0].dalvikInsn.vC = 1234;  // type for instance-of
+  mirs_[1].dalvikInsn.vC = 1234;  // type for instance-of
+  mirs_[3].dalvikInsn.vB = 1234;  // type for check-cast
+  mirs_[4].dalvikInsn.vB = 1234;  // type for check-cast
+  mirs_[5].dalvikInsn.vB = 1234;  // type for check-cast
+  mirs_[6].dalvikInsn.vB = 4321;  // type for check-cast
+  mirs_[7].dalvikInsn.vB = 1234;  // type for check-cast
+  PerformGVN();
+  PerformGVNCodeModifications();
+  ASSERT_EQ(arraysize(expected_ignore_check_cast), mir_count_);
+  for (size_t i = 0u; i != mir_count_; ++i) {
+    int expected = expected_ignore_check_cast[i] ? MIR_IGNORE_CHECK_CAST : 0u;
+    EXPECT_EQ(expected, mirs_[i].optimization_flags) << i;
+  }
+}
+
 }  // namespace art
diff --git a/compiler/dex/gvn_dead_code_elimination.cc b/compiler/dex/gvn_dead_code_elimination.cc
index 2e7f032..2d4c18f 100644
--- a/compiler/dex/gvn_dead_code_elimination.cc
+++ b/compiler/dex/gvn_dead_code_elimination.cc
@@ -1058,7 +1058,6 @@
     case Instruction::INVOKE_INTERFACE_RANGE:
     case Instruction::INVOKE_STATIC:
     case Instruction::INVOKE_STATIC_RANGE:
-    case Instruction::CHECK_CAST:
     case Instruction::THROW:
     case Instruction::FILLED_NEW_ARRAY:
     case Instruction::FILLED_NEW_ARRAY_RANGE:
@@ -1073,6 +1072,12 @@
       uses_all_vregs = true;
       break;
 
+    case Instruction::CHECK_CAST:
+      DCHECK_EQ(mir->ssa_rep->num_uses, 1);
+      must_keep = true;  // Keep for type information even if MIR_IGNORE_CHECK_CAST.
+      uses_all_vregs = (mir->optimization_flags & MIR_IGNORE_CHECK_CAST) == 0;
+      break;
+
     case kMirOpNullCheck:
       DCHECK_EQ(mir->ssa_rep->num_uses, 1);
       if ((mir->optimization_flags & MIR_IGNORE_NULL_CHECK) != 0) {
diff --git a/compiler/dex/local_value_numbering.cc b/compiler/dex/local_value_numbering.cc
index 99b6683..dc222b5 100644
--- a/compiler/dex/local_value_numbering.cc
+++ b/compiler/dex/local_value_numbering.cc
@@ -1520,7 +1520,6 @@
     case Instruction::GOTO:
     case Instruction::GOTO_16:
     case Instruction::GOTO_32:
-    case Instruction::CHECK_CAST:
     case Instruction::THROW:
     case Instruction::FILL_ARRAY_DATA:
     case Instruction::PACKED_SWITCH:
@@ -1612,9 +1611,32 @@
       HandleInvokeOrClInitOrAcquireOp(mir);
       break;
 
+    case Instruction::INSTANCE_OF: {
+        uint16_t operand = GetOperandValue(mir->ssa_rep->uses[0]);
+        uint16_t type = mir->dalvikInsn.vC;
+        res = gvn_->LookupValue(Instruction::INSTANCE_OF, operand, type, kNoValue);
+        SetOperandValue(mir->ssa_rep->defs[0], res);
+      }
+      break;
+    case Instruction::CHECK_CAST:
+      if (gvn_->CanModify()) {
+        // Check if there was an instance-of operation on the same value and if we are
+        // in a block where its result is true. If so, we can eliminate the check-cast.
+        uint16_t operand = GetOperandValue(mir->ssa_rep->uses[0]);
+        uint16_t type = mir->dalvikInsn.vB;
+        uint16_t cond = gvn_->FindValue(Instruction::INSTANCE_OF, operand, type, kNoValue);
+        if (cond != kNoValue && gvn_->IsTrueInBlock(cond, Id())) {
+          if (gvn_->GetCompilationUnit()->verbose) {
+            LOG(INFO) << "Removing check-cast at 0x" << std::hex << mir->offset;
+          }
+          // Don't use kMirOpNop. Keep the check-cast as it defines the type of the register.
+          mir->optimization_flags |= MIR_IGNORE_CHECK_CAST;
+        }
+      }
+      break;
+
     case Instruction::MOVE_RESULT:
     case Instruction::MOVE_RESULT_OBJECT:
-    case Instruction::INSTANCE_OF:
       // 1 result, treat as unique each time, use result s_reg - will be unique.
       res = GetOperandValue(mir->ssa_rep->defs[0]);
       SetOperandValue(mir->ssa_rep->defs[0], res);
diff --git a/compiler/dex/mir_graph.h b/compiler/dex/mir_graph.h
index 9da39d1..3298af1 100644
--- a/compiler/dex/mir_graph.h
+++ b/compiler/dex/mir_graph.h
@@ -150,6 +150,7 @@
 
 #define MIR_IGNORE_NULL_CHECK           (1 << kMIRIgnoreNullCheck)
 #define MIR_IGNORE_RANGE_CHECK          (1 << kMIRIgnoreRangeCheck)
+#define MIR_IGNORE_CHECK_CAST           (1 << kMIRIgnoreCheckCast)
 #define MIR_STORE_NON_NULL_VALUE        (1 << kMIRStoreNonNullValue)
 #define MIR_CLASS_IS_INITIALIZED        (1 << kMIRClassIsInitialized)
 #define MIR_CLASS_IS_IN_DEX_CACHE       (1 << kMIRClassIsInDexCache)
diff --git a/compiler/dex/mir_optimization.cc b/compiler/dex/mir_optimization.cc
index 93749e4..266b7c3 100644
--- a/compiler/dex/mir_optimization.cc
+++ b/compiler/dex/mir_optimization.cc
@@ -1751,6 +1751,9 @@
     DCHECK_NE(opt_flags & MIR_IGNORE_NULL_CHECK, 0);
     // Non-throwing only if range check has been eliminated.
     return ((opt_flags & MIR_IGNORE_RANGE_CHECK) == 0);
+  } else if (mir->dalvikInsn.opcode == Instruction::CHECK_CAST &&
+      (opt_flags & MIR_IGNORE_CHECK_CAST) != 0) {
+    return false;
   } else if (mir->dalvikInsn.opcode == Instruction::ARRAY_LENGTH ||
       static_cast<int>(mir->dalvikInsn.opcode) == kMirOpNullCheck) {
     // No more checks for these (null check was processed above).
diff --git a/compiler/dex/quick/gen_common.cc b/compiler/dex/quick/gen_common.cc
index e57889a..32a469d 100644
--- a/compiler/dex/quick/gen_common.cc
+++ b/compiler/dex/quick/gen_common.cc
@@ -1403,7 +1403,12 @@
   }
 }
 
-void Mir2Lir::GenCheckCast(uint32_t insn_idx, uint32_t type_idx, RegLocation rl_src) {
+void Mir2Lir::GenCheckCast(int opt_flags, uint32_t insn_idx, uint32_t type_idx,
+                           RegLocation rl_src) {
+  if ((opt_flags & MIR_IGNORE_CHECK_CAST) != 0) {
+    // Compiler analysis proved that this check-cast would never cause an exception.
+    return;
+  }
   bool type_known_final, type_known_abstract, use_declaring_class;
   bool needs_access_check = !cu_->compiler_driver->CanAccessTypeWithoutChecks(cu_->method_idx,
                                                                               *cu_->dex_file,
diff --git a/compiler/dex/quick/mir_to_lir.cc b/compiler/dex/quick/mir_to_lir.cc
index 8348626..8edc5fc 100644
--- a/compiler/dex/quick/mir_to_lir.cc
+++ b/compiler/dex/quick/mir_to_lir.cc
@@ -632,7 +632,7 @@
       break;
 
     case Instruction::CHECK_CAST: {
-      GenCheckCast(mir->offset, vB, rl_src[0]);
+      GenCheckCast(opt_flags, mir->offset, vB, rl_src[0]);
       break;
     }
     case Instruction::INSTANCE_OF:
diff --git a/compiler/dex/quick/mir_to_lir.h b/compiler/dex/quick/mir_to_lir.h
index 6f3f057..9a56171 100644
--- a/compiler/dex/quick/mir_to_lir.h
+++ b/compiler/dex/quick/mir_to_lir.h
@@ -826,7 +826,7 @@
     void GenNewInstance(uint32_t type_idx, RegLocation rl_dest);
     void GenThrow(RegLocation rl_src);
     void GenInstanceof(uint32_t type_idx, RegLocation rl_dest, RegLocation rl_src);
-    void GenCheckCast(uint32_t insn_idx, uint32_t type_idx, RegLocation rl_src);
+    void GenCheckCast(int opt_flags, uint32_t insn_idx, uint32_t type_idx, RegLocation rl_src);
     void GenLong3Addr(OpKind first_op, OpKind second_op, RegLocation rl_dest,
                       RegLocation rl_src1, RegLocation rl_src2);
     virtual void GenShiftOpLong(Instruction::Code opcode, RegLocation rl_dest,