Implement Sum-of-Abs-Differences idiom recognition.
Rationale:
Currently just on ARM64 (x86 lacks proper support),
using the SAD idiom yields great speedup on loops
that compute the sum-of-abs-difference operation.
Also includes some refinements around type conversions.
Speedup ExoPlayerAudio (golem run):
1.3x on ARM64
1.1x on x86
Test: test-art-host test-art-target
Bug: 64091002
Change-Id: Ia2b711d2bc23609a2ed50493dfe6719eedfe0130
diff --git a/compiler/optimizing/code_generator_vector_arm64.cc b/compiler/optimizing/code_generator_vector_arm64.cc
index 18a55c8..3f576c8 100644
--- a/compiler/optimizing/code_generator_vector_arm64.cc
+++ b/compiler/optimizing/code_generator_vector_arm64.cc
@@ -949,20 +949,18 @@
}
}
-void LocationsBuilderARM64::VisitVecMultiplyAccumulate(HVecMultiplyAccumulate* instr) {
- LocationSummary* locations = new (GetGraph()->GetArena()) LocationSummary(instr);
- switch (instr->GetPackedType()) {
+// Helper to set up locations for vector accumulations.
+static void CreateVecAccumLocations(ArenaAllocator* arena, HVecOperation* instruction) {
+ LocationSummary* locations = new (arena) LocationSummary(instruction);
+ switch (instruction->GetPackedType()) {
case Primitive::kPrimByte:
case Primitive::kPrimChar:
case Primitive::kPrimShort:
case Primitive::kPrimInt:
- locations->SetInAt(
- HVecMultiplyAccumulate::kInputAccumulatorIndex, Location::RequiresFpuRegister());
- locations->SetInAt(
- HVecMultiplyAccumulate::kInputMulLeftIndex, Location::RequiresFpuRegister());
- locations->SetInAt(
- HVecMultiplyAccumulate::kInputMulRightIndex, Location::RequiresFpuRegister());
- DCHECK_EQ(HVecMultiplyAccumulate::kInputAccumulatorIndex, 0);
+ case Primitive::kPrimLong:
+ locations->SetInAt(0, Location::RequiresFpuRegister());
+ locations->SetInAt(1, Location::RequiresFpuRegister());
+ locations->SetInAt(2, Location::RequiresFpuRegister());
locations->SetOut(Location::SameAsFirstInput());
break;
default:
@@ -971,18 +969,25 @@
}
}
+void LocationsBuilderARM64::VisitVecMultiplyAccumulate(HVecMultiplyAccumulate* instruction) {
+ CreateVecAccumLocations(GetGraph()->GetArena(), instruction);
+}
+
// Some early revisions of the Cortex-A53 have an erratum (835769) whereby it is possible for a
// 64-bit scalar multiply-accumulate instruction in AArch64 state to generate an incorrect result.
// However vector MultiplyAccumulate instruction is not affected.
-void InstructionCodeGeneratorARM64::VisitVecMultiplyAccumulate(HVecMultiplyAccumulate* instr) {
- LocationSummary* locations = instr->GetLocations();
- VRegister acc = VRegisterFrom(locations->InAt(HVecMultiplyAccumulate::kInputAccumulatorIndex));
- VRegister left = VRegisterFrom(locations->InAt(HVecMultiplyAccumulate::kInputMulLeftIndex));
- VRegister right = VRegisterFrom(locations->InAt(HVecMultiplyAccumulate::kInputMulRightIndex));
- switch (instr->GetPackedType()) {
+void InstructionCodeGeneratorARM64::VisitVecMultiplyAccumulate(HVecMultiplyAccumulate* instruction) {
+ LocationSummary* locations = instruction->GetLocations();
+ VRegister acc = VRegisterFrom(locations->InAt(0));
+ VRegister left = VRegisterFrom(locations->InAt(1));
+ VRegister right = VRegisterFrom(locations->InAt(2));
+
+ DCHECK(locations->InAt(0).Equals(locations->Out()));
+
+ switch (instruction->GetPackedType()) {
case Primitive::kPrimByte:
- DCHECK_EQ(16u, instr->GetVectorLength());
- if (instr->GetOpKind() == HInstruction::kAdd) {
+ DCHECK_EQ(16u, instruction->GetVectorLength());
+ if (instruction->GetOpKind() == HInstruction::kAdd) {
__ Mla(acc.V16B(), left.V16B(), right.V16B());
} else {
__ Mls(acc.V16B(), left.V16B(), right.V16B());
@@ -990,16 +995,16 @@
break;
case Primitive::kPrimChar:
case Primitive::kPrimShort:
- DCHECK_EQ(8u, instr->GetVectorLength());
- if (instr->GetOpKind() == HInstruction::kAdd) {
+ DCHECK_EQ(8u, instruction->GetVectorLength());
+ if (instruction->GetOpKind() == HInstruction::kAdd) {
__ Mla(acc.V8H(), left.V8H(), right.V8H());
} else {
__ Mls(acc.V8H(), left.V8H(), right.V8H());
}
break;
case Primitive::kPrimInt:
- DCHECK_EQ(4u, instr->GetVectorLength());
- if (instr->GetOpKind() == HInstruction::kAdd) {
+ DCHECK_EQ(4u, instruction->GetVectorLength());
+ if (instruction->GetOpKind() == HInstruction::kAdd) {
__ Mla(acc.V4S(), left.V4S(), right.V4S());
} else {
__ Mls(acc.V4S(), left.V4S(), right.V4S());
@@ -1007,6 +1012,186 @@
break;
default:
LOG(FATAL) << "Unsupported SIMD type";
+ UNREACHABLE();
+ }
+}
+
+void LocationsBuilderARM64::VisitVecSADAccumulate(HVecSADAccumulate* instruction) {
+ CreateVecAccumLocations(GetGraph()->GetArena(), instruction);
+ // Some conversions require temporary registers.
+ LocationSummary* locations = instruction->GetLocations();
+ HVecOperation* a = instruction->InputAt(1)->AsVecOperation();
+ HVecOperation* b = instruction->InputAt(2)->AsVecOperation();
+ DCHECK_EQ(a->GetPackedType(), b->GetPackedType());
+ switch (a->GetPackedType()) {
+ case Primitive::kPrimByte:
+ switch (instruction->GetPackedType()) {
+ case Primitive::kPrimLong:
+ locations->AddTemp(Location::RequiresFpuRegister());
+ locations->AddTemp(Location::RequiresFpuRegister());
+ FALLTHROUGH_INTENDED;
+ case Primitive::kPrimInt:
+ locations->AddTemp(Location::RequiresFpuRegister());
+ locations->AddTemp(Location::RequiresFpuRegister());
+ break;
+ default:
+ break;
+ }
+ break;
+ case Primitive::kPrimChar:
+ case Primitive::kPrimShort:
+ if (instruction->GetPackedType() == Primitive::kPrimLong) {
+ locations->AddTemp(Location::RequiresFpuRegister());
+ locations->AddTemp(Location::RequiresFpuRegister());
+ }
+ break;
+ case Primitive::kPrimInt:
+ case Primitive::kPrimLong:
+ if (instruction->GetPackedType() == a->GetPackedType()) {
+ locations->AddTemp(Location::RequiresFpuRegister());
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+void InstructionCodeGeneratorARM64::VisitVecSADAccumulate(HVecSADAccumulate* instruction) {
+ LocationSummary* locations = instruction->GetLocations();
+ VRegister acc = VRegisterFrom(locations->InAt(0));
+ VRegister left = VRegisterFrom(locations->InAt(1));
+ VRegister right = VRegisterFrom(locations->InAt(2));
+
+ DCHECK(locations->InAt(0).Equals(locations->Out()));
+
+ // Handle all feasible acc_T += sad(a_S, b_S) type combinations (T x S).
+ HVecOperation* a = instruction->InputAt(1)->AsVecOperation();
+ HVecOperation* b = instruction->InputAt(2)->AsVecOperation();
+ DCHECK_EQ(a->GetPackedType(), b->GetPackedType());
+ switch (a->GetPackedType()) {
+ case Primitive::kPrimByte:
+ DCHECK_EQ(16u, a->GetVectorLength());
+ switch (instruction->GetPackedType()) {
+ case Primitive::kPrimChar:
+ case Primitive::kPrimShort:
+ DCHECK_EQ(8u, instruction->GetVectorLength());
+ __ Sabal(acc.V8H(), left.V8B(), right.V8B());
+ __ Sabal2(acc.V8H(), left.V16B(), right.V16B());
+ break;
+ case Primitive::kPrimInt: {
+ DCHECK_EQ(4u, instruction->GetVectorLength());
+ VRegister tmp1 = VRegisterFrom(locations->GetTemp(0));
+ VRegister tmp2 = VRegisterFrom(locations->GetTemp(1));
+ __ Sxtl(tmp1.V8H(), left.V8B());
+ __ Sxtl(tmp2.V8H(), right.V8B());
+ __ Sabal(acc.V4S(), tmp1.V4H(), tmp2.V4H());
+ __ Sabal2(acc.V4S(), tmp1.V8H(), tmp2.V8H());
+ __ Sxtl2(tmp1.V8H(), left.V16B());
+ __ Sxtl2(tmp2.V8H(), right.V16B());
+ __ Sabal(acc.V4S(), tmp1.V4H(), tmp2.V4H());
+ __ Sabal2(acc.V4S(), tmp1.V8H(), tmp2.V8H());
+ break;
+ }
+ case Primitive::kPrimLong: {
+ DCHECK_EQ(2u, instruction->GetVectorLength());
+ VRegister tmp1 = VRegisterFrom(locations->GetTemp(0));
+ VRegister tmp2 = VRegisterFrom(locations->GetTemp(1));
+ VRegister tmp3 = VRegisterFrom(locations->GetTemp(2));
+ VRegister tmp4 = VRegisterFrom(locations->GetTemp(3));
+ __ Sxtl(tmp1.V8H(), left.V8B());
+ __ Sxtl(tmp2.V8H(), right.V8B());
+ __ Sxtl(tmp3.V4S(), tmp1.V4H());
+ __ Sxtl(tmp4.V4S(), tmp2.V4H());
+ __ Sabal(acc.V2D(), tmp3.V2S(), tmp4.V2S());
+ __ Sabal2(acc.V2D(), tmp3.V4S(), tmp4.V4S());
+ __ Sxtl2(tmp3.V4S(), tmp1.V8H());
+ __ Sxtl2(tmp4.V4S(), tmp2.V8H());
+ __ Sabal(acc.V2D(), tmp3.V2S(), tmp4.V2S());
+ __ Sabal2(acc.V2D(), tmp3.V4S(), tmp4.V4S());
+ __ Sxtl2(tmp1.V8H(), left.V16B());
+ __ Sxtl2(tmp2.V8H(), right.V16B());
+ __ Sxtl(tmp3.V4S(), tmp1.V4H());
+ __ Sxtl(tmp4.V4S(), tmp2.V4H());
+ __ Sabal(acc.V2D(), tmp3.V2S(), tmp4.V2S());
+ __ Sabal2(acc.V2D(), tmp3.V4S(), tmp4.V4S());
+ __ Sxtl2(tmp3.V4S(), tmp1.V8H());
+ __ Sxtl2(tmp4.V4S(), tmp2.V8H());
+ __ Sabal(acc.V2D(), tmp3.V2S(), tmp4.V2S());
+ __ Sabal2(acc.V2D(), tmp3.V4S(), tmp4.V4S());
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unsupported SIMD type";
+ UNREACHABLE();
+ }
+ break;
+ case Primitive::kPrimChar:
+ case Primitive::kPrimShort:
+ DCHECK_EQ(8u, a->GetVectorLength());
+ switch (instruction->GetPackedType()) {
+ case Primitive::kPrimInt:
+ DCHECK_EQ(4u, instruction->GetVectorLength());
+ __ Sabal(acc.V4S(), left.V4H(), right.V4H());
+ __ Sabal2(acc.V4S(), left.V8H(), right.V8H());
+ break;
+ case Primitive::kPrimLong: {
+ DCHECK_EQ(2u, instruction->GetVectorLength());
+ VRegister tmp1 = VRegisterFrom(locations->GetTemp(0));
+ VRegister tmp2 = VRegisterFrom(locations->GetTemp(1));
+ __ Sxtl(tmp1.V4S(), left.V4H());
+ __ Sxtl(tmp2.V4S(), right.V4H());
+ __ Sabal(acc.V2D(), tmp1.V2S(), tmp2.V2S());
+ __ Sabal2(acc.V2D(), tmp1.V4S(), tmp2.V4S());
+ __ Sxtl2(tmp1.V4S(), left.V8H());
+ __ Sxtl2(tmp2.V4S(), right.V8H());
+ __ Sabal(acc.V2D(), tmp1.V2S(), tmp2.V2S());
+ __ Sabal2(acc.V2D(), tmp1.V4S(), tmp2.V4S());
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unsupported SIMD type";
+ UNREACHABLE();
+ }
+ break;
+ case Primitive::kPrimInt:
+ DCHECK_EQ(4u, a->GetVectorLength());
+ switch (instruction->GetPackedType()) {
+ case Primitive::kPrimInt: {
+ DCHECK_EQ(4u, instruction->GetVectorLength());
+ VRegister tmp = VRegisterFrom(locations->GetTemp(0));
+ __ Sub(tmp.V4S(), left.V4S(), right.V4S());
+ __ Abs(tmp.V4S(), tmp.V4S());
+ __ Add(acc.V4S(), acc.V4S(), tmp.V4S());
+ break;
+ }
+ case Primitive::kPrimLong:
+ DCHECK_EQ(2u, instruction->GetVectorLength());
+ __ Sabal(acc.V2D(), left.V2S(), right.V2S());
+ __ Sabal2(acc.V2D(), left.V4S(), right.V4S());
+ break;
+ default:
+ LOG(FATAL) << "Unsupported SIMD type";
+ UNREACHABLE();
+ }
+ break;
+ case Primitive::kPrimLong:
+ DCHECK_EQ(2u, a->GetVectorLength());
+ switch (instruction->GetPackedType()) {
+ case Primitive::kPrimLong: {
+ DCHECK_EQ(2u, instruction->GetVectorLength());
+ VRegister tmp = VRegisterFrom(locations->GetTemp(0));
+ __ Sub(tmp.V2D(), left.V2D(), right.V2D());
+ __ Abs(tmp.V2D(), tmp.V2D());
+ __ Add(acc.V2D(), acc.V2D(), tmp.V2D());
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unsupported SIMD type";
+ UNREACHABLE();
+ }
+ break;
+ default:
+ LOG(FATAL) << "Unsupported SIMD type";
}
}