diff options
-rw-r--r-- | compiler/optimizing/intrinsics_arm64.cc | 8 | ||||
-rw-r--r-- | compiler/optimizing/intrinsics_arm_vixl.cc | 168 | ||||
-rw-r--r-- | test/021-string2/src/Main.java | 108 |
3 files changed, 212 insertions, 72 deletions
diff --git a/compiler/optimizing/intrinsics_arm64.cc b/compiler/optimizing/intrinsics_arm64.cc index 3fdad8cca5..1ed1b7537e 100644 --- a/compiler/optimizing/intrinsics_arm64.cc +++ b/compiler/optimizing/intrinsics_arm64.cc @@ -1633,12 +1633,13 @@ void IntrinsicCodeGeneratorARM64::VisitStringEquals(HInvoke* invoke) { } // Assertions that must hold in order to compare strings 8 bytes at a time. + // Ok to do this because strings are zero-padded to kObjectAlignment. DCHECK_ALIGNED(value_offset, 8); static_assert(IsAligned<8>(kObjectAlignment), "String of odd length is not zero padded"); if (const_string != nullptr && - const_string_length < (is_compressed ? kShortConstStringEqualsCutoffInBytes - : kShortConstStringEqualsCutoffInBytes / 2u)) { + const_string_length <= (is_compressed ? kShortConstStringEqualsCutoffInBytes + : kShortConstStringEqualsCutoffInBytes / 2u)) { // Load and compare the contents. Though we know the contents of the short const string // at compile time, materializing constants may be more code than loading from memory. int32_t offset = value_offset; @@ -1646,7 +1647,7 @@ void IntrinsicCodeGeneratorARM64::VisitStringEquals(HInvoke* invoke) { RoundUp(is_compressed ? const_string_length : const_string_length * 2u, 8u); temp = temp.X(); temp1 = temp1.X(); - while (remaining_bytes > 8u) { + while (remaining_bytes > sizeof(uint64_t)) { Register temp2 = XRegisterFrom(locations->GetTemp(0)); __ Ldp(temp, temp1, MemOperand(str.X(), offset)); __ Ldp(temp2, out, MemOperand(arg.X(), offset)); @@ -1682,7 +1683,6 @@ void IntrinsicCodeGeneratorARM64::VisitStringEquals(HInvoke* invoke) { temp1 = temp1.X(); Register temp2 = XRegisterFrom(locations->GetTemp(0)); // Loop to compare strings 8 bytes at a time starting at the front of the string. - // Ok to do this because strings are zero-padded to kObjectAlignment. __ Bind(&loop); __ Ldr(out, MemOperand(str.X(), temp1)); __ Ldr(temp2, MemOperand(arg.X(), temp1)); diff --git a/compiler/optimizing/intrinsics_arm_vixl.cc b/compiler/optimizing/intrinsics_arm_vixl.cc index 82a97bccb7..76c1410340 100644 --- a/compiler/optimizing/intrinsics_arm_vixl.cc +++ b/compiler/optimizing/intrinsics_arm_vixl.cc @@ -1721,6 +1721,22 @@ void IntrinsicCodeGeneratorARMVIXL::VisitStringCompareTo(HInvoke* invoke) { } } +// The cut off for unrolling the loop in String.equals() intrinsic for const strings. +// The normal loop plus the pre-header is 9 instructions (18-26 bytes) without string compression +// and 12 instructions (24-32 bytes) with string compression. We can compare up to 4 bytes in 4 +// instructions (LDR+LDR+CMP+BNE) and up to 8 bytes in 6 instructions (LDRD+LDRD+CMP+BNE+CMP+BNE). +// Allow up to 12 instructions (32 bytes) for the unrolled loop. +constexpr size_t kShortConstStringEqualsCutoffInBytes = 16; + +static const char* GetConstString(HInstruction* candidate, uint32_t* utf16_length) { + if (candidate->IsLoadString()) { + HLoadString* load_string = candidate->AsLoadString(); + const DexFile& dex_file = load_string->GetDexFile(); + return dex_file.StringDataAndUtf16LengthByIdx(load_string->GetStringIndex(), utf16_length); + } + return nullptr; +} + void IntrinsicLocationsBuilderARMVIXL::VisitStringEquals(HInvoke* invoke) { LocationSummary* locations = new (arena_) LocationSummary(invoke, LocationSummary::kNoCall, @@ -1728,12 +1744,29 @@ void IntrinsicLocationsBuilderARMVIXL::VisitStringEquals(HInvoke* invoke) { InvokeRuntimeCallingConventionARMVIXL calling_convention; locations->SetInAt(0, Location::RequiresRegister()); locations->SetInAt(1, Location::RequiresRegister()); + // Temporary registers to store lengths of strings and for calculations. // Using instruction cbz requires a low register, so explicitly set a temp to be R0. locations->AddTemp(LocationFrom(r0)); - locations->AddTemp(Location::RequiresRegister()); - locations->AddTemp(Location::RequiresRegister()); + // For the generic implementation and for long const strings we need an extra temporary. + // We do not need it for short const strings, up to 4 bytes, see code generation below. + uint32_t const_string_length = 0u; + const char* const_string = GetConstString(invoke->InputAt(0), &const_string_length); + if (const_string == nullptr) { + const_string = GetConstString(invoke->InputAt(1), &const_string_length); + } + bool is_compressed = + mirror::kUseStringCompression && + const_string != nullptr && + mirror::String::DexFileStringAllASCII(const_string, const_string_length); + if (const_string == nullptr || const_string_length > (is_compressed ? 4u : 2u)) { + locations->AddTemp(Location::RequiresRegister()); + } + + // TODO: If the String.equals() is used only for an immediately following HIf, we can + // mark it as emitted-at-use-site and emit branches directly to the appropriate blocks. + // Then we shall need an extra temporary register instead of the output register. locations->SetOut(Location::RequiresRegister()); } @@ -1746,8 +1779,6 @@ void IntrinsicCodeGeneratorARMVIXL::VisitStringEquals(HInvoke* invoke) { vixl32::Register out = OutputRegister(invoke); vixl32::Register temp = RegisterFrom(locations->GetTemp(0)); - vixl32::Register temp1 = RegisterFrom(locations->GetTemp(1)); - vixl32::Register temp2 = RegisterFrom(locations->GetTemp(2)); vixl32::Label loop; vixl32::Label end; @@ -1779,52 +1810,109 @@ void IntrinsicCodeGeneratorARMVIXL::VisitStringEquals(HInvoke* invoke) { // Receiver must be a string object, so its class field is equal to all strings' class fields. // If the argument is a string object, its class field must be equal to receiver's class field. __ Ldr(temp, MemOperand(str, class_offset)); - __ Ldr(temp1, MemOperand(arg, class_offset)); - __ Cmp(temp, temp1); + __ Ldr(out, MemOperand(arg, class_offset)); + __ Cmp(temp, out); __ B(ne, &return_false, /* far_target */ false); } - // Load `count` fields of this and argument strings. - __ Ldr(temp, MemOperand(str, count_offset)); - __ Ldr(temp1, MemOperand(arg, count_offset)); - // Check if `count` fields are equal, return false if they're not. - // Also compares the compression style, if differs return false. - __ Cmp(temp, temp1); - __ B(ne, &return_false, /* far_target */ false); - // Return true if both strings are empty. Even with string compression `count == 0` means empty. - static_assert(static_cast<uint32_t>(mirror::StringCompressionFlag::kCompressed) == 0u, - "Expecting 0=compressed, 1=uncompressed"); - __ CompareAndBranchIfZero(temp, &return_true, /* far_target */ false); + // Check if one of the inputs is a const string. Do not special-case both strings + // being const, such cases should be handled by constant folding if needed. + uint32_t const_string_length = 0u; + const char* const_string = GetConstString(invoke->InputAt(0), &const_string_length); + if (const_string == nullptr) { + const_string = GetConstString(invoke->InputAt(1), &const_string_length); + if (const_string != nullptr) { + std::swap(str, arg); // Make sure the const string is in `str`. + } + } + bool is_compressed = + mirror::kUseStringCompression && + const_string != nullptr && + mirror::String::DexFileStringAllASCII(const_string, const_string_length); + + if (const_string != nullptr) { + // Load `count` field of the argument string and check if it matches the const string. + // Also compares the compression style, if differs return false. + __ Ldr(temp, MemOperand(arg, count_offset)); + __ Cmp(temp, Operand(mirror::String::GetFlaggedCount(const_string_length, is_compressed))); + __ B(ne, &return_false, /* far_target */ false); + } else { + // Load `count` fields of this and argument strings. + __ Ldr(temp, MemOperand(str, count_offset)); + __ Ldr(out, MemOperand(arg, count_offset)); + // Check if `count` fields are equal, return false if they're not. + // Also compares the compression style, if differs return false. + __ Cmp(temp, out); + __ B(ne, &return_false, /* far_target */ false); + } // Assertions that must hold in order to compare strings 4 bytes at a time. + // Ok to do this because strings are zero-padded to kObjectAlignment. DCHECK_ALIGNED(value_offset, 4); static_assert(IsAligned<4>(kObjectAlignment), "String data must be aligned for fast compare."); - if (mirror::kUseStringCompression) { - // For string compression, calculate the number of bytes to compare (not chars). - // This could in theory exceed INT32_MAX, so treat temp as unsigned. - __ Lsrs(temp, temp, 1u); // Extract length and check compression flag. - ExactAssemblyScope aas(assembler->GetVIXLAssembler(), - 2 * kMaxInstructionSizeInBytes, - CodeBufferCheckScope::kMaximumSize); - __ it(cs); // If uncompressed, - __ add(cs, temp, temp, temp); // double the byte count. - } + if (const_string != nullptr && + const_string_length <= (is_compressed ? kShortConstStringEqualsCutoffInBytes + : kShortConstStringEqualsCutoffInBytes / 2u)) { + // Load and compare the contents. Though we know the contents of the short const string + // at compile time, materializing constants may be more code than loading from memory. + int32_t offset = value_offset; + size_t remaining_bytes = + RoundUp(is_compressed ? const_string_length : const_string_length * 2u, 4u); + while (remaining_bytes > sizeof(uint32_t)) { + vixl32::Register temp1 = RegisterFrom(locations->GetTemp(1)); + UseScratchRegisterScope scratch_scope(assembler->GetVIXLAssembler()); + vixl32::Register temp2 = scratch_scope.Acquire(); + __ Ldrd(temp, temp1, MemOperand(str, offset)); + __ Ldrd(temp2, out, MemOperand(arg, offset)); + __ Cmp(temp, temp2); + __ B(ne, &return_false, /* far_label */ false); + __ Cmp(temp1, out); + __ B(ne, &return_false, /* far_label */ false); + offset += 2u * sizeof(uint32_t); + remaining_bytes -= 2u * sizeof(uint32_t); + } + if (remaining_bytes != 0u) { + __ Ldr(temp, MemOperand(str, offset)); + __ Ldr(out, MemOperand(arg, offset)); + __ Cmp(temp, out); + __ B(ne, &return_false, /* far_label */ false); + } + } else { + // Return true if both strings are empty. Even with string compression `count == 0` means empty. + static_assert(static_cast<uint32_t>(mirror::StringCompressionFlag::kCompressed) == 0u, + "Expecting 0=compressed, 1=uncompressed"); + __ CompareAndBranchIfZero(temp, &return_true, /* far_target */ false); - // Store offset of string value in preparation for comparison loop. - __ Mov(temp1, value_offset); + if (mirror::kUseStringCompression) { + // For string compression, calculate the number of bytes to compare (not chars). + // This could in theory exceed INT32_MAX, so treat temp as unsigned. + __ Lsrs(temp, temp, 1u); // Extract length and check compression flag. + ExactAssemblyScope aas(assembler->GetVIXLAssembler(), + 2 * kMaxInstructionSizeInBytes, + CodeBufferCheckScope::kMaximumSize); + __ it(cs); // If uncompressed, + __ add(cs, temp, temp, temp); // double the byte count. + } - // Loop to compare strings 4 bytes at a time starting at the front of the string. - // Ok to do this because strings are zero-padded to kObjectAlignment. - __ Bind(&loop); - __ Ldr(out, MemOperand(str, temp1)); - __ Ldr(temp2, MemOperand(arg, temp1)); - __ Add(temp1, temp1, Operand::From(sizeof(uint32_t))); - __ Cmp(out, temp2); - __ B(ne, &return_false, /* far_target */ false); - // With string compression, we have compared 4 bytes, otherwise 2 chars. - __ Subs(temp, temp, mirror::kUseStringCompression ? 4 : 2); - __ B(hi, &loop, /* far_target */ false); + vixl32::Register temp1 = RegisterFrom(locations->GetTemp(1)); + UseScratchRegisterScope scratch_scope(assembler->GetVIXLAssembler()); + vixl32::Register temp2 = scratch_scope.Acquire(); + + // Store offset of string value in preparation for comparison loop. + __ Mov(temp1, value_offset); + + // Loop to compare strings 4 bytes at a time starting at the front of the string. + __ Bind(&loop); + __ Ldr(out, MemOperand(str, temp1)); + __ Ldr(temp2, MemOperand(arg, temp1)); + __ Add(temp1, temp1, Operand::From(sizeof(uint32_t))); + __ Cmp(out, temp2); + __ B(ne, &return_false, /* far_target */ false); + // With string compression, we have compared 4 bytes, otherwise 2 chars. + __ Subs(temp, temp, mirror::kUseStringCompression ? 4 : 2); + __ B(hi, &loop, /* far_target */ false); + } // Return true and exit the function. // If loop does not result in returning false, we return true. diff --git a/test/021-string2/src/Main.java b/test/021-string2/src/Main.java index 3b81d8e623..c713aa43a6 100644 --- a/test/021-string2/src/Main.java +++ b/test/021-string2/src/Main.java @@ -556,12 +556,24 @@ public class Main { Assert.assertTrue($noinline$equalsConstString0("")); Assert.assertFalse($noinline$equalsConstString0("1")); + Assert.assertTrue($noinline$equalsConstString3("012")); + Assert.assertFalse($noinline$equalsConstString3("01")); + Assert.assertFalse($noinline$equalsConstString3("0123")); + Assert.assertFalse($noinline$equalsConstString3("01x")); + Assert.assertFalse($noinline$equalsConstString3("01\u0440")); + Assert.assertTrue($noinline$equalsConstString7("0123456")); Assert.assertFalse($noinline$equalsConstString7("012345")); Assert.assertFalse($noinline$equalsConstString7("01234567")); Assert.assertFalse($noinline$equalsConstString7("012345x")); Assert.assertFalse($noinline$equalsConstString7("012345\u0440")); + Assert.assertTrue($noinline$equalsConstString12("012345678901")); + Assert.assertFalse($noinline$equalsConstString12("01234567890")); + Assert.assertFalse($noinline$equalsConstString12("0123456789012")); + Assert.assertFalse($noinline$equalsConstString12("01234567890x")); + Assert.assertFalse($noinline$equalsConstString12("01234567890\u0440")); + Assert.assertTrue($noinline$equalsConstString14("01234567890123")); Assert.assertFalse($noinline$equalsConstString14("0123456789012")); Assert.assertFalse($noinline$equalsConstString14("012345678901234")); @@ -587,12 +599,24 @@ public class Main { Assert.assertFalse( $noinline$equalsConstString35("0123456789012345678901234567890123\u0440")); + Assert.assertTrue($noinline$equalsConstNonAsciiString3("\u044012")); + Assert.assertFalse($noinline$equalsConstNonAsciiString3("\u04401")); + Assert.assertFalse($noinline$equalsConstNonAsciiString3("\u0440123")); + Assert.assertFalse($noinline$equalsConstNonAsciiString3("\u04401x")); + Assert.assertFalse($noinline$equalsConstNonAsciiString3("012")); + Assert.assertTrue($noinline$equalsConstNonAsciiString7("\u0440123456")); Assert.assertFalse($noinline$equalsConstNonAsciiString7("\u044012345")); Assert.assertFalse($noinline$equalsConstNonAsciiString7("\u04401234567")); Assert.assertFalse($noinline$equalsConstNonAsciiString7("\u044012345x")); Assert.assertFalse($noinline$equalsConstNonAsciiString7("0123456")); + Assert.assertTrue($noinline$equalsConstNonAsciiString12("\u044012345678901")); + Assert.assertFalse($noinline$equalsConstNonAsciiString12("\u04401234567890")); + Assert.assertFalse($noinline$equalsConstNonAsciiString12("\u0440123456789012")); + Assert.assertFalse($noinline$equalsConstNonAsciiString12("\u04401234567890x")); + Assert.assertFalse($noinline$equalsConstNonAsciiString12("012345678901")); + Assert.assertTrue($noinline$equalsConstNonAsciiString14("\u04401234567890123")); Assert.assertFalse($noinline$equalsConstNonAsciiString14("\u0440123456789012")); Assert.assertFalse($noinline$equalsConstNonAsciiString14("\u044012345678901234")); @@ -631,12 +655,24 @@ public class Main { Assert.assertTrue($noinline$constString0Equals("")); Assert.assertFalse($noinline$constString0Equals("1")); + Assert.assertTrue($noinline$constString3Equals("012")); + Assert.assertFalse($noinline$constString3Equals("01")); + Assert.assertFalse($noinline$constString3Equals("0123")); + Assert.assertFalse($noinline$constString3Equals("01x")); + Assert.assertFalse($noinline$constString3Equals("01\u0440")); + Assert.assertTrue($noinline$constString7Equals("0123456")); Assert.assertFalse($noinline$constString7Equals("012345")); Assert.assertFalse($noinline$constString7Equals("01234567")); Assert.assertFalse($noinline$constString7Equals("012345x")); Assert.assertFalse($noinline$constString7Equals("012345\u0440")); + Assert.assertTrue($noinline$constString12Equals("012345678901")); + Assert.assertFalse($noinline$constString12Equals("01234567890")); + Assert.assertFalse($noinline$constString12Equals("0123456789012")); + Assert.assertFalse($noinline$constString12Equals("01234567890x")); + Assert.assertFalse($noinline$constString12Equals("01234567890\u0440")); + Assert.assertTrue($noinline$constString14Equals("01234567890123")); Assert.assertFalse($noinline$constString14Equals("0123456789012")); Assert.assertFalse($noinline$constString14Equals("012345678901234")); @@ -662,12 +698,24 @@ public class Main { Assert.assertFalse( $noinline$constString35Equals("0123456789012345678901234567890123\u0040")); + Assert.assertTrue($noinline$constNonAsciiString3Equals("\u044012")); + Assert.assertFalse($noinline$constNonAsciiString3Equals("\u04401")); + Assert.assertFalse($noinline$constNonAsciiString3Equals("\u0440123")); + Assert.assertFalse($noinline$constNonAsciiString3Equals("\u04401x")); + Assert.assertFalse($noinline$constNonAsciiString3Equals("0123456")); + Assert.assertTrue($noinline$constNonAsciiString7Equals("\u0440123456")); Assert.assertFalse($noinline$constNonAsciiString7Equals("\u044012345")); Assert.assertFalse($noinline$constNonAsciiString7Equals("\u04401234567")); Assert.assertFalse($noinline$constNonAsciiString7Equals("\u044012345x")); Assert.assertFalse($noinline$constNonAsciiString7Equals("0123456")); + Assert.assertTrue($noinline$constNonAsciiString12Equals("\u044012345678901")); + Assert.assertFalse($noinline$constNonAsciiString12Equals("\u04401234567890")); + Assert.assertFalse($noinline$constNonAsciiString12Equals("\u0440123456789012")); + Assert.assertFalse($noinline$constNonAsciiString12Equals("\u04401234567890x")); + Assert.assertFalse($noinline$constNonAsciiString12Equals("012345678901")); + Assert.assertTrue($noinline$constNonAsciiString14Equals("\u04401234567890123")); Assert.assertFalse($noinline$constNonAsciiString14Equals("\u0440123456789012")); Assert.assertFalse($noinline$constNonAsciiString14Equals("\u044012345678901234")); @@ -708,134 +756,138 @@ public class Main { } public static boolean $noinline$equalsConstString0(String s) { - if (doThrow) { throw new Error(); } return s.equals(""); } + public static boolean $noinline$equalsConstString3(String s) { + return s.equals("012"); + } + public static boolean $noinline$equalsConstString7(String s) { - if (doThrow) { throw new Error(); } return s.equals("0123456"); } + public static boolean $noinline$equalsConstString12(String s) { + return s.equals("012345678901"); + } + public static boolean $noinline$equalsConstString14(String s) { - if (doThrow) { throw new Error(); } return s.equals("01234567890123"); } public static boolean $noinline$equalsConstString24(String s) { - if (doThrow) { throw new Error(); } return s.equals("012345678901234567890123"); } public static boolean $noinline$equalsConstString29(String s) { - if (doThrow) { throw new Error(); } return s.equals("01234567890123456789012345678"); } public static boolean $noinline$equalsConstString35(String s) { - if (doThrow) { throw new Error(); } return s.equals("01234567890123456789012345678901234"); } + public static boolean $noinline$equalsConstNonAsciiString3(String s) { + return s.equals("\u044012"); + } + public static boolean $noinline$equalsConstNonAsciiString7(String s) { - if (doThrow) { throw new Error(); } return s.equals("\u0440123456"); } + public static boolean $noinline$equalsConstNonAsciiString12(String s) { + return s.equals("\u044012345678901"); + } + public static boolean $noinline$equalsConstNonAsciiString14(String s) { - if (doThrow) { throw new Error(); } return s.equals("\u04401234567890123"); } public static boolean $noinline$equalsConstNonAsciiString24(String s) { - if (doThrow) { throw new Error(); } return s.equals("\u044012345678901234567890123"); } public static boolean $noinline$equalsConstNonAsciiString29(String s) { - if (doThrow) { throw new Error(); } return s.equals("\u04401234567890123456789012345678"); } public static boolean $noinline$equalsConstNonAsciiString35(String s) { - if (doThrow) { throw new Error(); } return s.equals("\u04401234567890123456789012345678901234"); } public static boolean $noinline$constString0Equals(String s) { - if (doThrow) { throw new Error(); } return s.equals(""); } + public static boolean $noinline$constString3Equals(String s) { + return "012".equals(s); + } + public static boolean $noinline$constString7Equals(String s) { - if (doThrow) { throw new Error(); } return "0123456".equals(s); } + public static boolean $noinline$constString12Equals(String s) { + return "012345678901".equals(s); + } + public static boolean $noinline$constString14Equals(String s) { - if (doThrow) { throw new Error(); } return "01234567890123".equals(s); } public static boolean $noinline$constString24Equals(String s) { - if (doThrow) { throw new Error(); } return "012345678901234567890123".equals(s); } public static boolean $noinline$constString29Equals(String s) { - if (doThrow) { throw new Error(); } return "01234567890123456789012345678".equals(s); } public static boolean $noinline$constString35Equals(String s) { - if (doThrow) { throw new Error(); } return "01234567890123456789012345678901234".equals(s); } + public static boolean $noinline$constNonAsciiString3Equals(String s) { + return "\u044012".equals(s); + } + public static boolean $noinline$constNonAsciiString7Equals(String s) { - if (doThrow) { throw new Error(); } return "\u0440123456".equals(s); } + public static boolean $noinline$constNonAsciiString12Equals(String s) { + return "\u044012345678901".equals(s); + } + public static boolean $noinline$constNonAsciiString14Equals(String s) { - if (doThrow) { throw new Error(); } return "\u04401234567890123".equals(s); } public static boolean $noinline$constNonAsciiString24Equals(String s) { - if (doThrow) { throw new Error(); } return "\u044012345678901234567890123".equals(s); } public static boolean $noinline$constNonAsciiString29Equals(String s) { - if (doThrow) { throw new Error(); } return "\u04401234567890123456789012345678".equals(s); } public static boolean $noinline$constNonAsciiString35Equals(String s) { - if (doThrow) { throw new Error(); } return "\u04401234567890123456789012345678901234".equals(s); } public static int $noinline$compareTo(String lhs, String rhs) { - if (doThrow) { throw new Error(); } return lhs.compareTo(rhs); } public static boolean $noinline$equals(String lhs, String rhs) { - if (doThrow) { throw new Error(); } return lhs.equals(rhs); } public static int $noinline$indexOf(String lhs, int ch) { - if (doThrow) { throw new Error(); } return lhs.indexOf(ch); } public static int $noinline$indexOf(String lhs, int ch, int fromIndex) { - if (doThrow) { throw new Error(); } return lhs.indexOf(ch, fromIndex); } - - public static boolean doThrow = false; } |