ARM64: Add support for multiply-accumulate.

Change-Id: I88dc313df520480f3fd16bbabda27f9435d25368
diff --git a/compiler/optimizing/code_generator_arm64.cc b/compiler/optimizing/code_generator_arm64.cc
index 2776b7d..e2aa4dc 100644
--- a/compiler/optimizing/code_generator_arm64.cc
+++ b/compiler/optimizing/code_generator_arm64.cc
@@ -1628,6 +1628,47 @@
          Operand(InputOperandAt(instruction, 1)));
 }
 
+void LocationsBuilderARM64::VisitArm64MultiplyAccumulate(HArm64MultiplyAccumulate* instr) {
+  LocationSummary* locations =
+      new (GetGraph()->GetArena()) LocationSummary(instr, LocationSummary::kNoCall);
+  locations->SetInAt(HArm64MultiplyAccumulate::kInputAccumulatorIndex,
+                     Location::RequiresRegister());
+  locations->SetInAt(HArm64MultiplyAccumulate::kInputMulLeftIndex, Location::RequiresRegister());
+  locations->SetInAt(HArm64MultiplyAccumulate::kInputMulRightIndex, Location::RequiresRegister());
+  locations->SetOut(Location::RequiresRegister(), Location::kNoOutputOverlap);
+}
+
+void InstructionCodeGeneratorARM64::VisitArm64MultiplyAccumulate(HArm64MultiplyAccumulate* instr) {
+  Register res = OutputRegister(instr);
+  Register accumulator = InputRegisterAt(instr, HArm64MultiplyAccumulate::kInputAccumulatorIndex);
+  Register mul_left = InputRegisterAt(instr, HArm64MultiplyAccumulate::kInputMulLeftIndex);
+  Register mul_right = InputRegisterAt(instr, HArm64MultiplyAccumulate::kInputMulRightIndex);
+
+  // Avoid emitting code that could trigger Cortex A53's erratum 835769.
+  // This fixup should be carried out for all multiply-accumulate instructions:
+  // madd, msub, smaddl, smsubl, umaddl and umsubl.
+  if (instr->GetType() == Primitive::kPrimLong &&
+      codegen_->GetInstructionSetFeatures().NeedFixCortexA53_835769()) {
+    MacroAssembler* masm = down_cast<CodeGeneratorARM64*>(codegen_)->GetVIXLAssembler();
+    vixl::Instruction* prev = masm->GetCursorAddress<vixl::Instruction*>() - vixl::kInstructionSize;
+    if (prev->IsLoadOrStore()) {
+      // Make sure we emit only exactly one nop.
+      vixl::CodeBufferCheckScope scope(masm,
+                                       vixl::kInstructionSize,
+                                       vixl::CodeBufferCheckScope::kCheck,
+                                       vixl::CodeBufferCheckScope::kExactSize);
+      __ nop();
+    }
+  }
+
+  if (instr->GetOpKind() == HInstruction::kAdd) {
+    __ Madd(res, mul_left, mul_right, accumulator);
+  } else {
+    DCHECK(instr->GetOpKind() == HInstruction::kSub);
+    __ Msub(res, mul_left, mul_right, accumulator);
+  }
+}
+
 void LocationsBuilderARM64::VisitArrayGet(HArrayGet* instruction) {
   LocationSummary* locations =
       new (GetGraph()->GetArena()) LocationSummary(instruction, LocationSummary::kNoCall);
diff --git a/compiler/optimizing/graph_visualizer.cc b/compiler/optimizing/graph_visualizer.cc
index d166d00..4438190 100644
--- a/compiler/optimizing/graph_visualizer.cc
+++ b/compiler/optimizing/graph_visualizer.cc
@@ -422,6 +422,12 @@
     StartAttributeStream("kind") << (try_boundary->IsEntry() ? "entry" : "exit");
   }
 
+#ifdef ART_ENABLE_CODEGEN_arm64
+  void VisitArm64MultiplyAccumulate(HArm64MultiplyAccumulate* instruction) OVERRIDE {
+    StartAttributeStream("kind") << instruction->GetOpKind();
+  }
+#endif
+
   bool IsPass(const char* name) {
     return strcmp(pass_name_, name) == 0;
   }
diff --git a/compiler/optimizing/instruction_simplifier_arm64.cc b/compiler/optimizing/instruction_simplifier_arm64.cc
index eb79f46..54dd2cc 100644
--- a/compiler/optimizing/instruction_simplifier_arm64.cc
+++ b/compiler/optimizing/instruction_simplifier_arm64.cc
@@ -62,6 +62,67 @@
   RecordSimplification();
 }
 
+bool InstructionSimplifierArm64Visitor::TrySimpleMultiplyAccumulatePatterns(
+    HMul* mul, HBinaryOperation* input_binop, HInstruction* input_other) {
+  DCHECK(Primitive::IsIntOrLongType(mul->GetType()));
+  DCHECK(input_binop->IsAdd() || input_binop->IsSub());
+  DCHECK_NE(input_binop, input_other);
+  if (!input_binop->HasOnlyOneNonEnvironmentUse()) {
+    return false;
+  }
+
+  // Try to interpret patterns like
+  //    a * (b <+/-> 1)
+  // as
+  //    (a * b) <+/-> a
+  HInstruction* input_a = input_other;
+  HInstruction* input_b = nullptr;  // Set to a non-null value if we found a pattern to optimize.
+  HInstruction::InstructionKind op_kind;
+
+  if (input_binop->IsAdd()) {
+    if ((input_binop->GetConstantRight() != nullptr) && input_binop->GetConstantRight()->IsOne()) {
+      // Interpret
+      //    a * (b + 1)
+      // as
+      //    (a * b) + a
+      input_b = input_binop->GetLeastConstantLeft();
+      op_kind = HInstruction::kAdd;
+    }
+  } else {
+    DCHECK(input_binop->IsSub());
+    if (input_binop->GetRight()->IsConstant() &&
+        input_binop->GetRight()->AsConstant()->IsMinusOne()) {
+      // Interpret
+      //    a * (b - (-1))
+      // as
+      //    a + (a * b)
+      input_b = input_binop->GetLeft();
+      op_kind = HInstruction::kAdd;
+    } else if (input_binop->GetLeft()->IsConstant() &&
+               input_binop->GetLeft()->AsConstant()->IsOne()) {
+      // Interpret
+      //    a * (1 - b)
+      // as
+      //    a - (a * b)
+      input_b = input_binop->GetRight();
+      op_kind = HInstruction::kSub;
+    }
+  }
+
+  if (input_b == nullptr) {
+    // We did not find a pattern we can optimize.
+    return false;
+  }
+
+  HArm64MultiplyAccumulate* mulacc = new(GetGraph()->GetArena()) HArm64MultiplyAccumulate(
+      mul->GetType(), op_kind, input_a, input_a, input_b, mul->GetDexPc());
+
+  mul->GetBlock()->ReplaceAndRemoveInstructionWith(mul, mulacc);
+  input_binop->GetBlock()->RemoveInstruction(input_binop);
+
+  return false;
+}
+
 void InstructionSimplifierArm64Visitor::VisitArrayGet(HArrayGet* instruction) {
   TryExtractArrayAccessAddress(instruction,
                                instruction->GetArray(),
@@ -76,5 +137,78 @@
                                Primitive::ComponentSize(instruction->GetComponentType()));
 }
 
+void InstructionSimplifierArm64Visitor::VisitMul(HMul* instruction) {
+  Primitive::Type type = instruction->GetType();
+  if (!Primitive::IsIntOrLongType(type)) {
+    return;
+  }
+
+  HInstruction* use = instruction->HasNonEnvironmentUses()
+      ? instruction->GetUses().GetFirst()->GetUser()
+      : nullptr;
+
+  if (instruction->HasOnlyOneNonEnvironmentUse() && (use->IsAdd() || use->IsSub())) {
+    // Replace code looking like
+    //    MUL tmp, x, y
+    //    SUB dst, acc, tmp
+    // with
+    //    MULSUB dst, acc, x, y
+    // Note that we do not want to (unconditionally) perform the merge when the
+    // multiplication has multiple uses and it can be merged in all of them.
+    // Multiple uses could happen on the same control-flow path, and we would
+    // then increase the amount of work. In the future we could try to evaluate
+    // whether all uses are on different control-flow paths (using dominance and
+    // reverse-dominance information) and only perform the merge when they are.
+    HInstruction* accumulator = nullptr;
+    HBinaryOperation* binop = use->AsBinaryOperation();
+    HInstruction* binop_left = binop->GetLeft();
+    HInstruction* binop_right = binop->GetRight();
+    // Be careful after GVN. This should not happen since the `HMul` has only
+    // one use.
+    DCHECK_NE(binop_left, binop_right);
+    if (binop_right == instruction) {
+      accumulator = binop_left;
+    } else if (use->IsAdd()) {
+      DCHECK_EQ(binop_left, instruction);
+      accumulator = binop_right;
+    }
+
+    if (accumulator != nullptr) {
+      HArm64MultiplyAccumulate* mulacc =
+          new (GetGraph()->GetArena()) HArm64MultiplyAccumulate(type,
+                                                                binop->GetKind(),
+                                                                accumulator,
+                                                                instruction->GetLeft(),
+                                                                instruction->GetRight());
+
+      binop->GetBlock()->ReplaceAndRemoveInstructionWith(binop, mulacc);
+      DCHECK(!instruction->HasUses());
+      instruction->GetBlock()->RemoveInstruction(instruction);
+      RecordSimplification();
+      return;
+    }
+  }
+
+  // Use multiply accumulate instruction for a few simple patterns.
+  // We prefer not applying the following transformations if the left and
+  // right inputs perform the same operation.
+  // We rely on GVN having squashed the inputs if appropriate. However the
+  // results are still correct even if that did not happen.
+  if (instruction->GetLeft() == instruction->GetRight()) {
+    return;
+  }
+
+  HInstruction* left = instruction->GetLeft();
+  HInstruction* right = instruction->GetRight();
+  if ((right->IsAdd() || right->IsSub()) &&
+      TrySimpleMultiplyAccumulatePatterns(instruction, right->AsBinaryOperation(), left)) {
+    return;
+  }
+  if ((left->IsAdd() || left->IsSub()) &&
+      TrySimpleMultiplyAccumulatePatterns(instruction, left->AsBinaryOperation(), right)) {
+    return;
+  }
+}
+
 }  // namespace arm64
 }  // namespace art
diff --git a/compiler/optimizing/instruction_simplifier_arm64.h b/compiler/optimizing/instruction_simplifier_arm64.h
index 4b697db..eed2276 100644
--- a/compiler/optimizing/instruction_simplifier_arm64.h
+++ b/compiler/optimizing/instruction_simplifier_arm64.h
@@ -40,8 +40,14 @@
                                     HInstruction* index,
                                     int access_size);
 
+  bool TrySimpleMultiplyAccumulatePatterns(HMul* mul,
+                                           HBinaryOperation* input_binop,
+                                           HInstruction* input_other);
+
+  // HInstruction visitors, sorted alphabetically.
   void VisitArrayGet(HArrayGet* instruction) OVERRIDE;
   void VisitArraySet(HArraySet* instruction) OVERRIDE;
+  void VisitMul(HMul* instruction) OVERRIDE;
 
   OptimizingCompilerStats* stats_;
 };
diff --git a/compiler/optimizing/nodes.h b/compiler/optimizing/nodes.h
index 4f894b0..f36e0f5 100644
--- a/compiler/optimizing/nodes.h
+++ b/compiler/optimizing/nodes.h
@@ -1096,7 +1096,8 @@
 #define FOR_EACH_CONCRETE_INSTRUCTION_ARM64(M)
 #else
 #define FOR_EACH_CONCRETE_INSTRUCTION_ARM64(M)                          \
-  M(Arm64IntermediateAddress, Instruction)
+  M(Arm64IntermediateAddress, Instruction)                              \
+  M(Arm64MultiplyAccumulate, Instruction)
 #endif
 
 #define FOR_EACH_CONCRETE_INSTRUCTION_MIPS(M)
diff --git a/compiler/optimizing/nodes_arm64.h b/compiler/optimizing/nodes_arm64.h
index 885d3a2..d07f019 100644
--- a/compiler/optimizing/nodes_arm64.h
+++ b/compiler/optimizing/nodes_arm64.h
@@ -42,6 +42,40 @@
   DISALLOW_COPY_AND_ASSIGN(HArm64IntermediateAddress);
 };
 
+class HArm64MultiplyAccumulate : public HExpression<3> {
+ public:
+  HArm64MultiplyAccumulate(Primitive::Type type,
+                           InstructionKind op,
+                           HInstruction* accumulator,
+                           HInstruction* mul_left,
+                           HInstruction* mul_right,
+                           uint32_t dex_pc = kNoDexPc)
+      : HExpression(type, SideEffects::None(), dex_pc), op_kind_(op) {
+    SetRawInputAt(kInputAccumulatorIndex, accumulator);
+    SetRawInputAt(kInputMulLeftIndex, mul_left);
+    SetRawInputAt(kInputMulRightIndex, mul_right);
+  }
+
+  static constexpr int kInputAccumulatorIndex = 0;
+  static constexpr int kInputMulLeftIndex = 1;
+  static constexpr int kInputMulRightIndex = 2;
+
+  bool CanBeMoved() const OVERRIDE { return true; }
+  bool InstructionDataEquals(HInstruction* other) const OVERRIDE {
+    return op_kind_ == other->AsArm64MultiplyAccumulate()->op_kind_;
+  }
+
+  InstructionKind GetOpKind() const { return op_kind_; }
+
+  DECLARE_INSTRUCTION(Arm64MultiplyAccumulate);
+
+ private:
+  // Indicates if this is a MADD or MSUB.
+  InstructionKind op_kind_;
+
+  DISALLOW_COPY_AND_ASSIGN(HArm64MultiplyAccumulate);
+};
+
 }  // namespace art
 
 #endif  // ART_COMPILER_OPTIMIZING_NODES_ARM64_H_