Simplify floating-point comparisons with NaN in Optimizing.
This change was suggested by Ian.
Also, simplify some art::HFloatConstant and
art::HDoubleConstant methods.
Change-Id: I7908df23581a7f61c8ec79c290fe5f70798ac3be
diff --git a/compiler/optimizing/constant_folding.cc b/compiler/optimizing/constant_folding.cc
index b7a92b5..5a1d9b4 100644
--- a/compiler/optimizing/constant_folding.cc
+++ b/compiler/optimizing/constant_folding.cc
@@ -28,6 +28,7 @@
void VisitShift(HBinaryOperation* shift);
void VisitAnd(HAnd* instruction) OVERRIDE;
+ void VisitCompare(HCompare* instruction) OVERRIDE;
void VisitMul(HMul* instruction) OVERRIDE;
void VisitOr(HOr* instruction) OVERRIDE;
void VisitRem(HRem* instruction) OVERRIDE;
@@ -108,6 +109,26 @@
}
}
+void InstructionWithAbsorbingInputSimplifier::VisitCompare(HCompare* instruction) {
+ HConstant* input_cst = instruction->GetConstantRight();
+ if (input_cst != nullptr) {
+ HInstruction* input_value = instruction->GetLeastConstantLeft();
+ if (Primitive::IsFloatingPointType(input_value->GetType()) &&
+ ((input_cst->IsFloatConstant() && input_cst->AsFloatConstant()->IsNaN()) ||
+ (input_cst->IsDoubleConstant() && input_cst->AsDoubleConstant()->IsNaN()))) {
+ // Replace code looking like
+ // CMP{G,L} dst, src, NaN
+ // with
+ // CONSTANT +1 (gt bias)
+ // or
+ // CONSTANT -1 (lt bias)
+ instruction->ReplaceWith(GetGraph()->GetConstant(Primitive::kPrimInt,
+ (instruction->IsGtBias() ? 1 : -1)));
+ instruction->GetBlock()->RemoveInstruction(instruction);
+ }
+ }
+}
+
void InstructionWithAbsorbingInputSimplifier::VisitMul(HMul* instruction) {
HConstant* input_cst = instruction->GetConstantRight();
Primitive::Type type = instruction->GetType();
diff --git a/compiler/optimizing/nodes.h b/compiler/optimizing/nodes.h
index 5fc0470..53bff91 100644
--- a/compiler/optimizing/nodes.h
+++ b/compiler/optimizing/nodes.h
@@ -2094,15 +2094,16 @@
size_t ComputeHashCode() const OVERRIDE { return static_cast<size_t>(GetValue()); }
bool IsMinusOne() const OVERRIDE {
- return bit_cast<uint32_t, float>(AsFloatConstant()->GetValue()) ==
- bit_cast<uint32_t, float>((-1.0f));
+ return bit_cast<uint32_t, float>(value_) == bit_cast<uint32_t, float>((-1.0f));
}
bool IsZero() const OVERRIDE {
- return AsFloatConstant()->GetValue() == 0.0f;
+ return value_ == 0.0f;
}
bool IsOne() const OVERRIDE {
- return bit_cast<uint32_t, float>(AsFloatConstant()->GetValue()) ==
- bit_cast<uint32_t, float>(1.0f);
+ return bit_cast<uint32_t, float>(value_) == bit_cast<uint32_t, float>(1.0f);
+ }
+ bool IsNaN() const {
+ return std::isnan(value_);
}
DECLARE_INSTRUCTION(FloatConstant);
@@ -2132,15 +2133,16 @@
size_t ComputeHashCode() const OVERRIDE { return static_cast<size_t>(GetValue()); }
bool IsMinusOne() const OVERRIDE {
- return bit_cast<uint64_t, double>(AsDoubleConstant()->GetValue()) ==
- bit_cast<uint64_t, double>((-1.0));
+ return bit_cast<uint64_t, double>(value_) == bit_cast<uint64_t, double>((-1.0));
}
bool IsZero() const OVERRIDE {
- return AsDoubleConstant()->GetValue() == 0.0;
+ return value_ == 0.0;
}
bool IsOne() const OVERRIDE {
- return bit_cast<uint64_t, double>(AsDoubleConstant()->GetValue()) ==
- bit_cast<uint64_t, double>(1.0);
+ return bit_cast<uint64_t, double>(value_) == bit_cast<uint64_t, double>(1.0);
+ }
+ bool IsNaN() const {
+ return std::isnan(value_);
}
DECLARE_INSTRUCTION(DoubleConstant);
diff --git a/test/442-checker-constant-folding/src/Main.java b/test/442-checker-constant-folding/src/Main.java
index 6b21fed..c89ab4d 100644
--- a/test/442-checker-constant-folding/src/Main.java
+++ b/test/442-checker-constant-folding/src/Main.java
@@ -16,6 +16,12 @@
public class Main {
+ public static void assertFalse(boolean condition) {
+ if (condition) {
+ throw new Error();
+ }
+ }
+
public static void assertIntEquals(int expected, int result) {
if (expected != result) {
throw new Error("Expected: " + expected + ", found: " + result);
@@ -407,6 +413,54 @@
return arg ^ arg;
}
+ // CHECK-START: boolean Main.CmpFloatGreaterThanNaN(float) constant_folding (before)
+ // CHECK-DAG: [[Arg:f\d+]] ParameterValue
+ // CHECK-DAG: [[ConstNan:f\d+]] FloatConstant nan
+ // CHECK-DAG: [[Const0:i\d+]] IntConstant 0
+ // CHECK-DAG: IntConstant 1
+ // CHECK-DAG: [[Cmp:i\d+]] Compare [ [[Arg]] [[ConstNan]] ]
+ // CHECK-DAG: [[Le:z\d+]] LessThanOrEqual [ [[Cmp]] [[Const0]] ]
+ // CHECK-DAG: If [ [[Le]] ]
+
+ // CHECK-START: boolean Main.CmpFloatGreaterThanNaN(float) constant_folding (after)
+ // CHECK-DAG: ParameterValue
+ // CHECK-DAG: FloatConstant nan
+ // CHECK-DAG: IntConstant 0
+ // CHECK-DAG: [[Const1:i\d+]] IntConstant 1
+ // CHECK-DAG: If [ [[Const1]] ]
+
+ // CHECK-START: boolean Main.CmpFloatGreaterThanNaN(float) constant_folding (after)
+ // CHECK-NOT: Compare
+ // CHECK-NOT: LessThanOrEqual
+
+ public static boolean CmpFloatGreaterThanNaN(float arg) {
+ return arg > Float.NaN;
+ }
+
+ // CHECK-START: boolean Main.CmpDoubleLessThanNaN(double) constant_folding (before)
+ // CHECK-DAG: [[Arg:d\d+]] ParameterValue
+ // CHECK-DAG: [[ConstNan:d\d+]] DoubleConstant nan
+ // CHECK-DAG: [[Const0:i\d+]] IntConstant 0
+ // CHECK-DAG: IntConstant 1
+ // CHECK-DAG: [[Cmp:i\d+]] Compare [ [[Arg]] [[ConstNan]] ]
+ // CHECK-DAG: [[Ge:z\d+]] GreaterThanOrEqual [ [[Cmp]] [[Const0]] ]
+ // CHECK-DAG: If [ [[Ge]] ]
+
+ // CHECK-START: boolean Main.CmpDoubleLessThanNaN(double) constant_folding (after)
+ // CHECK-DAG: ParameterValue
+ // CHECK-DAG: DoubleConstant nan
+ // CHECK-DAG: IntConstant 0
+ // CHECK-DAG: [[Const1:i\d+]] IntConstant 1
+ // CHECK-DAG: If [ [[Const1]] ]
+
+ // CHECK-START: boolean Main.CmpDoubleLessThanNaN(double) constant_folding (after)
+ // CHECK-NOT: Compare
+ // CHECK-NOT: GreaterThanOrEqual
+
+ public static boolean CmpDoubleLessThanNaN(double arg) {
+ return arg < Double.NaN;
+ }
+
public static void main(String[] args) {
assertIntEquals(IntNegation(), -42);
assertIntEquals(IntAddition1(), 3);
@@ -417,17 +471,19 @@
assertIntEquals(StaticCondition(), 5);
assertIntEquals(JumpsAndConditionals(true), 7);
assertIntEquals(JumpsAndConditionals(false), 3);
- int random = 123456; // Chosen randomly.
- assertIntEquals(And0(random), 0);
- assertLongEquals(Mul0(random), 0);
- assertIntEquals(OrAllOnes(random), -1);
- assertLongEquals(Rem0(random), 0);
- assertIntEquals(Rem1(random), 0);
- assertLongEquals(RemN1(random), 0);
- assertIntEquals(Shl0(random), 0);
- assertLongEquals(Shr0(random), 0);
- assertLongEquals(SubSameLong(random), 0);
- assertIntEquals(UShr0(random), 0);
- assertIntEquals(XorSameInt(random), 0);
+ int arbitrary = 123456; // Value chosen arbitrarily.
+ assertIntEquals(And0(arbitrary), 0);
+ assertLongEquals(Mul0(arbitrary), 0);
+ assertIntEquals(OrAllOnes(arbitrary), -1);
+ assertLongEquals(Rem0(arbitrary), 0);
+ assertIntEquals(Rem1(arbitrary), 0);
+ assertLongEquals(RemN1(arbitrary), 0);
+ assertIntEquals(Shl0(arbitrary), 0);
+ assertLongEquals(Shr0(arbitrary), 0);
+ assertLongEquals(SubSameLong(arbitrary), 0);
+ assertIntEquals(UShr0(arbitrary), 0);
+ assertIntEquals(XorSameInt(arbitrary), 0);
+ assertFalse(CmpFloatGreaterThanNaN(arbitrary));
+ assertFalse(CmpDoubleLessThanNaN(arbitrary));
}
}