Make periodic loop optimization safe from overflow
The trip number is calculated by doing:
((end - start) + (step - 1)) / step
(Note that we add step - 1 as a way of doing ceiling).
This way of calculating can overflow and produce the wrong result
if e.g. end and start are in opposite sides of the spectrum.
We can add checks to make sure that the calculation doesn't overflow.
Bug: 256858062
Fixes: 256858062
Test: art/test/testrunner/testrunner.py --host --64 -b --optimizing
Change-Id: Ib2868406d2aa8f106dab06c0b44ef1b6208ef41a
diff --git a/compiler/optimizing/induction_var_range.cc b/compiler/optimizing/induction_var_range.cc
index b1f33ab..b54507a 100644
--- a/compiler/optimizing/induction_var_range.cc
+++ b/compiler/optimizing/induction_var_range.cc
@@ -1359,6 +1359,15 @@
HInstruction* x = nullptr;
HInstruction* y = nullptr;
HInstruction* t = nullptr;
+
+ // Overflows when the stride is equal to `1` are fine since the periodicity is
+ // `2` and the lowest bit is the same. Similar with `-1`.
+ auto allow_potential_overflow = [&]() {
+ int64_t stride_value = 0;
+ return IsConstant(context, loop, trip->op_a->op_b, kExact, &stride_value) &&
+ (stride_value == 1 || stride_value == -1);
+ };
+
if (period == 2 &&
GenerateCode(context,
loop,
@@ -1383,7 +1392,8 @@
graph,
block,
/*is_min=*/ false,
- graph ? &t : nullptr)) {
+ graph ? &t : nullptr,
+ allow_potential_overflow())) {
// During actual code generation (graph != nullptr), generate is_even ? x : y.
if (graph != nullptr) {
DataType::Type type = trip->type;
diff --git a/test/618-checker-induction/src/Main.java b/test/618-checker-induction/src/Main.java
index 21cca22..bcd545c 100644
--- a/test/618-checker-induction/src/Main.java
+++ b/test/618-checker-induction/src/Main.java
@@ -672,6 +672,40 @@
return k;
}
+ /// CHECK-START: int Main.periodicOverflowTripCountNotOptimized() loop_optimization (before)
+ /// CHECK-DAG: <<Phi1:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
+ /// CHECK-DAG: {{i\d+}} Phi loop:<<Loop>> outer_loop:none
+ /// CHECK-DAG: Return [<<Phi1>>] loop:none
+ //
+ /// CHECK-START: int Main.periodicOverflowTripCountNotOptimized() loop_optimization (after)
+ /// CHECK-DAG: <<Phi1:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
+ /// CHECK-DAG: {{i\d+}} Phi loop:<<Loop>> outer_loop:none
+ /// CHECK-DAG: Return [<<Phi1>>] loop:none
+ static int periodicOverflowTripCountNotOptimized() {
+ int k = 0;
+ for (int i = Integer.MIN_VALUE; i < Integer.MAX_VALUE - 81; i += 80) {
+ k = 1 - k;
+ }
+ return k;
+ }
+
+ /// CHECK-START: int Main.periodicCouldOverflowTripCountNotOptimized(int) loop_optimization (before)
+ /// CHECK-DAG: <<Phi1:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
+ /// CHECK-DAG: {{i\d+}} Phi loop:<<Loop>> outer_loop:none
+ /// CHECK-DAG: Return [<<Phi1>>] loop:none
+ //
+ /// CHECK-START: int Main.periodicCouldOverflowTripCountNotOptimized(int) loop_optimization (after)
+ /// CHECK-DAG: <<Phi1:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
+ /// CHECK-DAG: {{i\d+}} Phi loop:<<Loop>> outer_loop:none
+ /// CHECK-DAG: Return [<<Phi1>>] loop:none
+ static int periodicCouldOverflowTripCountNotOptimized(int start) {
+ int k = 0;
+ for (int i = start; i < Integer.MAX_VALUE - 81; i += 80) {
+ k = 1 - k;
+ }
+ return k;
+ }
+
// If ever replaced by closed form, last value should be correct!
private static int getSumN(int n) {
int k = 0;
@@ -1138,6 +1172,9 @@
expectEquals((tc * (tc + 1)) / 2, getSumN(n));
}
+ expectEquals(1, periodicOverflowTripCountNotOptimized());
+ expectEquals(1, periodicCouldOverflowTripCountNotOptimized(Integer.MIN_VALUE));
+
expectEquals(10, closedTwice());
expectEquals(20, closedFeed());
expectEquals(-10, closedLargeUp());
diff --git a/test/654-checker-periodic/src/Main.java b/test/654-checker-periodic/src/Main.java
index 7a0c98c..0ae0b31 100644
--- a/test/654-checker-periodic/src/Main.java
+++ b/test/654-checker-periodic/src/Main.java
@@ -50,6 +50,21 @@
return lI;
}
+ /// CHECK-START: int Main.doitDownInt2(int) loop_optimization (before)
+ /// CHECK-DAG: <<Phi:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
+ /// CHECK-DAG: Phi loop:<<Loop>> outer_loop:none
+ //
+ /// CHECK-START: int Main.doitDownInt2(int) loop_optimization (after)
+ /// CHECK-NOT: Phi
+ static int doitDownInt2(int n) {
+ // Complete loop is replaced by last-value.
+ int lI = 1;
+ for (int i1 = n; i1 > 0; i1--) {
+ lI = (1486662021 - lI);
+ }
+ return lI;
+ }
+
/// CHECK-START: float Main.doitUpFloat(int) loop_optimization (before)
/// CHECK-DAG: <<Phi:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
/// CHECK-DAG: Phi loop:<<Loop>> outer_loop:none
@@ -122,6 +137,25 @@
return lF;
}
+ /// CHECK-START: float Main.doitDownFloatAlt2(int) loop_optimization (before)
+ /// CHECK-DAG: <<Phi:i\d+>> Phi loop:<<Loop:B\d+>> outer_loop:none
+ /// CHECK-DAG: Phi loop:<<Loop>> outer_loop:none
+ //
+ /// CHECK-START: float Main.doitDownFloatAlt2(int) loop_optimization (after)
+ /// CHECK-NOT: Phi
+ static float doitDownFloatAlt2(int n) {
+ // Complete loop is replaced by last-value
+ // since the values are now precise.
+ float lF = 1.0f;
+ float l2 = 1486662020.0f;
+ for (int i1 = n; i1 > 0; i1--) {
+ float old = lF;
+ lF = l2;
+ l2 = old;
+ }
+ return lF;
+ }
+
// Main driver.
public static void main(String[] args) {
for (int i = 0; i < 10; i++) {
@@ -135,6 +169,11 @@
expectEquals(ei, ci);
}
for (int i = 0; i < 10; i++) {
+ int ei = (i & 1) == 0 ? 1 : 1486662020;
+ int ci = doitDownInt2(i);
+ expectEquals(ei, ci);
+ }
+ for (int i = 0; i < 10; i++) {
float ef = i == 0 ? 1.0f : ((i & 1) == 0 ? 0.0f : 1486662021.0f);
float cf = doitUpFloat(i);
expectEquals(ef, cf);
@@ -154,6 +193,11 @@
float cf = doitDownFloatAlt(i);
expectEquals(ef, cf);
}
+ for (int i = 0; i < 10; i++) {
+ float ef = (i & 1) == 0 ? 1.0f : 1486662020.0f;
+ float cf = doitDownFloatAlt2(i);
+ expectEquals(ef, cf);
+ }
System.out.println("passed");
}