diff options
| author | 2018-11-22 16:47:16 +0000 | |
|---|---|---|
| committer | 2018-12-03 18:49:10 +0000 | |
| commit | 3a16a8f56e989a13419abb554a6f14a9a35570b7 (patch) | |
| tree | e64002a4adf1d3fc56880276dfbdb95016a9b0bc | |
| parent | aa6f48362b3258a5df5e527987ffe7e068eb4a79 (diff) | |
RFC: ART: ARM64: Support SDOT/UDOT instructions.
Support generation of Armv8.4-A dot product instructions.
Test: 684-checker-simd-dotprod.
Test: test-art-host, test-art-target.
Change-Id: Ia5ea8b59644fddb7db9bf22c4f9c24e43529e4bf
| -rw-r--r-- | compiler/optimizing/code_generator_vector_arm64.cc | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/compiler/optimizing/code_generator_vector_arm64.cc b/compiler/optimizing/code_generator_vector_arm64.cc index 5a18c1f72b..9506ff83a7 100644 --- a/compiler/optimizing/code_generator_vector_arm64.cc +++ b/compiler/optimizing/code_generator_vector_arm64.cc @@ -16,6 +16,7 @@ #include "code_generator_arm64.h" +#include "arch/arm64/instruction_set_features_arm64.h" #include "mirror/array-inl.h" #include "mirror/string.h" @@ -37,6 +38,14 @@ using helpers::XRegisterFrom; #define __ GetVIXLAssembler()-> +// Build-time switch for Armv8.4-a dot product instructions. +static constexpr bool kArm64EmitDotProdInstructions = true; + +// Returns whether dot product instructions should be emitted. +static bool ShouldEmitDotProductInstructions(const CodeGeneratorARM64* codegen_) { + return kArm64EmitDotProdInstructions && codegen_->GetInstructionSetFeatures().HasDotProd(); +} + void LocationsBuilderARM64::VisitVecReplicateScalar(HVecReplicateScalar* instruction) { LocationSummary* locations = new (GetGraph()->GetAllocator()) LocationSummary(instruction); HInstruction* input = instruction->InputAt(0); @@ -1285,8 +1294,9 @@ void LocationsBuilderARM64::VisitVecDotProd(HVecDotProd* instruction) { locations->SetInAt(2, Location::RequiresFpuRegister()); locations->SetOut(Location::SameAsFirstInput()); - // For Int8 and Uint8 we need a temp register. - if (DataType::Size(instruction->InputAt(1)->AsVecOperation()->GetPackedType()) == 1) { + // For Int8 and Uint8 general case we need a temp register. + if ((DataType::Size(instruction->InputAt(1)->AsVecOperation()->GetPackedType()) == 1) && + !ShouldEmitDotProductInstructions(codegen_)) { locations->AddTemp(Location::RequiresFpuRegister()); } } @@ -1308,25 +1318,32 @@ void InstructionCodeGeneratorARM64::VisitVecDotProd(HVecDotProd* instruction) { switch (inputs_data_size) { case 1u: { DCHECK_EQ(16u, a->GetVectorLength()); - VRegister tmp = VRegisterFrom(locations->GetTemp(0)); if (instruction->IsZeroExtending()) { - // TODO: Use Armv8.4-A UDOT instruction when it is available. - __ Umull(tmp.V8H(), left.V8B(), right.V8B()); - __ Uaddw(acc.V4S(), acc.V4S(), tmp.V4H()); - __ Uaddw2(acc.V4S(), acc.V4S(), tmp.V8H()); - - __ Umull2(tmp.V8H(), left.V16B(), right.V16B()); - __ Uaddw(acc.V4S(), acc.V4S(), tmp.V4H()); - __ Uaddw2(acc.V4S(), acc.V4S(), tmp.V8H()); + if (ShouldEmitDotProductInstructions(codegen_)) { + __ Udot(acc.V4S(), left.V16B(), right.V16B()); + } else { + VRegister tmp = VRegisterFrom(locations->GetTemp(0)); + __ Umull(tmp.V8H(), left.V8B(), right.V8B()); + __ Uaddw(acc.V4S(), acc.V4S(), tmp.V4H()); + __ Uaddw2(acc.V4S(), acc.V4S(), tmp.V8H()); + + __ Umull2(tmp.V8H(), left.V16B(), right.V16B()); + __ Uaddw(acc.V4S(), acc.V4S(), tmp.V4H()); + __ Uaddw2(acc.V4S(), acc.V4S(), tmp.V8H()); + } } else { - // TODO: Use Armv8.4-A SDOT instruction when it is available. - __ Smull(tmp.V8H(), left.V8B(), right.V8B()); - __ Saddw(acc.V4S(), acc.V4S(), tmp.V4H()); - __ Saddw2(acc.V4S(), acc.V4S(), tmp.V8H()); - - __ Smull2(tmp.V8H(), left.V16B(), right.V16B()); - __ Saddw(acc.V4S(), acc.V4S(), tmp.V4H()); - __ Saddw2(acc.V4S(), acc.V4S(), tmp.V8H()); + if (ShouldEmitDotProductInstructions(codegen_)) { + __ Sdot(acc.V4S(), left.V16B(), right.V16B()); + } else { + VRegister tmp = VRegisterFrom(locations->GetTemp(0)); + __ Smull(tmp.V8H(), left.V8B(), right.V8B()); + __ Saddw(acc.V4S(), acc.V4S(), tmp.V4H()); + __ Saddw2(acc.V4S(), acc.V4S(), tmp.V8H()); + + __ Smull2(tmp.V8H(), left.V16B(), right.V16B()); + __ Saddw(acc.V4S(), acc.V4S(), tmp.V4H()); + __ Saddw2(acc.V4S(), acc.V4S(), tmp.V8H()); + } } break; } |