diff options
-rw-r--r-- | compiler/optimizing/code_generator_x86_64.cc | 1 | ||||
-rw-r--r-- | compiler/optimizing/code_generator_x86_64.h | 2 | ||||
-rw-r--r-- | compiler/optimizing/instruction_builder.cc | 10 | ||||
-rw-r--r-- | compiler/optimizing/intrinsics_x86_64.cc | 27 | ||||
-rw-r--r-- | compiler/optimizing/nodes.h | 3 | ||||
-rw-r--r-- | test/2277-methodhandle-invokeexact/expected-stdout.txt | 28 | ||||
-rw-r--r-- | test/2277-methodhandle-invokeexact/src-multidex/Multi.java | 9 | ||||
-rw-r--r-- | test/2277-methodhandle-invokeexact/src/Main.java | 104 |
8 files changed, 105 insertions, 79 deletions
diff --git a/compiler/optimizing/code_generator_x86_64.cc b/compiler/optimizing/code_generator_x86_64.cc index e1c3c9f426..6a7f9b1264 100644 --- a/compiler/optimizing/code_generator_x86_64.cc +++ b/compiler/optimizing/code_generator_x86_64.cc @@ -57,7 +57,6 @@ class GcRoot; namespace x86_64 { static constexpr int kCurrentMethodStackOffset = 0; -static constexpr Register kMethodRegisterArgument = RDI; // The compare/jump sequence will generate about (1.5 * num_entries) instructions. A jump // table version generates 7 instructions and num_entries literals. Compare/jump sequence will // generates less code/data with a small num_entries. diff --git a/compiler/optimizing/code_generator_x86_64.h b/compiler/optimizing/code_generator_x86_64.h index ad4a60e091..38758148f1 100644 --- a/compiler/optimizing/code_generator_x86_64.h +++ b/compiler/optimizing/code_generator_x86_64.h @@ -28,6 +28,8 @@ namespace art HIDDEN { namespace x86_64 { +static constexpr Register kMethodRegisterArgument = RDI; + // Use a local definition to prevent copying mistakes. static constexpr size_t kX86_64WordSize = static_cast<size_t>(kX86_64PointerSize); diff --git a/compiler/optimizing/instruction_builder.cc b/compiler/optimizing/instruction_builder.cc index d7553dd14f..356322e85b 100644 --- a/compiler/optimizing/instruction_builder.cc +++ b/compiler/optimizing/instruction_builder.cc @@ -1391,9 +1391,13 @@ bool HInstructionBuilder::BuildInvokePolymorphic(uint32_t dex_pc, MethodReference method_reference(&graph_->GetDexFile(), method_idx); + // MethodHandle.invokeExact intrinsic needs to check whether call-site matches with MethodHandle's + // type. To do that, MethodType corresponding to the call-site is passed as an extra input. + // Other invoke-polymorphic calls do not need it. bool is_invoke_exact = static_cast<Intrinsics>(resolved_method->GetIntrinsic()) == Intrinsics::kMethodHandleInvokeExact; + // Currently intrinsic works for MethodHandle targeting invoke-virtual calls only. bool can_be_virtual = number_of_arguments >= 2 && DataType::FromShorty(shorty[1]) == DataType::Type::kReference; @@ -1414,7 +1418,7 @@ bool HInstructionBuilder::BuildInvokePolymorphic(uint32_t dex_pc, return false; } - DCHECK_EQ(invoke->AsInvokePolymorphic()->CanTargetInvokeVirtual(), can_be_intrinsified); + DCHECK_EQ(invoke->AsInvokePolymorphic()->CanHaveFastPath(), can_be_intrinsified); if (invoke->GetIntrinsic() != Intrinsics::kNone && invoke->GetIntrinsic() != Intrinsics::kMethodHandleInvoke && @@ -1896,7 +1900,9 @@ bool HInstructionBuilder::SetupInvokeArguments(HInstruction* invoke, if (invoke->IsInvokePolymorphic()) { HInvokePolymorphic* invoke_polymorphic = invoke->AsInvokePolymorphic(); - if (invoke_polymorphic->CanTargetInvokeVirtual()) { + // MethodHandle.invokeExact intrinsic expects MethodType corresponding to the call-site as an + // extra input to determine whether to throw WrongMethodTypeException or execute target method. + if (invoke_polymorphic->CanHaveFastPath()) { HLoadMethodType* load_method_type = new (allocator_) HLoadMethodType(graph_->GetCurrentMethod(), invoke_polymorphic->GetProtoIndex(), diff --git a/compiler/optimizing/intrinsics_x86_64.cc b/compiler/optimizing/intrinsics_x86_64.cc index 2c9272d403..d085d2c469 100644 --- a/compiler/optimizing/intrinsics_x86_64.cc +++ b/compiler/optimizing/intrinsics_x86_64.cc @@ -145,10 +145,12 @@ class ReadBarrierSystemArrayCopySlowPathX86_64 : public SlowPathCode { DISALLOW_COPY_AND_ASSIGN(ReadBarrierSystemArrayCopySlowPathX86_64); }; -// invoke-polymorphic's slow-path which does not move arguments. +// The MethodHandle.invokeExact intrinsic sets up arguments to match the target method call. If we +// need to go to the slow path, we call art_quick_invoke_polymorphic_with_hidden_receiver, which +// expects the MethodHandle object in RDI (in place of the actual ArtMethod). class InvokePolymorphicSlowPathX86_64 : public SlowPathCode { public: - explicit InvokePolymorphicSlowPathX86_64(HInstruction* instruction, CpuRegister method_handle) + InvokePolymorphicSlowPathX86_64(HInstruction* instruction, CpuRegister method_handle) : SlowPathCode(instruction), method_handle_(method_handle) { DCHECK(instruction->IsInvokePolymorphic()); } @@ -159,6 +161,7 @@ class InvokePolymorphicSlowPathX86_64 : public SlowPathCode { __ Bind(GetEntryLabel()); SaveLiveRegisters(codegen, instruction_->GetLocations()); + // Passing `MethodHandle` object as hidden argument. __ movq(CpuRegister(RDI), method_handle_); x86_64_codegen->InvokeRuntime(QuickEntrypointEnum::kQuickInvokePolymorphicWithHiddenReceiver, instruction_, @@ -4099,7 +4102,7 @@ void IntrinsicLocationsBuilderX86_64::VisitMethodHandleInvokeExact(HInvoke* invo // Don't emit intrinsic code for MethodHandle.invokeExact when it certainly does not target // invoke-virtual: if invokeExact is called w/o arguments or if the first argument in that // call is not a reference. - if (!invoke->AsInvokePolymorphic()->CanTargetInvokeVirtual()) { + if (!invoke->AsInvokePolymorphic()->CanHaveFastPath()) { return; } ArenaAllocator* allocator = invoke->GetBlock()->GetGraph()->GetAllocator(); @@ -4120,13 +4123,11 @@ void IntrinsicLocationsBuilderX86_64::VisitMethodHandleInvokeExact(HInvoke* invo // The last input is MethodType object corresponding to the call-site. locations->SetInAt(number_of_args, Location::RequiresRegister()); - // We use a fixed-register temporary to pass the target method. - locations->AddTemp(calling_convention.GetMethodLocation()); locations->AddTemp(Location::RequiresRegister()); } void IntrinsicCodeGeneratorX86_64::VisitMethodHandleInvokeExact(HInvoke* invoke) { - DCHECK(invoke->AsInvokePolymorphic()->CanTargetInvokeVirtual()); + DCHECK(invoke->AsInvokePolymorphic()->CanHaveFastPath()); LocationSummary* locations = invoke->GetLocations(); CpuRegister method_handle = locations->InAt(0).AsRegister<CpuRegister>(); @@ -4139,9 +4140,10 @@ void IntrinsicCodeGeneratorX86_64::VisitMethodHandleInvokeExact(HInvoke* invoke) Address method_handle_kind = Address(method_handle, mirror::MethodHandle::HandleKindOffset()); // If it is not InvokeVirtual then go to slow path. - // Even if MethodHandle's kind is kInvokeVirtual underlying method still can be an interface or - // direct method (that's what current `MethodHandles$Lookup.findVirtual` is doing). We don't check - // whether `method` is an interface method explicitly: in that case the subtype check will fail. + // Even if MethodHandle's kind is kInvokeVirtual, the underlying method can still be an interface + // or a direct method (that's what current `MethodHandles$Lookup.findVirtual` is doing). We don't + // check whether `method` is an interface method explicitly: in that case the subtype check below + // will fail. // TODO(b/297147201): check whether it can be more precise and what d8/r8 can produce. __ cmpl(method_handle_kind, Immediate(mirror::MethodHandle::Kind::kInvokeVirtual)); __ j(kNotEqual, slow_path->GetEntryLabel()); @@ -4153,16 +4155,17 @@ void IntrinsicCodeGeneratorX86_64::VisitMethodHandleInvokeExact(HInvoke* invoke) __ cmpl(call_site_type, Address(method_handle, mirror::MethodHandle::MethodTypeOffset())); __ j(kNotEqual, slow_path->GetEntryLabel()); - CpuRegister method = locations->GetTemp(0).AsRegister<CpuRegister>(); + CpuRegister method = CpuRegister(kMethodRegisterArgument); - // Find method to call. + // Get method to call. __ movq(method, Address(method_handle, mirror::MethodHandle::ArtFieldOrMethodOffset())); CpuRegister receiver = locations->InAt(1).AsRegister<CpuRegister>(); // Using vtable_index register as temporary in subtype check. It will be overridden later. // If `method` is an interface method this check will fail. - CpuRegister vtable_index = locations->GetTemp(1).AsRegister<CpuRegister>(); + CpuRegister vtable_index = locations->GetTemp(0).AsRegister<CpuRegister>(); + // We deliberately avoid the read barrier, letting the slow path handle the false negatives. GenerateSubTypeObjectCheckNoReadBarrier(codegen_, slow_path, receiver, diff --git a/compiler/optimizing/nodes.h b/compiler/optimizing/nodes.h index ffddd25843..eb6d9ecad4 100644 --- a/compiler/optimizing/nodes.h +++ b/compiler/optimizing/nodes.h @@ -4939,7 +4939,8 @@ class HInvokePolymorphic final : public HInvoke { dex::ProtoIndex GetProtoIndex() { return proto_idx_; } - bool CanTargetInvokeVirtual() const { + // Whether we can do direct invocation of the method handle. + bool CanHaveFastPath() const { return GetIntrinsic() == Intrinsics::kMethodHandleInvokeExact && GetNumberOfArguments() >= 2 && InputAt(1)->GetType() == DataType::Type::kReference; diff --git a/test/2277-methodhandle-invokeexact/expected-stdout.txt b/test/2277-methodhandle-invokeexact/expected-stdout.txt index d924f9903d..e69de29bb2 100644 --- a/test/2277-methodhandle-invokeexact/expected-stdout.txt +++ b/test/2277-methodhandle-invokeexact/expected-stdout.txt @@ -1,28 +0,0 @@ -in voidMethod -A.returnInt()=42 -A.returnDouble()=42.0 -in I.defaultMethod -in A.overrideMe -I'm private interface method -A.privateReturnInt()=1042 -B.privateReturnInt()=9999 -((B) new A()).privateReturnInt()=9999 -I am from throwException -A.staticMethod()=staticMethod -Multi.$noinline$getMethodHandle().invokeExact(nonEmpty)=hello -Multi: mh.invokeExact(nonEmpty)=1001 -Sums.sum(1)=1 -Sums.sum(1, 2)=3 -Sums.sum(1, 2, 3)=6 -Sums.sum(1, 2, 3, 4)=10 -Sums.sum(1, 2, 3, 4, 5)=15 -Sums.sum(1, 2, 3, 4, 5, 6)=21 -Sums.sum(1, 2, 3, 4, 5, 6, 7)=28 -Sums.sum(1, 2, 3, 4, 5, 6, 7, 8)=36 -Sums.sum(1, 2, 3, 4, 5, 6, 7, 8, 9)=45 -Sums.sum(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)=55 -Sums.sum(1, 2L)=3 -Sums.sum(1, 2L, 3, 4L)=10 -Sums.sum(1, 2L, 3, 4L, 5, 6L)=21 -Sums.sum(1, 2L, 3, 4L, 5, 6L, 7, 8L)=36 -Sums.sum(1, 2L, 3, 4L, 5, 6L, 7, 8L, 9, 10L)=55 diff --git a/test/2277-methodhandle-invokeexact/src-multidex/Multi.java b/test/2277-methodhandle-invokeexact/src-multidex/Multi.java index 43601a947e..5f2dcf0093 100644 --- a/test/2277-methodhandle-invokeexact/src-multidex/Multi.java +++ b/test/2277-methodhandle-invokeexact/src-multidex/Multi.java @@ -18,6 +18,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.lang.invoke.WrongMethodTypeException; +import java.util.Objects; import java.util.Optional; public class Multi { @@ -29,7 +30,7 @@ public class Multi { public static void $noinline$testMHFromMain(MethodHandle mh) throws Throwable { Optional<Integer> nonEmpty = Optional.<Integer>of(1001); Object result = (Object) mh.invokeExact(nonEmpty); - System.out.println("Multi: mh.invokeExact(nonEmpty)=" + result); + assertEquals("Expected 1001, but got " + result, 1001, result); try { mh.invokeExact(nonEmpty); @@ -37,6 +38,12 @@ public class Multi { } catch (WrongMethodTypeException expected) {} } + private static void assertEquals(String msg, Object expected, Object actual) { + if (!Objects.equals(expected, actual)) { + fail(msg); + } + } + private static void fail(String msg) { throw new AssertionError(msg); } diff --git a/test/2277-methodhandle-invokeexact/src/Main.java b/test/2277-methodhandle-invokeexact/src/Main.java index 729857e47a..9a52885c8f 100644 --- a/test/2277-methodhandle-invokeexact/src/Main.java +++ b/test/2277-methodhandle-invokeexact/src/Main.java @@ -20,10 +20,13 @@ import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandle; import java.lang.invoke.WrongMethodTypeException; import java.util.Arrays; +import java.util.Objects; import java.util.Optional; public class Main { + private static String STATUS = ""; + public static void main(String[] args) throws Throwable { $noinline$testNoArgsCalls(); $noinline$testMethodHandleFromOtherDex(); @@ -35,8 +38,7 @@ public class Main { MethodHandle mh = Multi.$noinline$getMethodHandle(); Optional<String> nonEmpty = Optional.<String>of("hello"); Object returnedObject = mh.invokeExact(nonEmpty); - System.out.println( - "Multi.$noinline$getMethodHandle().invokeExact(nonEmpty)=" + returnedObject); + assertEquals("hello", returnedObject); try { mh.invokeExact(nonEmpty); @@ -46,10 +48,13 @@ public class Main { private static void $noinline$testNoArgsCalls() throws Throwable { VOID_METHOD.invokeExact(new A()); + assertEquals("A.voidMethod", STATUS); + int returnedInt = (int) RETURN_INT.invokeExact(new A()); - System.out.println("A.returnInt()=" + returnedInt); + assertEquals(42, returnedInt); + double returnedDouble = (double) RETURN_DOUBLE.invokeExact(new A()); - System.out.println("A.returnDouble()=" + returnedDouble); + assertEquals(42.0d, returnedDouble); try { INTERFACE_DEFAULT_METHOD.invokeExact(new A()); @@ -57,23 +62,27 @@ public class Main { } catch (WrongMethodTypeException ignored) {} INTERFACE_DEFAULT_METHOD.invokeExact((I) new A()); + assertEquals("I.defaultMethod", STATUS); + OVERWRITTEN_INTERFACE_DEFAULT_METHOD.invokeExact((I) new A()); + assertEquals("A.overrideMe", STATUS); - System.out.println((String) PRIVATE_INTERFACE_METHOD.invokeExact((I) new A())); + assertEquals("boo", (String) PRIVATE_INTERFACE_METHOD.invokeExact((I) new A())); int privateIntA = (int) A_PRIVATE_RETURN_INT.invokeExact(new A()); - System.out.println("A.privateReturnInt()=" + privateIntA); + assertEquals(1042, privateIntA); int privateIntB = (int) B_PRIVATE_RETURN_INT.invokeExact(new B()); - System.out.println("B.privateReturnInt()=" + privateIntB); + assertEquals(9999, privateIntB); + privateIntB = (int) B_PRIVATE_RETURN_INT.invokeExact((B) new A()); - System.out.println("((B) new A()).privateReturnInt()=" + privateIntB); + assertEquals(9999, privateIntB); try { EXCEPTION_THROWING_METHOD.invokeExact(new A()); unreachable("Target method always throws"); - } catch (RuntimeException e) { - System.out.println(e.getMessage()); + } catch (MyRuntimeException expected) { + assertEquals("A.throwException", STATUS); } try { @@ -82,54 +91,78 @@ public class Main { } catch (WrongMethodTypeException ignored) {} String returnedString = (String) STATIC_METHOD.invokeExact(new A()); - System.out.println("A.staticMethod()=" + returnedString); + assertEquals("staticMethod", returnedString); } private static void $noinline$testWithArgs() throws Throwable { int sum = (int) SUM_I.invokeExact(new Sums(), 1); - System.out.println("Sums.sum(1)=" + sum); + assertEquals(1, sum); sum = (int) SUM_2I.invokeExact(new Sums(), 1, 2); - System.out.println("Sums.sum(1, 2)=" + sum); + assertEquals(3, sum); sum = (int) SUM_3I.invokeExact(new Sums(), 1, 2, 3); - System.out.println("Sums.sum(1, 2, 3)=" + sum); + assertEquals(6, sum); sum = (int) SUM_4I.invokeExact(new Sums(), 1, 2, 3, 4); - System.out.println("Sums.sum(1, 2, 3, 4)=" + sum); + assertEquals(10, sum); sum = (int) SUM_5I.invokeExact(new Sums(), 1, 2, 3, 4, 5); - System.out.println("Sums.sum(1, 2, 3, 4, 5)=" + sum); + assertEquals(15, sum); sum = (int) SUM_6I.invokeExact(new Sums(), 1, 2, 3, 4, 5, 6); - System.out.println("Sums.sum(1, 2, 3, 4, 5, 6)=" + sum); + assertEquals(21, sum); sum = (int) SUM_7I.invokeExact(new Sums(), 1, 2, 3, 4, 5, 6, 7); - System.out.println("Sums.sum(1, 2, 3, 4, 5, 6, 7)=" + sum); + assertEquals(28, sum); sum = (int) SUM_8I.invokeExact(new Sums(), 1, 2, 3, 4, 5, 6, 7, 8); - System.out.println("Sums.sum(1, 2, 3, 4, 5, 6, 7, 8)=" + sum); + assertEquals(36, sum); sum = (int) SUM_9I.invokeExact(new Sums(), 1, 2, 3, 4, 5, 6, 7, 8, 9); - System.out.println("Sums.sum(1, 2, 3, 4, 5, 6, 7, 8, 9)=" + sum); + assertEquals(45, sum); sum = (int) SUM_10I.invokeExact(new Sums(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); - System.out.println("Sums.sum(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)=" + sum); + assertEquals(55, sum); long lsum = (long) SUM_IJ.invokeExact(new Sums(), 1, 2L); - System.out.println("Sums.sum(1, 2L)=" + lsum); + assertEquals(3L, lsum); lsum = (long) SUM_2IJ.invokeExact(new Sums(), 1, 2L, 3, 4L); - System.out.println("Sums.sum(1, 2L, 3, 4L)=" + lsum); + assertEquals(10L, lsum); lsum = (long) SUM_3IJ.invokeExact(new Sums(), 1, 2L, 3, 4L, 5, 6L); - System.out.println("Sums.sum(1, 2L, 3, 4L, 5, 6L)=" + lsum); + assertEquals(21L, lsum); lsum = (long) SUM_4IJ.invokeExact(new Sums(), 1, 2L, 3, 4L, 5, 6L, 7, 8L); - System.out.println("Sums.sum(1, 2L, 3, 4L, 5, 6L, 7, 8L)=" + lsum); + assertEquals(36L, lsum); lsum = (long) SUM_5IJ.invokeExact(new Sums(), 1, 2L, 3, 4L, 5, 6L, 7, 8L, 9, 10L); - System.out.println("Sums.sum(1, 2L, 3, 4L, 5, 6L, 7, 8L, 9, 10L)=" + lsum); + assertEquals(55L, lsum); + } + + private static void assertEquals(Object expected, Object actual) { + if (!Objects.equals(expected, actual)) { + throw new AssertionError("Expected: " + expected + ", got: " + actual); + } + } + + private static void assertEquals(int expected, int actual) { + if (expected != actual) { + throw new AssertionError("Expected: " + expected + ", got: " + actual); + } + } + + private static void assertEquals(long expected, long actual) { + if (expected != actual) { + throw new AssertionError("Expected: " + expected + ", got: " + actual); + } + } + + private static void assertEquals(double expected, double actual) { + if (expected != actual) { + throw new AssertionError("Expected: " + expected + ", got: " + actual); + } } private static void unreachable(String msg) { @@ -251,7 +284,7 @@ public class Main { static interface I { public default void defaultMethod() { - System.out.println("in I.defaultMethod"); + STATUS = "I.defaultMethod"; } public default void overrideMe() { @@ -259,23 +292,26 @@ public class Main { } private String innerPrivateMethod() { - return "I'm private interface method"; + return "boo"; } } + static class MyRuntimeException extends RuntimeException {} + static class A extends B implements I { public int field; public void voidMethod() { - System.out.println("in voidMethod"); - } - - public void throwException() { - throw new RuntimeException("I am from throwException"); + STATUS = "A.voidMethod"; } @Override public void overrideMe() { - System.out.println("in A.overrideMe"); + STATUS = "A.overrideMe"; + } + + public void throwException() { + STATUS = "A.throwException"; + throw new MyRuntimeException(); } public double returnDouble() { |