RISCV: [Codegen] Add condition instructions

Test: m test-art-host-gtest
Bug: 283082089
Change-Id: Ic356fe85488660fa832400a8af94d6d240a83cd7
diff --git a/compiler/optimizing/code_generator_riscv64.cc b/compiler/optimizing/code_generator_riscv64.cc
index 6fc32a3..fa60266 100644
--- a/compiler/optimizing/code_generator_riscv64.cc
+++ b/compiler/optimizing/code_generator_riscv64.cc
@@ -347,13 +347,96 @@
   LOG(FATAL) << "Unimplemented";
 }
 
-void InstructionCodeGeneratorRISCV64::GenerateIntLongCompare(IfCondition cond,
-                                                             bool is_64bit,
-                                                             LocationSummary* locations) {
-  UNUSED(cond);
-  UNUSED(is_64bit);
-  UNUSED(locations);
-  LOG(FATAL) << "Unimplemented";
+void InstructionCodeGeneratorRISCV64::GenerateIntLongCondition(IfCondition cond,
+                                                               LocationSummary* locations) {
+  XRegister rd = locations->Out().AsRegister<XRegister>();
+  XRegister rs1 = locations->InAt(0).AsRegister<XRegister>();
+  Location rs2_location = locations->InAt(1);
+  bool use_imm = rs2_location.IsConstant();
+  int64_t imm = use_imm ? CodeGenerator::GetInt64ValueOf(rs2_location.GetConstant()) : 0;
+  XRegister rs2 = use_imm ? kNoXRegister : rs2_location.AsRegister<XRegister>();
+  switch (cond) {
+    case kCondEQ:
+    case kCondNE:
+      if (!use_imm) {
+        __ Sub(rd, rs1, rs2);  // SUB is OK here even for 32-bit comparison.
+      } else if (imm != 0) {
+        DCHECK(IsInt<12>(-imm));
+        __ Addi(rd, rs1, -imm);  // ADDI is OK here even for 32-bit comparison.
+      }  // else test `rs1` directly without subtraction for `use_imm && imm == 0`.
+      if (cond == kCondEQ) {
+        __ Seqz(rd, (use_imm && imm == 0) ? rs1 : rd);
+      } else {
+        __ Snez(rd, (use_imm && imm == 0) ? rs1 : rd);
+      }
+      break;
+
+    case kCondLT:
+    case kCondGE:
+      if (use_imm) {
+        DCHECK(IsInt<12>(imm));
+        __ Slti(rd, rs1, imm);
+      } else {
+        __ Slt(rd, rs1, rs2);
+      }
+      if (cond == kCondGE) {
+        // Calculate `rs1 >= rhs` as `!(rs1 < rhs)` since there's only the SLT but no SGE.
+        __ Xori(rd, rd, 1);
+      }
+      break;
+
+    case kCondLE:
+    case kCondGT:
+      if (use_imm) {
+        // Calculate `rs1 <= imm` as `rs1 < imm + 1`.
+        DCHECK(IsInt<12>(imm + 1));  // The value that overflows would fail this check.
+        __ Slti(rd, rs1, imm + 1);
+      } else {
+        __ Slt(rd, rs2, rs1);
+      }
+      if ((cond == kCondGT) == use_imm) {
+        // Calculate `rs1 > imm` as `!(rs1 < imm + 1)` and calculate
+        // `rs1 <= rs2` as `!(rs2 < rs1)` since there's only the SLT but no SGE.
+        __ Xori(rd, rd, 1);
+      }
+      break;
+
+    case kCondB:
+    case kCondAE:
+      if (use_imm) {
+        // Sltiu sign-extends its 12-bit immediate operand before the comparison
+        // and thus lets us compare directly with unsigned values in the ranges
+        // [0, 0x7ff] and [0x[ffffffff]fffff800, 0x[ffffffff]ffffffff].
+        DCHECK(IsInt<12>(imm));
+        __ Sltiu(rd, rs1, imm);
+      } else {
+        __ Sltu(rd, rs1, rs2);
+      }
+      if (cond == kCondAE) {
+        // Calculate `rs1 AE rhs` as `!(rs1 B rhs)` since there's only the SLTU but no SGEU.
+        __ Xori(rd, rd, 1);
+      }
+      break;
+
+    case kCondBE:
+    case kCondA:
+      if (use_imm) {
+        // Calculate `rs1 BE imm` as `rs1 B imm + 1`.
+        // Sltiu sign-extends its 12-bit immediate operand before the comparison
+        // and thus lets us compare directly with unsigned values in the ranges
+        // [0, 0x7ff] and [0x[ffffffff]fffff800, 0x[ffffffff]ffffffff].
+        DCHECK(IsInt<12>(imm + 1));  // The value that overflows would fail this check.
+        __ Sltiu(rd, rs1, imm + 1);
+      } else {
+        __ Sltu(rd, rs2, rs1);
+      }
+      if ((cond == kCondA) == use_imm) {
+        // Calculate `rs1 A imm` as `!(rs1 B imm + 1)` and calculate
+        // `rs1 BE rs2` as `!(rs2 B rs1)` since there's only the SLTU but no SGEU.
+        __ Xori(rd, rd, 1);
+      }
+      break;
+  }
 }
 
 bool InstructionCodeGeneratorRISCV64::MaterializeIntLongCompare(IfCondition cond,
@@ -379,15 +462,97 @@
   LOG(FATAL) << "UniMplemented";
 }
 
-void InstructionCodeGeneratorRISCV64::GenerateFpCompare(IfCondition cond,
-                                                        bool gt_bias,
-                                                        DataType::Type type,
-                                                        LocationSummary* locations) {
-  UNUSED(cond);
-  UNUSED(gt_bias);
-  UNUSED(type);
-  UNUSED(locations);
-  LOG(FATAL) << "Unimplemented";
+void InstructionCodeGeneratorRISCV64::GenerateFpCondition(IfCondition cond,
+                                                          bool gt_bias,
+                                                          DataType::Type type,
+                                                          LocationSummary* locations) {
+  XRegister rd = locations->Out().AsRegister<XRegister>();
+  FRegister rs1 = locations->InAt(0).AsFpuRegister<FRegister>();
+  FRegister rs2 = locations->InAt(1).AsFpuRegister<FRegister>();
+
+  // All FP compare operations yield 0 for NaN on either side but we need the result to be 1
+  // for certain combinations of `gt_bias` and `cond`.
+  // There is no dex instruction or HIR that would need the "equal or unordered" or "not equal"
+  // and the conditions "equal" and "not equal or unordered" do not need the NaN check here.
+  Riscv64Label done;
+  if (gt_bias ? (cond == kCondGT || cond == kCondGE) : (cond == kCondLT || cond == kCondLE)) {
+    // FCLASS.S/D examines the value in the floating-point register and writes to integer
+    // register a 10-bit mask that indicates the class of the floating-point number.
+    // rd[8]: Singaling NaN
+    // rd[9]: Quiet NaN
+    ScratchRegisterScope srs(GetAssembler());
+    XRegister tmp = srs.AllocateXRegister();
+    XRegister tmp2 = srs.AllocateXRegister();
+    if (type == DataType::Type::kFloat32) {
+      __ FClassS(tmp, rs1);
+      __ FClassS(tmp2, rs2);
+    } else {
+      DCHECK_EQ(type, DataType::Type::kFloat64);
+      __ FClassD(tmp, rs1);
+      __ FClassS(tmp2, rs2);
+    }
+    __ Li(rd, 1);  // Pre-load the result for NaN to avoid branching over it later.
+    __ Or(tmp, tmp, tmp2);
+    __ Srli(tmp, tmp, 8);
+    __ Bnez(tmp, &done);  // goto done if either input was NaN.
+  }
+
+  if (type == DataType::Type::kFloat32) {
+    switch (cond) {
+      case kCondEQ:
+        __ FEqS(rd, rs1, rs2);
+        break;
+      case kCondNE:
+        __ FEqS(rd, rs1, rs2);
+        __ Xori(rd, rd, 1);
+        break;
+      case kCondLT:
+        __ FLtS(rd, rs1, rs2);
+        break;
+      case kCondLE:
+        __ FLeS(rd, rs1, rs2);
+        break;
+      case kCondGT:
+        __ FLtS(rd, rs2, rs1);
+        break;
+      case kCondGE:
+        __ FLeS(rd, rs2, rs1);
+        break;
+      default:
+        LOG(FATAL) << "Unexpected floating-point condition " << cond;
+        UNREACHABLE();
+    }
+  } else {
+    DCHECK_EQ(type, DataType::Type::kFloat64);
+    switch (cond) {
+      case kCondEQ:
+        __ FEqD(rd, rs1, rs2);
+        break;
+      case kCondNE:
+        __ FEqD(rd, rs1, rs2);
+        __ Xori(rd, rd, 1);
+        break;
+      case kCondLT:
+        __ FLtD(rd, rs1, rs2);
+        break;
+      case kCondLE:
+        __ FLeD(rd, rs1, rs2);
+        break;
+      case kCondGT:
+        __ FLtD(rd, rs2, rs1);
+        break;
+      case kCondGE:
+        __ FLeD(rd, rs2, rs1);
+        break;
+      default:
+        LOG(FATAL) << "Unexpected floating-point condition " << cond;
+        UNREACHABLE();
+    }
+  }
+
+  if (done.IsLinked()) {
+    __ Bind(&done);
+  }
 }
 
 bool InstructionCodeGeneratorRISCV64::MaterializeFpCompare(IfCondition cond,
@@ -603,13 +768,68 @@
 }
 
 void LocationsBuilderRISCV64::HandleCondition(HCondition* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  LocationSummary* locations = new (GetGraph()->GetAllocator()) LocationSummary(instruction);
+  switch (instruction->InputAt(0)->GetType()) {
+    case DataType::Type::kFloat32:
+    case DataType::Type::kFloat64:
+      locations->SetInAt(0, Location::RequiresFpuRegister());
+      locations->SetInAt(1, Location::RequiresFpuRegister());
+      break;
+
+    default: {
+      locations->SetInAt(0, Location::RequiresRegister());
+      HInstruction* rhs = instruction->InputAt(1);
+      bool use_imm = false;
+      if (rhs->IsConstant()) {
+        int64_t imm = CodeGenerator::GetInt64ValueOf(rhs->AsConstant());
+        switch (instruction->GetCondition()) {
+          case kCondEQ:
+          case kCondNE:
+            imm = -imm;
+            break;
+          case kCondLE:
+          case kCondGT:
+          case kCondBE:
+          case kCondA:
+            imm += 1;
+            break;
+          default:
+            break;
+        }
+        // Constants that cannot be embedded in an instruction's 12-bit immediate shall be
+        // materialized. This simplifies the code and avoids cases with arithmetic overflow.
+        use_imm = IsInt<12>(imm);
+      }
+      if (use_imm) {
+        locations->SetInAt(1, Location::ConstantLocation(rhs->AsConstant()));
+      } else {
+        locations->SetInAt(1, Location::RequiresRegister());
+      }
+      break;
+    }
+  }
+  if (!instruction->IsEmittedAtUseSite()) {
+    locations->SetOut(Location::RequiresRegister(), Location::kNoOutputOverlap);
+  }
 }
 
 void InstructionCodeGeneratorRISCV64::HandleCondition(HCondition* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  if (instruction->IsEmittedAtUseSite()) {
+    return;
+  }
+
+  DataType::Type type = instruction->InputAt(0)->GetType();
+  LocationSummary* locations = instruction->GetLocations();
+  switch (type) {
+    case DataType::Type::kFloat32:
+    case DataType::Type::kFloat64:
+      GenerateFpCondition(instruction->GetCondition(), instruction->IsGtBias(), type, locations);
+      return;
+    default:
+      // Integral types.
+      GenerateIntLongCondition(instruction->GetCondition(), locations);
+      return;
+  }
 }
 
 void LocationsBuilderRISCV64::HandleShift(HBinaryOperation* instruction) {
@@ -761,23 +981,19 @@
 }
 
 void LocationsBuilderRISCV64::VisitAbove(HAbove* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitAbove(HAbove* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitAboveOrEqual(HAboveOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitAboveOrEqual(HAboveOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitAbs(HAbs* abs) {
@@ -885,23 +1101,19 @@
 }
 
 void LocationsBuilderRISCV64::VisitBelow(HBelow* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitBelow(HBelow* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitBelowOrEqual(HBelowOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitBelowOrEqual(HBelowOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitBooleanNot(HBooleanNot* instruction) {
@@ -1057,13 +1269,11 @@
 }
 
 void LocationsBuilderRISCV64::VisitEqual(HEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitEqual(HEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitExit(HExit* instruction) {
@@ -1092,23 +1302,19 @@
 }
 
 void LocationsBuilderRISCV64::VisitGreaterThan(HGreaterThan* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitGreaterThan(HGreaterThan* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitGreaterThanOrEqual(HGreaterThanOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitGreaterThanOrEqual(HGreaterThanOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitIf(HIf* instruction) {
@@ -1245,23 +1451,19 @@
 }
 
 void LocationsBuilderRISCV64::VisitLessThan(HLessThan* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitLessThan(HLessThan* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitLessThanOrEqual(HLessThanOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitLessThanOrEqual(HLessThanOrEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitLoadClass(HLoadClass* instruction) {
@@ -1445,13 +1647,11 @@
 }
 
 void LocationsBuilderRISCV64::VisitNotEqual(HNotEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void InstructionCodeGeneratorRISCV64::VisitNotEqual(HNotEqual* instruction) {
-  UNUSED(instruction);
-  LOG(FATAL) << "Unimplemented";
+  HandleCondition(instruction);
 }
 
 void LocationsBuilderRISCV64::VisitNullConstant(HNullConstant* instruction) {
diff --git a/compiler/optimizing/code_generator_riscv64.h b/compiler/optimizing/code_generator_riscv64.h
index 4942f77..e024ebd 100644
--- a/compiler/optimizing/code_generator_riscv64.h
+++ b/compiler/optimizing/code_generator_riscv64.h
@@ -218,7 +218,7 @@
   void DivRemByPowerOfTwo(HBinaryOperation* instruction);
   void GenerateDivRemWithAnyConstant(HBinaryOperation* instruction);
   void GenerateDivRemIntegral(HBinaryOperation* instruction);
-  void GenerateIntLongCompare(IfCondition cond, bool is64bit, LocationSummary* locations);
+  void GenerateIntLongCondition(IfCondition cond, LocationSummary* locations);
   // When the function returns `false` it means that the condition holds if `dst` is non-zero
   // and doesn't hold if `dst` is zero. If it returns `true`, the roles of zero and non-zero
   // `dst` are exchanged.
@@ -230,10 +230,10 @@
                                        bool is64bit,
                                        LocationSummary* locations,
                                        Riscv64Label* label);
-  void GenerateFpCompare(IfCondition cond,
-                         bool gt_bias,
-                         DataType::Type type,
-                         LocationSummary* locations);
+  void GenerateFpCondition(IfCondition cond,
+                           bool gt_bias,
+                           DataType::Type type,
+                           LocationSummary* locations);
   // When the function returns `false` it means that the condition holds if `dst` is non-zero
   // and doesn't hold if `dst` is zero. If it returns `true`, the roles of zero and non-zero
   // `dst` are exchanged.