diff options
| -rw-r--r-- | runtime/check_jni.cc | 186 |
1 files changed, 172 insertions, 14 deletions
diff --git a/runtime/check_jni.cc b/runtime/check_jni.cc index 4172b89244..b6ad5473ff 100644 --- a/runtime/check_jni.cc +++ b/runtime/check_jni.cc @@ -66,6 +66,8 @@ namespace art { #define kFlag_Invocation 0x8000 // Part of the invocation interface (JavaVM*). #define kFlag_ForceTrace 0x80000000 // Add this to a JNI function's flags if you want to trace every call. + +class VarArgs; /* * Java primitive types: * B - jbyte @@ -126,6 +128,116 @@ union JniValueType { jshort S; const void* V; // void jboolean Z; + const VarArgs* va; +}; + +/* + * A structure containing all the information needed to validate varargs arguments. + * + * Note that actually getting the arguments from this structure mutates it so should only be done on + * owned copies. + */ +class VarArgs { + public: + VarArgs(jmethodID m, va_list var) : m_(m), type_(kTypeVaList), cnt_(0) { + va_copy(vargs_, var); + } + + VarArgs(jmethodID m, const jvalue* vals) : m_(m), type_(kTypePtr), cnt_(0), ptr_(vals) {} + + ~VarArgs() { + if (type_ == kTypeVaList) { + va_end(vargs_); + } + } + + VarArgs(VarArgs&& other) { + m_ = other.m_; + cnt_ = other.cnt_; + type_ = other.type_; + if (other.type_ == kTypeVaList) { + va_copy(vargs_, other.vargs_); + } else { + ptr_ = other.ptr_; + } + } + + // This method is const because we need to ensure that one only uses the GetValue method on an + // owned copy of the VarArgs. This is because getting the next argument from a va_list is a + // mutating operation. Therefore we pass around these VarArgs with the 'const' qualifier and when + // we want to use one we need to Clone() it. + VarArgs Clone() const { + if (type_ == kTypeVaList) { + // const_cast needed to make sure the compiler is okay with va_copy, which (being a macro) is + // messed up if the source argument is not the exact type 'va_list'. + return VarArgs(m_, cnt_, const_cast<VarArgs*>(this)->vargs_); + } else { + return VarArgs(m_, cnt_, ptr_); + } + } + + jmethodID GetMethodID() const { + return m_; + } + + JniValueType GetValue(char fmt) { + JniValueType o; + if (type_ == kTypeVaList) { + switch (fmt) { + case 'Z': o.Z = static_cast<jboolean>(va_arg(vargs_, jint)); break; + case 'B': o.B = static_cast<jbyte>(va_arg(vargs_, jint)); break; + case 'C': o.C = static_cast<jchar>(va_arg(vargs_, jint)); break; + case 'S': o.S = static_cast<jshort>(va_arg(vargs_, jint)); break; + case 'I': o.I = va_arg(vargs_, jint); break; + case 'J': o.J = va_arg(vargs_, jlong); break; + case 'F': o.F = static_cast<jfloat>(va_arg(vargs_, jdouble)); break; + case 'D': o.D = va_arg(vargs_, jdouble); break; + case 'L': o.L = va_arg(vargs_, jobject); break; + default: + LOG(FATAL) << "Illegal type format char " << fmt; + UNREACHABLE(); + } + } else { + CHECK(type_ == kTypePtr); + jvalue v = ptr_[cnt_]; + cnt_++; + switch (fmt) { + case 'Z': o.Z = v.z; break; + case 'B': o.B = v.b; break; + case 'C': o.C = v.c; break; + case 'S': o.S = v.s; break; + case 'I': o.I = v.i; break; + case 'J': o.J = v.j; break; + case 'F': o.F = v.f; break; + case 'D': o.D = v.d; break; + case 'L': o.L = v.l; break; + default: + LOG(FATAL) << "Illegal type format char " << fmt; + UNREACHABLE(); + } + } + return o; + } + + private: + VarArgs(jmethodID m, uint32_t cnt, va_list var) : m_(m), type_(kTypeVaList), cnt_(cnt) { + va_copy(vargs_, var); + } + + VarArgs(jmethodID m, uint32_t cnt, const jvalue* vals) : m_(m), type_(kTypePtr), cnt_(cnt), ptr_(vals) {} + + enum VarArgsType { + kTypeVaList, + kTypePtr, + }; + + jmethodID m_; + VarArgsType type_; + uint32_t cnt_; + union { + va_list vargs_; + const jvalue* ptr_; + }; }; class ScopedCheck { @@ -339,7 +451,7 @@ class ScopedCheck { * z - jsize (for lengths; use i if negative values are okay) * v - JavaVM* * E - JNIEnv* - * . - no argument; just print "..." (used for varargs JNI calls) + * . - VarArgs* for Jni calls with variable length arguments * * Use the kFlag_NullableUtf flag where 'u' field(s) are nullable. */ @@ -736,11 +848,35 @@ class ScopedCheck { return CheckThread(arg.E); case 'L': // jobject return CheckInstance(soa, kObject, arg.L, true); + case '.': // A VarArgs list + return CheckVarArgs(soa, arg.va); default: return CheckNonHeapValue(fmt, arg); } } + bool CheckVarArgs(ScopedObjectAccess& soa, const VarArgs* args_p) + SHARED_REQUIRES(Locks::mutator_lock_) { + CHECK(args_p != nullptr); + VarArgs args(args_p->Clone()); + ArtMethod* m = CheckMethodID(soa, args.GetMethodID()); + if (m == nullptr) { + return false; + } + uint32_t len = 0; + const char* shorty = m->GetShorty(&len); + // Skip the return type + CHECK_GE(len, 1u); + len--; + shorty++; + for (uint32_t i = 0; i < len; i++) { + if (!CheckPossibleHeapValue(soa, shorty[i], args.GetValue(shorty[i]))) { + return false; + } + } + return true; + } + bool CheckNonHeapValue(char fmt, JniValueType arg) { switch (fmt) { case 'p': // TODO: pointer - null or readable? @@ -833,6 +969,24 @@ class ScopedCheck { } break; } + case '.': { + const VarArgs* va = arg.va; + VarArgs args(va->Clone()); + ArtMethod* m = soa.DecodeMethod(args.GetMethodID()); + uint32_t len; + const char* shorty = m->GetShorty(&len); + CHECK_GE(len, 1u); + // Skip past return value. + len--; + shorty++; + // Remove the previous ', ' from the message. + msg->erase(msg->length() - 2); + for (uint32_t i = 0; i < len; i++) { + *msg += ", "; + TracePossibleHeapValue(soa, entry, shorty[i], args.GetValue(shorty[i]), msg); + } + break; + } default: TraceNonHeapValue(fmt, arg, msg); break; @@ -1836,8 +1990,9 @@ class CheckJNI { static jobject NewObjectV(JNIEnv* env, jclass c, jmethodID mid, va_list vargs) { ScopedObjectAccess soa(env); ScopedCheck sc(kFlag_Default, __FUNCTION__); - JniValueType args[3] = {{.E = env}, {.c = c}, {.m = mid}}; - if (sc.Check(soa, true, "Ecm", args) && sc.CheckInstantiableNonArray(soa, c) && + VarArgs rest(mid, vargs); + JniValueType args[4] = {{.E = env}, {.c = c}, {.m = mid}, {.va = &rest}}; + if (sc.Check(soa, true, "Ecm.", args) && sc.CheckInstantiableNonArray(soa, c) && sc.CheckConstructor(soa, mid)) { JniValueType result; result.L = baseEnv(env)->NewObjectV(env, c, mid, vargs); @@ -1859,8 +2014,9 @@ class CheckJNI { static jobject NewObjectA(JNIEnv* env, jclass c, jmethodID mid, jvalue* vargs) { ScopedObjectAccess soa(env); ScopedCheck sc(kFlag_Default, __FUNCTION__); - JniValueType args[3] = {{.E = env}, {.c = c}, {.m = mid}}; - if (sc.Check(soa, true, "Ecm", args) && sc.CheckInstantiableNonArray(soa, c) && + VarArgs rest(mid, vargs); + JniValueType args[4] = {{.E = env}, {.c = c}, {.m = mid}, {.va = &rest}}; + if (sc.Check(soa, true, "Ecm.", args) && sc.CheckInstantiableNonArray(soa, c) && sc.CheckConstructor(soa, mid)) { JniValueType result; result.L = baseEnv(env)->NewObjectA(env, c, mid, vargs); @@ -2689,25 +2845,25 @@ class CheckJNI { } static bool CheckCallArgs(ScopedObjectAccess& soa, ScopedCheck& sc, JNIEnv* env, jobject obj, - jclass c, jmethodID mid, InvokeType invoke) + jclass c, jmethodID mid, InvokeType invoke, const VarArgs* vargs) SHARED_REQUIRES(Locks::mutator_lock_) { bool checked; switch (invoke) { case kVirtual: { DCHECK(c == nullptr); - JniValueType args[3] = {{.E = env}, {.L = obj}, {.m = mid}}; - checked = sc.Check(soa, true, "ELm", args); + JniValueType args[4] = {{.E = env}, {.L = obj}, {.m = mid}, {.va = vargs}}; + checked = sc.Check(soa, true, "ELm.", args); break; } case kDirect: { - JniValueType args[4] = {{.E = env}, {.L = obj}, {.c = c}, {.m = mid}}; - checked = sc.Check(soa, true, "ELcm", args); + JniValueType args[5] = {{.E = env}, {.L = obj}, {.c = c}, {.m = mid}, {.va = vargs}}; + checked = sc.Check(soa, true, "ELcm.", args); break; } case kStatic: { DCHECK(obj == nullptr); - JniValueType args[3] = {{.E = env}, {.c = c}, {.m = mid}}; - checked = sc.Check(soa, true, "Ecm", args); + JniValueType args[4] = {{.E = env}, {.c = c}, {.m = mid}, {.va = vargs}}; + checked = sc.Check(soa, true, "Ecm.", args); break; } default: @@ -2724,7 +2880,8 @@ class CheckJNI { ScopedObjectAccess soa(env); ScopedCheck sc(kFlag_Default, function_name); JniValueType result; - if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke) && + VarArgs rest(mid, vargs); + if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke, &rest) && sc.CheckMethodAndSig(soa, obj, c, mid, type, invoke)) { const char* result_check; switch (type) { @@ -2907,7 +3064,8 @@ class CheckJNI { ScopedObjectAccess soa(env); ScopedCheck sc(kFlag_Default, function_name); JniValueType result; - if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke) && + VarArgs rest(mid, vargs); + if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke, &rest) && sc.CheckMethodAndSig(soa, obj, c, mid, type, invoke)) { const char* result_check; switch (type) { |