diff options
| author | 2023-10-19 17:59:10 +0000 | |
|---|---|---|
| committer | 2023-10-19 17:59:10 +0000 | |
| commit | 33be5720ee21ff68dbc22bca8842e33c02f60cd0 (patch) | |
| tree | 2ad91ea5451e5858739cc94ab95bbbc1ca9168b9 /libs/input | |
| parent | e96bb74514b703041fc336e03b0c9e84439249c3 (diff) | |
| parent | 5d53fcb96b3e837d91282d31b0819aa6412a8c9f (diff) | |
Merge "Merge 10952656" into aosp-main-future
Diffstat (limited to 'libs/input')
| -rw-r--r-- | libs/input/Android.bp | 212 | ||||
| -rw-r--r-- | libs/input/FromRustToCpp.cpp | 26 | ||||
| -rw-r--r-- | libs/input/InputEventLabels.cpp | 40 | ||||
| -rw-r--r-- | libs/input/InputTransport.cpp | 25 | ||||
| -rw-r--r-- | libs/input/InputVerifier.cpp | 118 | ||||
| -rw-r--r-- | libs/input/InputWrapper.hpp | 18 | ||||
| -rw-r--r-- | libs/input/MotionPredictor.cpp | 49 | ||||
| -rw-r--r-- | libs/input/MotionPredictorMetricsManager.cpp | 373 | ||||
| -rw-r--r-- | libs/input/TfLiteMotionPredictor.cpp | 68 | ||||
| -rw-r--r-- | libs/input/VelocityTracker.cpp | 23 | ||||
| -rw-r--r-- | libs/input/ffi/FromRustToCpp.h | 23 | ||||
| -rw-r--r-- | libs/input/input_verifier.rs | 422 | ||||
| -rw-r--r-- | libs/input/tests/Android.bp | 21 | ||||
| -rw-r--r-- | libs/input/tests/InputVerifier_test.cpp | 29 | ||||
| -rw-r--r-- | libs/input/tests/MotionPredictorMetricsManager_test.cpp | 972 | ||||
| -rw-r--r-- | libs/input/tests/MotionPredictor_test.cpp | 11 | ||||
| -rw-r--r-- | libs/input/tests/VelocityTracker_test.cpp | 5 |
17 files changed, 2264 insertions, 171 deletions
diff --git a/libs/input/Android.bp b/libs/input/Android.bp index 869458c407..022dfaddc1 100644 --- a/libs/input/Android.bp +++ b/libs/input/Android.bp @@ -33,6 +33,138 @@ filegroup { ], } +aidl_interface { + name: "inputconstants", + host_supported: true, + vendor_available: true, + unstable: true, + srcs: [ + ":inputconstants_aidl", + ], + + backend: { + rust: { + enabled: true, + }, + }, +} + +rust_bindgen { + name: "libinput_bindgen", + host_supported: true, + crate_name: "input_bindgen", + visibility: ["//frameworks/native/services/inputflinger"], + wrapper_src: "InputWrapper.hpp", + + include_dirs: [ + "frameworks/native/include", + ], + + source_stem: "bindings", + + bindgen_flags: [ + "--verbose", + "--allowlist-var=AMOTION_EVENT_FLAG_CANCELED", + "--allowlist-var=AMOTION_EVENT_ACTION_CANCEL", + "--allowlist-var=AMOTION_EVENT_ACTION_UP", + "--allowlist-var=AMOTION_EVENT_ACTION_POINTER_DOWN", + "--allowlist-var=AMOTION_EVENT_ACTION_DOWN", + "--allowlist-var=AMOTION_EVENT_ACTION_POINTER_INDEX_SHIFT", + "--allowlist-var=MAX_POINTER_ID", + ], + + static_libs: [ + "inputconstants-cpp", + "libui-types", + ], + shared_libs: ["libc++"], + header_libs: [ + "native_headers", + "jni_headers", + "flatbuffer_headers", + ], +} + +// Contains methods to help access C++ code from rust +cc_library_static { + name: "libinput_from_rust_to_cpp", + cpp_std: "c++20", + host_supported: true, + cflags: [ + "-Wall", + "-Wextra", + "-Werror", + ], + srcs: [ + "FromRustToCpp.cpp", + ], + + generated_headers: [ + "cxx-bridge-header", + ], + generated_sources: ["libinput_cxx_bridge_code"], + + shared_libs: [ + "libbase", + ], +} + +genrule { + name: "libinput_cxx_bridge_code", + tools: ["cxxbridge"], + cmd: "$(location cxxbridge) $(in) >> $(out)", + srcs: ["input_verifier.rs"], + out: ["inputverifier_generated.cpp"], +} + +genrule { + name: "libinput_cxx_bridge_header", + tools: ["cxxbridge"], + cmd: "$(location cxxbridge) $(in) --header >> $(out)", + srcs: ["input_verifier.rs"], + out: ["input_verifier.rs.h"], +} + +rust_defaults { + name: "libinput_rust_defaults", + srcs: ["input_verifier.rs"], + host_supported: true, + rustlibs: [ + "libbitflags", + "libcxx", + "libinput_bindgen", + "liblogger", + "liblog_rust", + "inputconstants-rust", + ], + + shared_libs: [ + "libbase", + "liblog", + ], +} + +rust_ffi_static { + name: "libinput_rust", + crate_name: "input", + defaults: ["libinput_rust_defaults"], +} + +rust_test { + name: "libinput_rust_test", + defaults: ["libinput_rust_defaults"], + whole_static_libs: [ + "libinput_from_rust_to_cpp", + ], + test_options: { + unit_test: true, + }, + test_suites: ["device_tests"], + sanitize: { + hwaddress: true, + }, +} + cc_library { name: "libinput", cpp_std: "c++20", @@ -44,6 +176,7 @@ cc_library { "-Wno-unused-parameter", ], srcs: [ + "FromRustToCpp.cpp", "Input.cpp", "InputDevice.cpp", "InputEventLabels.cpp", @@ -52,6 +185,7 @@ cc_library { "KeyCharacterMap.cpp", "KeyLayoutMap.cpp", "MotionPredictor.cpp", + "MotionPredictorMetricsManager.cpp", "PrintTools.cpp", "PropertyMap.cpp", "TfLiteMotionPredictor.cpp", @@ -65,19 +199,28 @@ cc_library { header_libs: [ "flatbuffer_headers", "jni_headers", + "libeigen", "tensorflow_headers", ], - export_header_lib_headers: ["jni_headers"], + export_header_lib_headers: [ + "jni_headers", + "libeigen", + ], generated_headers: [ + "cxx-bridge-header", + "libinput_cxx_bridge_header", "toolbox_input_labels", ], + generated_sources: ["libinput_cxx_bridge_code"], + shared_libs: [ "libbase", "libcutils", "liblog", "libPlatformProperties", + "libtinyxml2", "libvintf", ], @@ -85,21 +228,36 @@ cc_library { "-Wl,--exclude-libs=libtflite_static.a", ], + sanitize: { + undefined: true, + all_undefined: true, + misc_undefined: ["integer"], + }, + static_libs: [ + "inputconstants-cpp", "libui-types", "libtflite_static", ], + whole_static_libs: [ + "libinput_rust", + ], + export_static_lib_headers: [ "libui-types", ], + export_generated_headers: [ + "cxx-bridge-header", + "libinput_cxx_bridge_header", + ], + target: { android: { srcs: [ "InputTransport.cpp", "android/os/IInputFlinger.aidl", - ":inputconstants_aidl", ], export_shared_lib_headers: ["libbinder"], @@ -107,6 +265,10 @@ cc_library { shared_libs: [ "libutils", "libbinder", + // Stats logging library and its dependencies. + "libstatslog_libinput", + "libstatsbootstrap", + "android.os.statsbootstrap_aidl-cpp", ], static_libs: [ @@ -117,12 +279,9 @@ cc_library { "libgui_window_info_static", ], - sanitize: { - misc_undefined: ["integer"], - }, - required: [ "motion_predictor_model_prebuilt", + "motion_predictor_model_config", ], }, host: { @@ -138,12 +297,8 @@ cc_library { host_linux: { srcs: [ "InputTransport.cpp", - "android/os/IInputConstants.aidl", - "android/os/IInputFlinger.aidl", - "android/os/InputConfig.aidl", ], static_libs: [ - "libhostgraphics", "libgui_window_info_static", ], shared_libs: [ @@ -165,6 +320,43 @@ cc_library { }, } +// Use bootstrap version of stats logging library. +// libinput is a bootstrap process (starts early in the boot process), and thus can't use the normal +// `libstatslog` because that requires `libstatssocket`, which is only available later in the boot. +cc_library { + name: "libstatslog_libinput", + generated_sources: ["statslog_libinput.cpp"], + generated_headers: ["statslog_libinput.h"], + export_generated_headers: ["statslog_libinput.h"], + shared_libs: [ + "libbinder", + "libstatsbootstrap", + "libutils", + "android.os.statsbootstrap_aidl-cpp", + ], +} + +genrule { + name: "statslog_libinput.h", + tools: ["stats-log-api-gen"], + cmd: "$(location stats-log-api-gen) --header $(genDir)/statslog_libinput.h --module libinput" + + " --namespace android,stats,libinput --bootstrap", + out: [ + "statslog_libinput.h", + ], +} + +genrule { + name: "statslog_libinput.cpp", + tools: ["stats-log-api-gen"], + cmd: "$(location stats-log-api-gen) --cpp $(genDir)/statslog_libinput.cpp --module libinput" + + " --namespace android,stats,libinput --importHeader statslog_libinput.h" + + " --bootstrap", + out: [ + "statslog_libinput.cpp", + ], +} + cc_defaults { name: "libinput_fuzz_defaults", cpp_std: "c++20", diff --git a/libs/input/FromRustToCpp.cpp b/libs/input/FromRustToCpp.cpp new file mode 100644 index 0000000000..e4ce62e734 --- /dev/null +++ b/libs/input/FromRustToCpp.cpp @@ -0,0 +1,26 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <android-base/logging.h> +#include <ffi/FromRustToCpp.h> + +namespace android { + +bool shouldLog(rust::Str tag) { + return android::base::ShouldLog(android::base::LogSeverity::DEBUG, tag.data()); +} + +} // namespace android diff --git a/libs/input/InputEventLabels.cpp b/libs/input/InputEventLabels.cpp index f99a7d640e..bade68629d 100644 --- a/libs/input/InputEventLabels.cpp +++ b/libs/input/InputEventLabels.cpp @@ -404,7 +404,8 @@ namespace android { DEFINE_AXIS(GESTURE_Y_OFFSET), \ DEFINE_AXIS(GESTURE_SCROLL_X_DISTANCE), \ DEFINE_AXIS(GESTURE_SCROLL_Y_DISTANCE), \ - DEFINE_AXIS(GESTURE_PINCH_SCALE_FACTOR) + DEFINE_AXIS(GESTURE_PINCH_SCALE_FACTOR), \ + DEFINE_AXIS(GESTURE_SWIPE_FINGER_COUNT) // NOTE: If you add new LEDs here, you must also add them to Input.h #define LEDS_SEQUENCE \ @@ -433,17 +434,14 @@ namespace android { // clang-format on // --- InputEventLookup --- -const std::unordered_map<std::string, int> InputEventLookup::KEYCODES = {KEYCODES_SEQUENCE}; -const std::vector<InputEventLabel> InputEventLookup::KEY_NAMES = {KEYCODES_SEQUENCE}; - -const std::unordered_map<std::string, int> InputEventLookup::AXES = {AXES_SEQUENCE}; - -const std::vector<InputEventLabel> InputEventLookup::AXES_NAMES = {AXES_SEQUENCE}; - -const std::unordered_map<std::string, int> InputEventLookup::LEDS = {LEDS_SEQUENCE}; - -const std::unordered_map<std::string, int> InputEventLookup::FLAGS = {FLAGS_SEQUENCE}; +InputEventLookup::InputEventLookup() + : KEYCODES({KEYCODES_SEQUENCE}), + KEY_NAMES({KEYCODES_SEQUENCE}), + AXES({AXES_SEQUENCE}), + AXES_NAMES({AXES_SEQUENCE}), + LEDS({LEDS_SEQUENCE}), + FLAGS({FLAGS_SEQUENCE}) {} std::optional<int> InputEventLookup::lookupValueByLabel( const std::unordered_map<std::string, int>& map, const char* literal) { @@ -461,30 +459,36 @@ const char* InputEventLookup::lookupLabelByValue(const std::vector<InputEventLab } std::optional<int> InputEventLookup::getKeyCodeByLabel(const char* label) { - return lookupValueByLabel(KEYCODES, label); + const auto& self = get(); + return self.lookupValueByLabel(self.KEYCODES, label); } const char* InputEventLookup::getLabelByKeyCode(int32_t keyCode) { - if (keyCode >= 0 && static_cast<size_t>(keyCode) < KEYCODES.size()) { - return lookupLabelByValue(KEY_NAMES, keyCode); + const auto& self = get(); + if (keyCode >= 0 && static_cast<size_t>(keyCode) < self.KEYCODES.size()) { + return get().lookupLabelByValue(self.KEY_NAMES, keyCode); } return nullptr; } std::optional<int> InputEventLookup::getKeyFlagByLabel(const char* label) { - return lookupValueByLabel(FLAGS, label); + const auto& self = get(); + return lookupValueByLabel(self.FLAGS, label); } std::optional<int> InputEventLookup::getAxisByLabel(const char* label) { - return lookupValueByLabel(AXES, label); + const auto& self = get(); + return lookupValueByLabel(self.AXES, label); } const char* InputEventLookup::getAxisLabel(int32_t axisId) { - return lookupLabelByValue(AXES_NAMES, axisId); + const auto& self = get(); + return lookupLabelByValue(self.AXES_NAMES, axisId); } std::optional<int> InputEventLookup::getLedByLabel(const char* label) { - return lookupValueByLabel(LEDS, label); + const auto& self = get(); + return lookupValueByLabel(self.LEDS, label); } namespace { diff --git a/libs/input/InputTransport.cpp b/libs/input/InputTransport.cpp index f6b4648d67..4d3d8bc31c 100644 --- a/libs/input/InputTransport.cpp +++ b/libs/input/InputTransport.cpp @@ -4,6 +4,7 @@ // Provides a shared memory transport for input events. // #define LOG_TAG "InputTransport" +#define ATRACE_TAG ATRACE_TAG_INPUT #include <errno.h> #include <fcntl.h> @@ -13,6 +14,7 @@ #include <sys/types.h> #include <unistd.h> +#include <android-base/logging.h> #include <android-base/properties.h> #include <android-base/stringprintf.h> #include <binder/Parcel.h> @@ -80,6 +82,7 @@ const bool DEBUG_RESAMPLING = } // namespace +using android::base::Result; using android::base::StringPrintf; namespace android { @@ -449,6 +452,13 @@ status_t InputChannel::sendMessage(const InputMessage* msg) { ALOGD_IF(DEBUG_CHANNEL_MESSAGES, "channel '%s' ~ sent message of type %s", mName.c_str(), ftl::enum_string(msg->header.type).c_str()); + + if (ATRACE_ENABLED()) { + std::string message = + StringPrintf("sendMessage(inputChannel=%s, seq=0x%" PRIx32 ", type=0x%" PRIx32 ")", + mName.c_str(), msg->header.seq, msg->header.type); + ATRACE_NAME(message.c_str()); + } return OK; } @@ -484,6 +494,13 @@ status_t InputChannel::receiveMessage(InputMessage* msg) { ALOGD_IF(DEBUG_CHANNEL_MESSAGES, "channel '%s' ~ received message of type %s", mName.c_str(), ftl::enum_string(msg->header.type).c_str()); + + if (ATRACE_ENABLED()) { + std::string message = StringPrintf("receiveMessage(inputChannel=%s, seq=0x%" PRIx32 + ", type=0x%" PRIx32 ")", + mName.c_str(), msg->header.seq, msg->header.type); + ATRACE_NAME(message.c_str()); + } return OK; } @@ -606,8 +623,12 @@ status_t InputPublisher::publishMotionEvent( ATRACE_NAME(message.c_str()); } if (verifyEvents()) { - mInputVerifier.processMovement(deviceId, action, pointerCount, pointerProperties, - pointerCoords, flags); + Result<void> result = + mInputVerifier.processMovement(deviceId, action, pointerCount, pointerProperties, + pointerCoords, flags); + if (!result.ok()) { + LOG(FATAL) << "Bad stream: " << result.error(); + } } if (debugTransportPublisher()) { std::string transformString; diff --git a/libs/input/InputVerifier.cpp b/libs/input/InputVerifier.cpp index eb758045cc..9745e89234 100644 --- a/libs/input/InputVerifier.cpp +++ b/libs/input/InputVerifier.cpp @@ -18,111 +18,35 @@ #include <android-base/logging.h> #include <input/InputVerifier.h> +#include "input_verifier.rs.h" -namespace android { +using android::base::Error; +using android::base::Result; +using android::input::RustPointerProperties; -/** - * Log all of the movements that are sent to this verifier. Helps to identify the streams that lead - * to inconsistent events. - * Enable this via "adb shell setprop log.tag.InputVerifierLogEvents DEBUG" - */ -static bool logEvents() { - return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "LogEvents", ANDROID_LOG_INFO); -} +namespace android { // --- InputVerifier --- -InputVerifier::InputVerifier(const std::string& name) : mName(name){}; +InputVerifier::InputVerifier(const std::string& name) + : mVerifier(android::input::verifier::create(rust::String::lossy(name))){}; -void InputVerifier::processMovement(int32_t deviceId, int32_t action, uint32_t pointerCount, - const PointerProperties* pointerProperties, - const PointerCoords* pointerCoords, int32_t flags) { - if (logEvents()) { - LOG(ERROR) << "Processing " << MotionEvent::actionToString(action) << " for device " - << deviceId << " (" << pointerCount << " pointer" - << (pointerCount == 1 ? "" : "s") << ") on " << mName; +Result<void> InputVerifier::processMovement(int32_t deviceId, int32_t action, uint32_t pointerCount, + const PointerProperties* pointerProperties, + const PointerCoords* pointerCoords, int32_t flags) { + std::vector<RustPointerProperties> rpp; + for (size_t i = 0; i < pointerCount; i++) { + rpp.emplace_back(RustPointerProperties{.id = pointerProperties[i].id}); } - - switch (MotionEvent::getActionMasked(action)) { - case AMOTION_EVENT_ACTION_DOWN: { - auto [it, inserted] = mTouchingPointerIdsByDevice.insert({deviceId, {}}); - if (!inserted) { - LOG(FATAL) << "Got ACTION_DOWN, but already have touching pointers " << it->second - << " for device " << deviceId << " on " << mName; - } - it->second.set(pointerProperties[0].id); - break; - } - case AMOTION_EVENT_ACTION_POINTER_DOWN: { - auto it = mTouchingPointerIdsByDevice.find(deviceId); - if (it == mTouchingPointerIdsByDevice.end()) { - LOG(FATAL) << "Got POINTER_DOWN, but no touching pointers for device " << deviceId - << " on " << mName; - } - it->second.set(pointerProperties[MotionEvent::getActionIndex(action)].id); - break; - } - case AMOTION_EVENT_ACTION_MOVE: { - ensureTouchingPointersMatch(deviceId, pointerCount, pointerProperties, "MOVE"); - break; - } - case AMOTION_EVENT_ACTION_POINTER_UP: { - auto it = mTouchingPointerIdsByDevice.find(deviceId); - if (it == mTouchingPointerIdsByDevice.end()) { - LOG(FATAL) << "Got POINTER_UP, but no touching pointers for device " << deviceId - << " on " << mName; - } - it->second.reset(pointerProperties[MotionEvent::getActionIndex(action)].id); - break; - } - case AMOTION_EVENT_ACTION_UP: { - auto it = mTouchingPointerIdsByDevice.find(deviceId); - if (it == mTouchingPointerIdsByDevice.end()) { - LOG(FATAL) << "Got ACTION_UP, but no record for deviceId " << deviceId << " on " - << mName; - } - const auto& [_, touchingPointerIds] = *it; - if (touchingPointerIds.count() != 1) { - LOG(FATAL) << "Got ACTION_UP, but we have pointers: " << touchingPointerIds - << " for deviceId " << deviceId << " on " << mName; - } - const int32_t pointerId = pointerProperties[0].id; - if (!touchingPointerIds.test(pointerId)) { - LOG(FATAL) << "Got ACTION_UP, but pointerId " << pointerId - << " is not touching. Touching pointers: " << touchingPointerIds - << " for deviceId " << deviceId << " on " << mName; - } - mTouchingPointerIdsByDevice.erase(it); - break; - } - case AMOTION_EVENT_ACTION_CANCEL: { - if ((flags & AMOTION_EVENT_FLAG_CANCELED) != AMOTION_EVENT_FLAG_CANCELED) { - LOG(FATAL) << "For ACTION_CANCEL, must set FLAG_CANCELED"; - } - ensureTouchingPointersMatch(deviceId, pointerCount, pointerProperties, "CANCEL"); - mTouchingPointerIdsByDevice.erase(deviceId); - break; - } + rust::Slice<const RustPointerProperties> properties{rpp.data(), rpp.size()}; + rust::String errorMessage = + android::input::verifier::process_movement(*mVerifier, deviceId, action, properties, + flags); + if (errorMessage.empty()) { + return {}; + } else { + return Error() << errorMessage; } } -void InputVerifier::ensureTouchingPointersMatch(int32_t deviceId, uint32_t pointerCount, - const PointerProperties* pointerProperties, - const char* action) const { - auto it = mTouchingPointerIdsByDevice.find(deviceId); - if (it == mTouchingPointerIdsByDevice.end()) { - LOG(FATAL) << "Got " << action << ", but no touching pointers for device " << deviceId - << " on " << mName; - } - const auto& [_, touchingPointerIds] = *it; - for (size_t i = 0; i < pointerCount; i++) { - const int32_t pointerId = pointerProperties[i].id; - if (!touchingPointerIds.test(pointerId)) { - LOG(FATAL) << "Got " << action << " for pointerId " << pointerId - << " but the touching pointers are " << touchingPointerIds << " on " - << mName; - } - } -}; - } // namespace android diff --git a/libs/input/InputWrapper.hpp b/libs/input/InputWrapper.hpp new file mode 100644 index 0000000000..a01080d319 --- /dev/null +++ b/libs/input/InputWrapper.hpp @@ -0,0 +1,18 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <android/input.h> +#include "input/Input.h" diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index abcca345d3..5736ad7eed 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -36,9 +36,6 @@ namespace android { namespace { -const int64_t PREDICTION_INTERVAL_NANOS = - 12500000 / 3; // TODO(b/266747937): Get this from the model. - /** * Log debug messages about predictions. * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG" @@ -70,7 +67,7 @@ MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) { // We still have an active gesture for another device. The provided MotionEvent is not - // consistent the previous gesture. + // consistent with the previous gesture. LOG(ERROR) << "Inconsistent event stream: last event is " << *mLastEvent << ", but " << __func__ << " is called with " << event; return android::base::Error() @@ -86,9 +83,10 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { // Initialise the model now that it's likely to be used. if (!mModel) { mModel = TfLiteMotionPredictorModel::create(); + LOG_ALWAYS_FATAL_IF(!mModel); } - if (mBuffers == nullptr) { + if (!mBuffers) { mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength()); } @@ -136,6 +134,13 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { mLastEvent = MotionEvent(); } mLastEvent->copyFrom(&event, /*keepHistory=*/false); + + // Pass input event to the MetricsManager. + if (!mMetricsManager) { + mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength()); + } + mMetricsManager->onRecord(event); + return {}; } @@ -178,19 +183,30 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { for (size_t i = 0; i < static_cast<size_t>(predictedR.size()) && predictionTime <= futureTime; ++i) { - const TfLiteMotionPredictorSample::Point point = - convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]); + if (predictedR[i] < mModel->config().distanceNoiseFloor) { + // Stop predicting when the predicted output is below the model's noise floor. + // + // We assume that all subsequent predictions in the batch are unreliable because later + // predictions are conditional on earlier predictions, and a state of noise is not a + // good basis for prediction. + // + // The UX trade-off is that this potentially sacrifices some predictions when the input + // device starts to speed up, but avoids producing noisy predictions as it slows down. + break; + } // TODO(b/266747654): Stop predictions if confidence is < some threshold. - ALOGD_IF(isDebug(), "prediction %zu: %f, %f", i, point.x, point.y); + const TfLiteMotionPredictorSample::Point predictedPoint = + convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]); + + ALOGD_IF(isDebug(), "prediction %zu: %f, %f", i, predictedPoint.x, predictedPoint.y); PointerCoords coords; coords.clear(); - coords.setAxisValue(AMOTION_EVENT_AXIS_X, point.x); - coords.setAxisValue(AMOTION_EVENT_AXIS_Y, point.y); - // TODO(b/266747654): Stop predictions if predicted pressure is < some threshold. + coords.setAxisValue(AMOTION_EVENT_AXIS_X, predictedPoint.x); + coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y); coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]); - predictionTime += PREDICTION_INTERVAL_NANOS; + predictionTime += mModel->config().predictionInterval; if (i == 0) { hasPredictions = true; prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(), @@ -207,12 +223,17 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { } axisFrom = axisTo; - axisTo = point; + axisTo = predictedPoint; } - // TODO(b/266747511): Interpolate to futureTime? + if (!hasPredictions) { return nullptr; } + + // Pass predictions to the MetricsManager. + LOG_ALWAYS_FATAL_IF(!mMetricsManager); + mMetricsManager->onPredict(*prediction); + return prediction; } diff --git a/libs/input/MotionPredictorMetricsManager.cpp b/libs/input/MotionPredictorMetricsManager.cpp new file mode 100644 index 0000000000..67b103290f --- /dev/null +++ b/libs/input/MotionPredictorMetricsManager.cpp @@ -0,0 +1,373 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "MotionPredictorMetricsManager" + +#include <input/MotionPredictorMetricsManager.h> + +#include <algorithm> + +#include <android-base/logging.h> + +#include "Eigen/Core" +#include "Eigen/Geometry" + +#ifdef __ANDROID__ +#include <statslog_libinput.h> +#endif + +namespace android { +namespace { + +inline constexpr int NANOS_PER_SECOND = 1'000'000'000; // nanoseconds per second +inline constexpr int NANOS_PER_MILLIS = 1'000'000; // nanoseconds per millisecond + +// Velocity threshold at which we report "high-velocity" metrics, in pixels per second. +// This value was selected from manual experimentation, as a threshold that separates "fast" +// (semi-sloppy) handwriting from more careful medium to slow handwriting. +inline constexpr float HIGH_VELOCITY_THRESHOLD = 1100.0; + +// Small value to add to the path length when computing scale-invariant error to avoid division by +// zero. +inline constexpr float PATH_LENGTH_EPSILON = 0.001; + +} // namespace + +MotionPredictorMetricsManager::MotionPredictorMetricsManager(nsecs_t predictionInterval, + size_t maxNumPredictions) + : mPredictionInterval(predictionInterval), + mMaxNumPredictions(maxNumPredictions), + mRecentGroundTruthPoints(maxNumPredictions + 1), + mAggregatedMetrics(maxNumPredictions), + mAtomFields(maxNumPredictions) {} + +void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) { + // Convert MotionEvent to GroundTruthPoint. + const PointerCoords* coords = inputEvent.getRawPointerCoords(/*pointerIndex=*/0); + LOG_ALWAYS_FATAL_IF(coords == nullptr); + const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f{coords->getY(), + coords->getX()}, + .pressure = + inputEvent.getPressure(/*pointerIndex=*/0)}, + .timestamp = inputEvent.getEventTime()}; + + // Handle event based on action type. + switch (inputEvent.getActionMasked()) { + case AMOTION_EVENT_ACTION_DOWN: { + clearStrokeData(); + incorporateNewGroundTruth(groundTruthPoint); + break; + } + case AMOTION_EVENT_ACTION_MOVE: { + incorporateNewGroundTruth(groundTruthPoint); + break; + } + case AMOTION_EVENT_ACTION_UP: + case AMOTION_EVENT_ACTION_CANCEL: { + // Only expect meaningful predictions when given at least two input points. + if (mRecentGroundTruthPoints.size() >= 2) { + computeAtomFields(); + reportMetrics(); + break; + } + } + } +} + +// Adds new predictions to mRecentPredictions and maintains the invariant that elements are +// sorted in ascending order of targetTimestamp. +void MotionPredictorMetricsManager::onPredict(const MotionEvent& predictionEvent) { + for (size_t i = 0; i < predictionEvent.getHistorySize() + 1; ++i) { + // Convert MotionEvent to PredictionPoint. + const PointerCoords* coords = + predictionEvent.getHistoricalRawPointerCoords(/*pointerIndex=*/0, i); + LOG_ALWAYS_FATAL_IF(coords == nullptr); + const nsecs_t targetTimestamp = predictionEvent.getHistoricalEventTime(i); + mRecentPredictions.push_back( + PredictionPoint{{.position = Eigen::Vector2f{coords->getY(), coords->getX()}, + .pressure = + predictionEvent.getHistoricalPressure(/*pointerIndex=*/0, + i)}, + .originTimestamp = mRecentGroundTruthPoints.back().timestamp, + .targetTimestamp = targetTimestamp}); + } + + std::sort(mRecentPredictions.begin(), mRecentPredictions.end()); +} + +void MotionPredictorMetricsManager::clearStrokeData() { + mRecentGroundTruthPoints.clear(); + mRecentPredictions.clear(); + std::fill(mAggregatedMetrics.begin(), mAggregatedMetrics.end(), AggregatedStrokeMetrics{}); + std::fill(mAtomFields.begin(), mAtomFields.end(), AtomFields{}); +} + +void MotionPredictorMetricsManager::incorporateNewGroundTruth( + const GroundTruthPoint& groundTruthPoint) { + // Note: this removes the oldest point if `mRecentGroundTruthPoints` is already at capacity. + mRecentGroundTruthPoints.pushBack(groundTruthPoint); + + // Remove outdated predictions – those that can never be matched with the current or any future + // ground truth points. We use fuzzy association for the timestamps here, because ground truth + // and prediction timestamps may not be perfectly synchronized. + const nsecs_t fuzzy_association_time_delta = mPredictionInterval / 4; + const auto firstCurrentIt = + std::find_if(mRecentPredictions.begin(), mRecentPredictions.end(), + [&groundTruthPoint, + fuzzy_association_time_delta](const PredictionPoint& prediction) { + return prediction.targetTimestamp > + groundTruthPoint.timestamp - fuzzy_association_time_delta; + }); + mRecentPredictions.erase(mRecentPredictions.begin(), firstCurrentIt); + + // Fuzzily match the new ground truth's timestamp to recent predictions' targetTimestamp and + // update the corresponding metrics. + for (const PredictionPoint& prediction : mRecentPredictions) { + if ((prediction.targetTimestamp > + groundTruthPoint.timestamp - fuzzy_association_time_delta) && + (prediction.targetTimestamp < + groundTruthPoint.timestamp + fuzzy_association_time_delta)) { + updateAggregatedMetrics(prediction); + } + } +} + +void MotionPredictorMetricsManager::updateAggregatedMetrics( + const PredictionPoint& predictionPoint) { + if (mRecentGroundTruthPoints.size() < 2) { + return; + } + + const GroundTruthPoint& latestGroundTruthPoint = mRecentGroundTruthPoints.back(); + const GroundTruthPoint& previousGroundTruthPoint = + mRecentGroundTruthPoints[mRecentGroundTruthPoints.size() - 2]; + // Calculate prediction error vector. + const Eigen::Vector2f groundTruthTrajectory = + latestGroundTruthPoint.position - previousGroundTruthPoint.position; + const Eigen::Vector2f predictionTrajectory = + predictionPoint.position - previousGroundTruthPoint.position; + const Eigen::Vector2f predictionError = predictionTrajectory - groundTruthTrajectory; + + // By default, prediction error counts fully as both off-trajectory and along-trajectory error. + // This serves as the fallback when the two most recent ground truth points are equal. + const float predictionErrorNorm = predictionError.norm(); + float alongTrajectoryError = predictionErrorNorm; + float offTrajectoryError = predictionErrorNorm; + if (groundTruthTrajectory.squaredNorm() > 0) { + // Rotate the prediction error vector by the angle of the ground truth trajectory vector. + // This yields a vector whose first component is the along-trajectory error and whose + // second component is the off-trajectory error. + const float theta = std::atan2(groundTruthTrajectory[1], groundTruthTrajectory[0]); + const Eigen::Vector2f rotatedPredictionError = Eigen::Rotation2Df(-theta) * predictionError; + alongTrajectoryError = rotatedPredictionError[0]; + offTrajectoryError = rotatedPredictionError[1]; + } + + // Compute the multiple of mPredictionInterval nearest to the amount of time into the + // future being predicted. This serves as the time bucket index into mAggregatedMetrics. + const float timestampDeltaFloat = + static_cast<float>(predictionPoint.targetTimestamp - predictionPoint.originTimestamp); + const size_t tIndex = + static_cast<size_t>(std::round(timestampDeltaFloat / mPredictionInterval - 1)); + + // Aggregate values into "general errors". + mAggregatedMetrics[tIndex].alongTrajectoryErrorSum += alongTrajectoryError; + mAggregatedMetrics[tIndex].alongTrajectorySumSquaredErrors += + alongTrajectoryError * alongTrajectoryError; + mAggregatedMetrics[tIndex].offTrajectorySumSquaredErrors += + offTrajectoryError * offTrajectoryError; + const float pressureError = predictionPoint.pressure - latestGroundTruthPoint.pressure; + mAggregatedMetrics[tIndex].pressureSumSquaredErrors += pressureError * pressureError; + ++mAggregatedMetrics[tIndex].generalErrorsCount; + + // Aggregate values into high-velocity metrics, if we are in one of the last two time buckets + // and the velocity is above the threshold. Velocity here is measured in pixels per second. + const float velocity = groundTruthTrajectory.norm() / + (static_cast<float>(latestGroundTruthPoint.timestamp - + previousGroundTruthPoint.timestamp) / + NANOS_PER_SECOND); + if ((tIndex + 2 >= mMaxNumPredictions) && (velocity > HIGH_VELOCITY_THRESHOLD)) { + mAggregatedMetrics[tIndex].highVelocityAlongTrajectorySse += + alongTrajectoryError * alongTrajectoryError; + mAggregatedMetrics[tIndex].highVelocityOffTrajectorySse += + offTrajectoryError * offTrajectoryError; + ++mAggregatedMetrics[tIndex].highVelocityErrorsCount; + } + + // Compute path length for scale-invariant errors. + float pathLength = 0; + for (size_t i = 1; i < mRecentGroundTruthPoints.size(); ++i) { + pathLength += + (mRecentGroundTruthPoints[i].position - mRecentGroundTruthPoints[i - 1].position) + .norm(); + } + // Avoid overweighting errors at the beginning of a stroke: compute the path length as if there + // were a full ground truth history by filling in missing segments with the average length. + // Note: the "- 1" is needed to translate from number of endpoints to number of segments. + pathLength *= static_cast<float>(mRecentGroundTruthPoints.capacity() - 1) / + (mRecentGroundTruthPoints.size() - 1); + pathLength += PATH_LENGTH_EPSILON; // Ensure path length is nonzero (>= PATH_LENGTH_EPSILON). + + // Compute and aggregate scale-invariant errors. + const float scaleInvariantAlongTrajectoryError = alongTrajectoryError / pathLength; + const float scaleInvariantOffTrajectoryError = offTrajectoryError / pathLength; + mAggregatedMetrics[tIndex].scaleInvariantAlongTrajectorySse += + scaleInvariantAlongTrajectoryError * scaleInvariantAlongTrajectoryError; + mAggregatedMetrics[tIndex].scaleInvariantOffTrajectorySse += + scaleInvariantOffTrajectoryError * scaleInvariantOffTrajectoryError; + ++mAggregatedMetrics[tIndex].scaleInvariantErrorsCount; +} + +void MotionPredictorMetricsManager::computeAtomFields() { + for (size_t i = 0; i < mAggregatedMetrics.size(); ++i) { + if (mAggregatedMetrics[i].generalErrorsCount == 0) { + // We have not received data corresponding to metrics for this time bucket. + continue; + } + + mAtomFields[i].deltaTimeBucketMilliseconds = + static_cast<int>(mPredictionInterval / NANOS_PER_MILLIS * (i + 1)); + + // Note: we need the "* 1000"s below because we report values in integral milli-units. + + { // General errors: reported for every time bucket. + const float alongTrajectoryErrorMean = mAggregatedMetrics[i].alongTrajectoryErrorSum / + mAggregatedMetrics[i].generalErrorsCount; + mAtomFields[i].alongTrajectoryErrorMeanMillipixels = + static_cast<int>(alongTrajectoryErrorMean * 1000); + + const float alongTrajectoryMse = mAggregatedMetrics[i].alongTrajectorySumSquaredErrors / + mAggregatedMetrics[i].generalErrorsCount; + // Take the max with 0 to avoid negative values caused by numerical instability. + const float alongTrajectoryErrorVariance = + std::max(0.0f, + alongTrajectoryMse - + alongTrajectoryErrorMean * alongTrajectoryErrorMean); + const float alongTrajectoryErrorStd = std::sqrt(alongTrajectoryErrorVariance); + mAtomFields[i].alongTrajectoryErrorStdMillipixels = + static_cast<int>(alongTrajectoryErrorStd * 1000); + + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].offTrajectorySumSquaredErrors < 0, + "mAggregatedMetrics[%zu].offTrajectorySumSquaredErrors = %f should " + "not be negative", + i, mAggregatedMetrics[i].offTrajectorySumSquaredErrors); + const float offTrajectoryRmse = + std::sqrt(mAggregatedMetrics[i].offTrajectorySumSquaredErrors / + mAggregatedMetrics[i].generalErrorsCount); + mAtomFields[i].offTrajectoryRmseMillipixels = + static_cast<int>(offTrajectoryRmse * 1000); + + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].pressureSumSquaredErrors < 0, + "mAggregatedMetrics[%zu].pressureSumSquaredErrors = %f should not " + "be negative", + i, mAggregatedMetrics[i].pressureSumSquaredErrors); + const float pressureRmse = std::sqrt(mAggregatedMetrics[i].pressureSumSquaredErrors / + mAggregatedMetrics[i].generalErrorsCount); + mAtomFields[i].pressureRmseMilliunits = static_cast<int>(pressureRmse * 1000); + } + + // High-velocity errors: reported only for last two time buckets. + // Check if we are in one of the last two time buckets, and there is high-velocity data. + if ((i + 2 >= mMaxNumPredictions) && (mAggregatedMetrics[i].highVelocityErrorsCount > 0)) { + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityAlongTrajectorySse < 0, + "mAggregatedMetrics[%zu].highVelocityAlongTrajectorySse = %f " + "should not be negative", + i, mAggregatedMetrics[i].highVelocityAlongTrajectorySse); + const float alongTrajectoryRmse = + std::sqrt(mAggregatedMetrics[i].highVelocityAlongTrajectorySse / + mAggregatedMetrics[i].highVelocityErrorsCount); + mAtomFields[i].highVelocityAlongTrajectoryRmse = + static_cast<int>(alongTrajectoryRmse * 1000); + + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityOffTrajectorySse < 0, + "mAggregatedMetrics[%zu].highVelocityOffTrajectorySse = %f should " + "not be negative", + i, mAggregatedMetrics[i].highVelocityOffTrajectorySse); + const float offTrajectoryRmse = + std::sqrt(mAggregatedMetrics[i].highVelocityOffTrajectorySse / + mAggregatedMetrics[i].highVelocityErrorsCount); + mAtomFields[i].highVelocityOffTrajectoryRmse = + static_cast<int>(offTrajectoryRmse * 1000); + } + + // Scale-invariant errors: reported only for the last time bucket, where the values + // represent an average across all time buckets. + if (i + 1 == mMaxNumPredictions) { + // Compute error averages. + float alongTrajectoryRmseSum = 0; + float offTrajectoryRmseSum = 0; + for (size_t j = 0; j < mAggregatedMetrics.size(); ++j) { + // If we have general errors (checked above), we should always also have + // scale-invariant errors. + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantErrorsCount == 0, + "mAggregatedMetrics[%zu].scaleInvariantErrorsCount is 0", j); + + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse < 0, + "mAggregatedMetrics[%zu].scaleInvariantAlongTrajectorySse = %f " + "should not be negative", + j, mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse); + alongTrajectoryRmseSum += + std::sqrt(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse / + mAggregatedMetrics[j].scaleInvariantErrorsCount); + + LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse < 0, + "mAggregatedMetrics[%zu].scaleInvariantOffTrajectorySse = %f " + "should not be negative", + j, mAggregatedMetrics[j].scaleInvariantOffTrajectorySse); + offTrajectoryRmseSum += + std::sqrt(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse / + mAggregatedMetrics[j].scaleInvariantErrorsCount); + } + + const float averageAlongTrajectoryRmse = + alongTrajectoryRmseSum / mAggregatedMetrics.size(); + mAtomFields.back().scaleInvariantAlongTrajectoryRmse = + static_cast<int>(averageAlongTrajectoryRmse * 1000); + + const float averageOffTrajectoryRmse = offTrajectoryRmseSum / mAggregatedMetrics.size(); + mAtomFields.back().scaleInvariantOffTrajectoryRmse = + static_cast<int>(averageOffTrajectoryRmse * 1000); + } + } +} + +void MotionPredictorMetricsManager::reportMetrics() { + // Report one atom for each time bucket. + for (size_t i = 0; i < mAtomFields.size(); ++i) { + // Call stats_write logging function only on Android targets (not supported on host). +#ifdef __ANDROID__ + android::stats::libinput:: + stats_write(android::stats::libinput::STYLUS_PREDICTION_METRICS_REPORTED, + /*stylus_vendor_id=*/0, + /*stylus_product_id=*/0, mAtomFields[i].deltaTimeBucketMilliseconds, + mAtomFields[i].alongTrajectoryErrorMeanMillipixels, + mAtomFields[i].alongTrajectoryErrorStdMillipixels, + mAtomFields[i].offTrajectoryRmseMillipixels, + mAtomFields[i].pressureRmseMilliunits, + mAtomFields[i].highVelocityAlongTrajectoryRmse, + mAtomFields[i].highVelocityOffTrajectoryRmse, + mAtomFields[i].scaleInvariantAlongTrajectoryRmse, + mAtomFields[i].scaleInvariantOffTrajectoryRmse); +#endif + } + + // Set mock atom fields, if available. + if (mMockLoggedAtomFields != nullptr) { + *mMockLoggedAtomFields = mAtomFields; + } +} + +} // namespace android diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index 8d10ff56b0..d17476e216 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -36,6 +36,7 @@ #define ATRACE_TAG ATRACE_TAG_INPUT #include <cutils/trace.h> #include <log/log.h> +#include <utils/Timers.h> #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -44,6 +45,8 @@ #include "tensorflow/lite/model.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tinyxml2.h" + namespace android { namespace { @@ -72,16 +75,41 @@ bool fileExists(const char* filename) { std::string getModelPath() { #if defined(__ANDROID__) - static const char* oemModel = "/vendor/etc/motion_predictor_model.fb"; + static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite"; if (fileExists(oemModel)) { return oemModel; } - return "/system/etc/motion_predictor_model.fb"; + return "/system/etc/motion_predictor_model.tflite"; #else - return base::GetExecutableDirectory() + "/motion_predictor_model.fb"; + return base::GetExecutableDirectory() + "/motion_predictor_model.tflite"; #endif } +std::string getConfigPath() { + // The config file should be alongside the model file. + return base::Dirname(getModelPath()) + "/motion_predictor_config.xml"; +} + +int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) { + const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); + LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); + + int64_t value = 0; + LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS, + "Failed to parse %s: %s", elementName, element->GetText()); + return value; +} + +float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) { + const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); + LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); + + float value = 0; + LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS, + "Failed to parse %s: %s", elementName, element->GetText()); + return value; +} + // A TFLite ErrorReporter that logs to logcat. class LoggingErrorReporter : public tflite::ErrorReporter { public: @@ -133,6 +161,7 @@ std::unique_ptr<tflite::OpResolver> createOpResolver() { ::tflite::ops::builtin::Register_CONCATENATION()); resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, ::tflite::ops::builtin::Register_FULLY_CONNECTED()); + resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU()); return resolver; } @@ -189,13 +218,7 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, float phi = 0; float orientation = 0; - // Ignore the sample if there is no movement. These samples can occur when there's change to a - // property other than the coordinates and pollute the input to the model. - if (r == 0) { - return; - } - - if (!mAxisFrom) { // Second point. + if (!mAxisFrom && r > 0) { // Second point. // We can only determine the distance from the first point, and not any // angle. However, if the second point forms an axis, the orientation can // be transformed relative to that axis. @@ -216,8 +239,10 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, } // Update the axis for the next point. - mAxisFrom = mAxisTo; - mAxisTo = sample; + if (r > 0) { + mAxisFrom = mAxisTo; + mAxisTo = sample; + } // Push the current sample onto the end of the input buffers. mInputR.pushBack(r); @@ -245,13 +270,26 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() PLOG(FATAL) << "Failed to mmap model"; } + const std::string configPath = getConfigPath(); + tinyxml2::XMLDocument configDocument; + LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS, + "Failed to load config file from %s", configPath.c_str()); + + // Parse configuration file. + const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor"); + LOG_ALWAYS_FATAL_IF(!configRoot); + Config config{ + .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"), + .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), + }; + return std::unique_ptr<TfLiteMotionPredictorModel>( - new TfLiteMotionPredictorModel(std::move(modelBuffer))); + new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config))); } TfLiteMotionPredictorModel::TfLiteMotionPredictorModel( - std::unique_ptr<android::base::MappedFile> model) - : mFlatBuffer(std::move(model)) { + std::unique_ptr<android::base::MappedFile> model, Config config) + : mFlatBuffer(std::move(model)), mConfig(std::move(config)) { CHECK(mFlatBuffer); mErrorReporter = std::make_unique<LoggingErrorReporter>(); mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(), diff --git a/libs/input/VelocityTracker.cpp b/libs/input/VelocityTracker.cpp index 8551e5fa1c..078109a5b6 100644 --- a/libs/input/VelocityTracker.cpp +++ b/libs/input/VelocityTracker.cpp @@ -16,7 +16,9 @@ #define LOG_TAG "VelocityTracker" +#include <android-base/logging.h> #include <array> +#include <ftl/enum.h> #include <inttypes.h> #include <limits.h> #include <math.h> @@ -145,27 +147,19 @@ static std::string matrixToString(const float* a, uint32_t m, uint32_t n, bool r VelocityTracker::VelocityTracker(const Strategy strategy) : mLastEventTime(0), mCurrentPointerIdBits(0), mOverrideStrategy(strategy) {} -VelocityTracker::~VelocityTracker() { -} - bool VelocityTracker::isAxisSupported(int32_t axis) { return DEFAULT_STRATEGY_BY_AXIS.find(axis) != DEFAULT_STRATEGY_BY_AXIS.end(); } void VelocityTracker::configureStrategy(int32_t axis) { const bool isDifferentialAxis = DIFFERENTIAL_AXES.find(axis) != DIFFERENTIAL_AXES.end(); - - std::unique_ptr<VelocityTrackerStrategy> createdStrategy; - if (mOverrideStrategy != VelocityTracker::Strategy::DEFAULT) { - createdStrategy = createStrategy(mOverrideStrategy, /*deltaValues=*/isDifferentialAxis); + if (isDifferentialAxis || mOverrideStrategy == VelocityTracker::Strategy::DEFAULT) { + // Do not allow overrides of strategies for differential axes, for now. + mConfiguredStrategies[axis] = createStrategy(DEFAULT_STRATEGY_BY_AXIS.at(axis), + /*deltaValues=*/isDifferentialAxis); } else { - createdStrategy = createStrategy(DEFAULT_STRATEGY_BY_AXIS.at(axis), - /*deltaValues=*/isDifferentialAxis); + mConfiguredStrategies[axis] = createStrategy(mOverrideStrategy, /*deltaValues=*/false); } - - LOG_ALWAYS_FATAL_IF(createdStrategy == nullptr, - "Could not create velocity tracker strategy for axis '%" PRId32 "'!", axis); - mConfiguredStrategies[axis] = std::move(createdStrategy); } std::unique_ptr<VelocityTrackerStrategy> VelocityTracker::createStrategy( @@ -213,6 +207,9 @@ std::unique_ptr<VelocityTrackerStrategy> VelocityTracker::createStrategy( default: break; } + LOG(FATAL) << "Invalid strategy: " << ftl::enum_string(strategy) + << ", deltaValues = " << deltaValues; + return nullptr; } diff --git a/libs/input/ffi/FromRustToCpp.h b/libs/input/ffi/FromRustToCpp.h new file mode 100644 index 0000000000..889945c32b --- /dev/null +++ b/libs/input/ffi/FromRustToCpp.h @@ -0,0 +1,23 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rust/cxx.h" + +namespace android { + +bool shouldLog(rust::Str tag); + +} // namespace android diff --git a/libs/input/input_verifier.rs b/libs/input/input_verifier.rs new file mode 100644 index 0000000000..dd2ac4ca91 --- /dev/null +++ b/libs/input/input_verifier.rs @@ -0,0 +1,422 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Validate the incoming motion stream. +//! This class is not thread-safe. +//! State is stored in the "InputVerifier" object +//! that can be created via the 'create' method. +//! Usage: +//! Box<InputVerifier> verifier = create("inputChannel name"); +//! result = process_movement(verifier, ...); +//! if (result) { +//! crash(result.error_message()); +//! } + +use std::collections::HashMap; +use std::collections::HashSet; + +use bitflags::bitflags; +use log::info; + +#[cxx::bridge(namespace = "android::input")] +#[allow(unsafe_op_in_unsafe_fn)] +mod ffi { + #[namespace = "android"] + unsafe extern "C++" { + include!("ffi/FromRustToCpp.h"); + fn shouldLog(tag: &str) -> bool; + } + #[namespace = "android::input::verifier"] + extern "Rust" { + type InputVerifier; + + fn create(name: String) -> Box<InputVerifier>; + fn process_movement( + verifier: &mut InputVerifier, + device_id: i32, + action: u32, + pointer_properties: &[RustPointerProperties], + flags: i32, + ) -> String; + } + + pub struct RustPointerProperties { + id: i32, + } +} + +use crate::ffi::shouldLog; +use crate::ffi::RustPointerProperties; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +struct DeviceId(i32); + +fn process_movement( + verifier: &mut InputVerifier, + device_id: i32, + action: u32, + pointer_properties: &[RustPointerProperties], + flags: i32, +) -> String { + let result = verifier.process_movement( + DeviceId(device_id), + action, + pointer_properties, + Flags::from_bits(flags).unwrap(), + ); + match result { + Ok(()) => "".to_string(), + Err(e) => e, + } +} + +fn create(name: String) -> Box<InputVerifier> { + Box::new(InputVerifier::new(&name)) +} + +#[repr(u32)] +enum MotionAction { + Down = input_bindgen::AMOTION_EVENT_ACTION_DOWN, + Up = input_bindgen::AMOTION_EVENT_ACTION_UP, + Move = input_bindgen::AMOTION_EVENT_ACTION_MOVE, + Cancel = input_bindgen::AMOTION_EVENT_ACTION_CANCEL, + Outside = input_bindgen::AMOTION_EVENT_ACTION_OUTSIDE, + PointerDown { action_index: usize } = input_bindgen::AMOTION_EVENT_ACTION_POINTER_DOWN, + PointerUp { action_index: usize } = input_bindgen::AMOTION_EVENT_ACTION_POINTER_UP, + HoverEnter = input_bindgen::AMOTION_EVENT_ACTION_HOVER_ENTER, + HoverMove = input_bindgen::AMOTION_EVENT_ACTION_HOVER_MOVE, + HoverExit = input_bindgen::AMOTION_EVENT_ACTION_HOVER_EXIT, + Scroll = input_bindgen::AMOTION_EVENT_ACTION_SCROLL, + ButtonPress = input_bindgen::AMOTION_EVENT_ACTION_BUTTON_PRESS, + ButtonRelease = input_bindgen::AMOTION_EVENT_ACTION_BUTTON_RELEASE, +} + +fn get_action_index(action: u32) -> usize { + let index = (action & input_bindgen::AMOTION_EVENT_ACTION_POINTER_INDEX_MASK) + >> input_bindgen::AMOTION_EVENT_ACTION_POINTER_INDEX_SHIFT; + index.try_into().unwrap() +} + +impl From<u32> for MotionAction { + fn from(action: u32) -> Self { + let action_masked = action & input_bindgen::AMOTION_EVENT_ACTION_MASK; + let action_index = get_action_index(action); + match action_masked { + input_bindgen::AMOTION_EVENT_ACTION_DOWN => MotionAction::Down, + input_bindgen::AMOTION_EVENT_ACTION_UP => MotionAction::Up, + input_bindgen::AMOTION_EVENT_ACTION_MOVE => MotionAction::Move, + input_bindgen::AMOTION_EVENT_ACTION_CANCEL => MotionAction::Cancel, + input_bindgen::AMOTION_EVENT_ACTION_OUTSIDE => MotionAction::Outside, + input_bindgen::AMOTION_EVENT_ACTION_POINTER_DOWN => { + MotionAction::PointerDown { action_index } + } + input_bindgen::AMOTION_EVENT_ACTION_POINTER_UP => { + MotionAction::PointerUp { action_index } + } + input_bindgen::AMOTION_EVENT_ACTION_HOVER_ENTER => MotionAction::HoverEnter, + input_bindgen::AMOTION_EVENT_ACTION_HOVER_MOVE => MotionAction::HoverMove, + input_bindgen::AMOTION_EVENT_ACTION_HOVER_EXIT => MotionAction::HoverExit, + input_bindgen::AMOTION_EVENT_ACTION_SCROLL => MotionAction::Scroll, + input_bindgen::AMOTION_EVENT_ACTION_BUTTON_PRESS => MotionAction::ButtonPress, + input_bindgen::AMOTION_EVENT_ACTION_BUTTON_RELEASE => MotionAction::ButtonRelease, + _ => panic!("Unknown action: {}", action), + } + } +} + +bitflags! { + struct Flags: i32 { + const CANCELED = input_bindgen::AMOTION_EVENT_FLAG_CANCELED; + } +} + +fn motion_action_to_string(action: u32) -> String { + match action.into() { + MotionAction::Down => "DOWN".to_string(), + MotionAction::Up => "UP".to_string(), + MotionAction::Move => "MOVE".to_string(), + MotionAction::Cancel => "CANCEL".to_string(), + MotionAction::Outside => "OUTSIDE".to_string(), + MotionAction::PointerDown { action_index } => { + format!("POINTER_DOWN({})", action_index) + } + MotionAction::PointerUp { action_index } => { + format!("POINTER_UP({})", action_index) + } + MotionAction::HoverMove => "HOVER_MOVE".to_string(), + MotionAction::Scroll => "SCROLL".to_string(), + MotionAction::HoverEnter => "HOVER_ENTER".to_string(), + MotionAction::HoverExit => "HOVER_EXIT".to_string(), + MotionAction::ButtonPress => "BUTTON_PRESS".to_string(), + MotionAction::ButtonRelease => "BUTTON_RELEASE".to_string(), + } +} + +/** + * Log all of the movements that are sent to this verifier. Helps to identify the streams that lead + * to inconsistent events. + * Enable this via "adb shell setprop log.tag.InputVerifierLogEvents DEBUG" + */ +fn log_events() -> bool { + shouldLog("InputVerifierLogEvents") +} + +struct InputVerifier { + name: String, + touching_pointer_ids_by_device: HashMap<DeviceId, HashSet<i32>>, +} + +impl InputVerifier { + fn new(name: &str) -> Self { + logger::init( + logger::Config::default() + .with_tag_on_device("InputVerifier") + .with_min_level(log::Level::Trace), + ); + Self { name: name.to_owned(), touching_pointer_ids_by_device: HashMap::new() } + } + + fn process_movement( + &mut self, + device_id: DeviceId, + action: u32, + pointer_properties: &[RustPointerProperties], + flags: Flags, + ) -> Result<(), String> { + if log_events() { + info!( + "Processing {} for device {:?} ({} pointer{}) on {}", + motion_action_to_string(action), + device_id, + pointer_properties.len(), + if pointer_properties.len() == 1 { "" } else { "s" }, + self.name + ); + } + + match action.into() { + MotionAction::Down => { + let it = self + .touching_pointer_ids_by_device + .entry(device_id) + .or_insert_with(HashSet::new); + let pointer_id = pointer_properties[0].id; + if it.contains(&pointer_id) { + return Err(format!( + "{}: Invalid DOWN event - pointers already down for device {:?}: {:?}", + self.name, device_id, it + )); + } + it.insert(pointer_id); + } + MotionAction::PointerDown { action_index } => { + if !self.touching_pointer_ids_by_device.contains_key(&device_id) { + return Err(format!( + "{}: Received POINTER_DOWN but no pointers are currently down \ + for device {:?}", + self.name, device_id + )); + } + let it = self.touching_pointer_ids_by_device.get_mut(&device_id).unwrap(); + let pointer_id = pointer_properties[action_index].id; + if it.contains(&pointer_id) { + return Err(format!( + "{}: Pointer with id={} not found in the properties", + self.name, pointer_id + )); + } + it.insert(pointer_id); + } + MotionAction::Move => { + if !self.ensure_touching_pointers_match(device_id, pointer_properties) { + return Err(format!( + "{}: ACTION_MOVE touching pointers don't match", + self.name + )); + } + } + MotionAction::PointerUp { action_index } => { + if !self.touching_pointer_ids_by_device.contains_key(&device_id) { + return Err(format!( + "{}: Received POINTER_UP but no pointers are currently down for device \ + {:?}", + self.name, device_id + )); + } + let it = self.touching_pointer_ids_by_device.get_mut(&device_id).unwrap(); + let pointer_id = pointer_properties[action_index].id; + it.remove(&pointer_id); + } + MotionAction::Up => { + if !self.touching_pointer_ids_by_device.contains_key(&device_id) { + return Err(format!( + "{} Received ACTION_UP but no pointers are currently down for device {:?}", + self.name, device_id + )); + } + let it = self.touching_pointer_ids_by_device.get_mut(&device_id).unwrap(); + if it.len() != 1 { + return Err(format!( + "{}: Got ACTION_UP, but we have pointers: {:?} for device {:?}", + self.name, it, device_id + )); + } + let pointer_id = pointer_properties[0].id; + if !it.contains(&pointer_id) { + return Err(format!( + "{}: Got ACTION_UP, but pointerId {} is not touching. Touching pointers:\ + {:?} for device {:?}", + self.name, pointer_id, it, device_id + )); + } + it.clear(); + } + MotionAction::Cancel => { + if flags.contains(Flags::CANCELED) { + return Err(format!( + "{}: For ACTION_CANCEL, must set FLAG_CANCELED", + self.name + )); + } + if !self.ensure_touching_pointers_match(device_id, pointer_properties) { + return Err(format!( + "{}: Got ACTION_CANCEL, but the pointers don't match. \ + Existing pointers: {:?}", + self.name, self.touching_pointer_ids_by_device + )); + } + self.touching_pointer_ids_by_device.remove(&device_id); + } + _ => return Ok(()), + } + Ok(()) + } + + fn ensure_touching_pointers_match( + &self, + device_id: DeviceId, + pointer_properties: &[RustPointerProperties], + ) -> bool { + let Some(pointers) = self.touching_pointer_ids_by_device.get(&device_id) else { + return false; + }; + + for pointer_property in pointer_properties.iter() { + let pointer_id = pointer_property.id; + if !pointers.contains(&pointer_id) { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use crate::DeviceId; + use crate::Flags; + use crate::InputVerifier; + use crate::RustPointerProperties; + #[test] + fn single_pointer_stream() { + let mut verifier = InputVerifier::new("Test"); + let pointer_properties = Vec::from([RustPointerProperties { id: 0 }]); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_DOWN, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_MOVE, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_UP, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + } + + #[test] + fn multi_device_stream() { + let mut verifier = InputVerifier::new("Test"); + let pointer_properties = Vec::from([RustPointerProperties { id: 0 }]); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_DOWN, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_MOVE, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + assert!(verifier + .process_movement( + DeviceId(2), + input_bindgen::AMOTION_EVENT_ACTION_DOWN, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + assert!(verifier + .process_movement( + DeviceId(2), + input_bindgen::AMOTION_EVENT_ACTION_MOVE, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_UP, + &pointer_properties, + Flags::empty(), + ) + .is_ok()); + } + + #[test] + fn test_invalid_up() { + let mut verifier = InputVerifier::new("Test"); + let pointer_properties = Vec::from([RustPointerProperties { id: 0 }]); + assert!(verifier + .process_movement( + DeviceId(1), + input_bindgen::AMOTION_EVENT_ACTION_UP, + &pointer_properties, + Flags::empty(), + ) + .is_err()); + } +} diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp index 42bdf57514..e7224ff752 100644 --- a/libs/input/tests/Android.bp +++ b/libs/input/tests/Android.bp @@ -18,7 +18,9 @@ cc_test { "InputDevice_test.cpp", "InputEvent_test.cpp", "InputPublisherAndConsumer_test.cpp", + "InputVerifier_test.cpp", "MotionPredictor_test.cpp", + "MotionPredictorMetricsManager_test.cpp", "RingBuffer_test.cpp", "TfLiteMotionPredictor_test.cpp", "TouchResampling_test.cpp", @@ -44,6 +46,7 @@ cc_test { "-Wno-unused-parameter", ], sanitize: { + hwaddress: true, undefined: true, all_undefined: true, diag: { @@ -56,17 +59,33 @@ cc_test { "libcutils", "liblog", "libPlatformProperties", + "libtinyxml2", "libutils", "libvintf", ], data: [ "data/*", - ":motion_predictor_model.fb", + ":motion_predictor_model", ], test_options: { unit_test: true, }, test_suites: ["device-tests"], + target: { + host: { + sanitize: { + address: true, + }, + }, + android: { + static_libs: [ + // Stats logging library and its dependencies. + "libstatslog_libinput", + "libstatsbootstrap", + "android.os.statsbootstrap_aidl-cpp", + ], + }, + }, } // NOTE: This is a compile time test, and does not need to be diff --git a/libs/input/tests/InputVerifier_test.cpp b/libs/input/tests/InputVerifier_test.cpp new file mode 100644 index 0000000000..e24fa6ed0b --- /dev/null +++ b/libs/input/tests/InputVerifier_test.cpp @@ -0,0 +1,29 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <input/InputVerifier.h> +#include <string> + +namespace android { + +TEST(InputVerifierTest, CreationWithInvalidUtfStringDoesNotCrash) { + constexpr char bytes[] = {static_cast<char>(0xC0), static_cast<char>(0x80)}; + const std::string name(bytes, sizeof(bytes)); + InputVerifier verifier(name); +} + +} // namespace android diff --git a/libs/input/tests/MotionPredictorMetricsManager_test.cpp b/libs/input/tests/MotionPredictorMetricsManager_test.cpp new file mode 100644 index 0000000000..b420a5a4e7 --- /dev/null +++ b/libs/input/tests/MotionPredictorMetricsManager_test.cpp @@ -0,0 +1,972 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <input/MotionPredictor.h> + +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <numeric> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <input/InputEventBuilders.h> +#include <utils/Timers.h> // for nsecs_t + +#include "Eigen/Core" +#include "Eigen/Geometry" + +namespace android { +namespace { + +using ::testing::FloatNear; +using ::testing::Matches; + +using GroundTruthPoint = MotionPredictorMetricsManager::GroundTruthPoint; +using PredictionPoint = MotionPredictorMetricsManager::PredictionPoint; +using AtomFields = MotionPredictorMetricsManager::AtomFields; + +inline constexpr int NANOS_PER_MILLIS = 1'000'000; + +inline constexpr nsecs_t TEST_INITIAL_TIMESTAMP = 1'000'000'000; +inline constexpr size_t TEST_MAX_NUM_PREDICTIONS = 5; +inline constexpr nsecs_t TEST_PREDICTION_INTERVAL_NANOS = 12'500'000 / 3; // 1 / (240 hz) +inline constexpr int NO_DATA_SENTINEL = MotionPredictorMetricsManager::NO_DATA_SENTINEL; + +// Parameters: +// • arg: Eigen::Vector2f +// • target: Eigen::Vector2f +// • epsilon: float +MATCHER_P2(Vector2fNear, target, epsilon, "") { + return Matches(FloatNear(target[0], epsilon))(arg[0]) && + Matches(FloatNear(target[1], epsilon))(arg[1]); +} + +// Parameters: +// • arg: PredictionPoint +// • target: PredictionPoint +// • epsilon: float +MATCHER_P2(PredictionPointNear, target, epsilon, "") { + if (!Matches(Vector2fNear(target.position, epsilon))(arg.position)) { + *result_listener << "Position mismatch. Actual: (" << arg.position[0] << ", " + << arg.position[1] << "), expected: (" << target.position[0] << ", " + << target.position[1] << ")"; + return false; + } + if (!Matches(FloatNear(target.pressure, epsilon))(arg.pressure)) { + *result_listener << "Pressure mismatch. Actual: " << arg.pressure + << ", expected: " << target.pressure; + return false; + } + if (arg.originTimestamp != target.originTimestamp) { + *result_listener << "Origin timestamp mismatch. Actual: " << arg.originTimestamp + << ", expected: " << target.originTimestamp; + return false; + } + if (arg.targetTimestamp != target.targetTimestamp) { + *result_listener << "Target timestamp mismatch. Actual: " << arg.targetTimestamp + << ", expected: " << target.targetTimestamp; + return false; + } + return true; +} + +// --- Mathematical helper functions. --- + +template <typename T> +T average(std::vector<T> values) { + return std::accumulate(values.begin(), values.end(), T{}) / static_cast<T>(values.size()); +} + +template <typename T> +T standardDeviation(std::vector<T> values) { + T mean = average(values); + T accumulator = {}; + for (const T value : values) { + accumulator += value * value - mean * mean; + } + // Take the max with 0 to avoid negative values caused by numerical instability. + return std::sqrt(std::max(T{}, accumulator) / static_cast<T>(values.size())); +} + +template <typename T> +T rmse(std::vector<T> errors) { + T sse = {}; + for (const T error : errors) { + sse += error * error; + } + return std::sqrt(sse / static_cast<T>(errors.size())); +} + +TEST(MathematicalHelperFunctionTest, Average) { + std::vector<float> values{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + EXPECT_EQ(5.5f, average(values)); +} + +TEST(MathematicalHelperFunctionTest, StandardDeviation) { + // https://www.calculator.net/standard-deviation-calculator.html?numberinputs=10%2C+12%2C+23%2C+23%2C+16%2C+23%2C+21%2C+16 + std::vector<float> values{10, 12, 23, 23, 16, 23, 21, 16}; + EXPECT_FLOAT_EQ(4.8989794855664f, standardDeviation(values)); +} + +TEST(MathematicalHelperFunctionTest, Rmse) { + std::vector<float> errors{1, 5, 7, 7, 8, 20}; + EXPECT_FLOAT_EQ(9.899494937f, rmse(errors)); +} + +// --- MotionEvent-related helper functions. --- + +// Creates a MotionEvent corresponding to the given GroundTruthPoint. +MotionEvent makeMotionEvent(const GroundTruthPoint& groundTruthPoint) { + // Build single pointer of type STYLUS, with coordinates from groundTruthPoint. + PointerBuilder pointerBuilder = + PointerBuilder(/*id=*/0, ToolType::STYLUS) + .x(groundTruthPoint.position[1]) + .y(groundTruthPoint.position[0]) + .axis(AMOTION_EVENT_AXIS_PRESSURE, groundTruthPoint.pressure); + return MotionEventBuilder(/*action=*/AMOTION_EVENT_ACTION_MOVE, + /*source=*/AINPUT_SOURCE_CLASS_POINTER) + .eventTime(groundTruthPoint.timestamp) + .pointer(pointerBuilder) + .build(); +} + +// Creates a MotionEvent corresponding to the given sequence of PredictionPoints. +MotionEvent makeMotionEvent(const std::vector<PredictionPoint>& predictionPoints) { + // Build single pointer of type STYLUS, with coordinates from first prediction point. + PointerBuilder pointerBuilder = + PointerBuilder(/*id=*/0, ToolType::STYLUS) + .x(predictionPoints[0].position[1]) + .y(predictionPoints[0].position[0]) + .axis(AMOTION_EVENT_AXIS_PRESSURE, predictionPoints[0].pressure); + MotionEvent predictionEvent = + MotionEventBuilder( + /*action=*/AMOTION_EVENT_ACTION_MOVE, /*source=*/AINPUT_SOURCE_CLASS_POINTER) + .eventTime(predictionPoints[0].targetTimestamp) + .pointer(pointerBuilder) + .build(); + for (size_t i = 1; i < predictionPoints.size(); ++i) { + PointerCoords coords = + PointerBuilder(/*id=*/0, ToolType::STYLUS) + .x(predictionPoints[i].position[1]) + .y(predictionPoints[i].position[0]) + .axis(AMOTION_EVENT_AXIS_PRESSURE, predictionPoints[i].pressure) + .buildCoords(); + predictionEvent.addSample(predictionPoints[i].targetTimestamp, &coords); + } + return predictionEvent; +} + +// Creates a MotionEvent corresponding to a stylus lift (UP) ground truth event. +MotionEvent makeLiftMotionEvent() { + return MotionEventBuilder(/*action=*/AMOTION_EVENT_ACTION_UP, + /*source=*/AINPUT_SOURCE_CLASS_POINTER) + .pointer(PointerBuilder(/*id=*/0, ToolType::STYLUS)) + .build(); +} + +TEST(MakeMotionEventTest, MakeGroundTruthMotionEvent) { + const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10.0f, 20.0f), + .pressure = 0.6f}, + .timestamp = TEST_INITIAL_TIMESTAMP}; + const MotionEvent groundTruthMotionEvent = makeMotionEvent(groundTruthPoint); + + ASSERT_EQ(1u, groundTruthMotionEvent.getPointerCount()); + // Note: a MotionEvent's "history size" is one less than its number of samples. + ASSERT_EQ(0u, groundTruthMotionEvent.getHistorySize()); + EXPECT_EQ(groundTruthPoint.position[0], groundTruthMotionEvent.getRawPointerCoords(0)->getY()); + EXPECT_EQ(groundTruthPoint.position[1], groundTruthMotionEvent.getRawPointerCoords(0)->getX()); + EXPECT_EQ(groundTruthPoint.pressure, + groundTruthMotionEvent.getRawPointerCoords(0)->getAxisValue( + AMOTION_EVENT_AXIS_PRESSURE)); + EXPECT_EQ(AMOTION_EVENT_ACTION_MOVE, groundTruthMotionEvent.getAction()); +} + +TEST(MakeMotionEventTest, MakePredictionMotionEvent) { + const nsecs_t originTimestamp = TEST_INITIAL_TIMESTAMP; + const std::vector<PredictionPoint> + predictionPoints{{{.position = Eigen::Vector2f(10.0f, 20.0f), .pressure = 0.6f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + 5 * NANOS_PER_MILLIS}, + {{.position = Eigen::Vector2f(11.0f, 22.0f), .pressure = 0.5f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + 10 * NANOS_PER_MILLIS}, + {{.position = Eigen::Vector2f(12.0f, 24.0f), .pressure = 0.4f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + 15 * NANOS_PER_MILLIS}}; + const MotionEvent predictionMotionEvent = makeMotionEvent(predictionPoints); + + ASSERT_EQ(1u, predictionMotionEvent.getPointerCount()); + // Note: a MotionEvent's "history size" is one less than its number of samples. + ASSERT_EQ(predictionPoints.size(), predictionMotionEvent.getHistorySize() + 1); + for (size_t i = 0; i < predictionPoints.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + const PointerCoords coords = *predictionMotionEvent.getHistoricalRawPointerCoords( + /*pointerIndex=*/0, /*historicalIndex=*/i); + EXPECT_EQ(predictionPoints[i].position[0], coords.getY()); + EXPECT_EQ(predictionPoints[i].position[1], coords.getX()); + EXPECT_EQ(predictionPoints[i].pressure, coords.getAxisValue(AMOTION_EVENT_AXIS_PRESSURE)); + // Note: originTimestamp is discarded when converting PredictionPoint to MotionEvent. + EXPECT_EQ(predictionPoints[i].targetTimestamp, + predictionMotionEvent.getHistoricalEventTime(i)); + EXPECT_EQ(AMOTION_EVENT_ACTION_MOVE, predictionMotionEvent.getAction()); + } +} + +TEST(MakeMotionEventTest, MakeLiftMotionEvent) { + const MotionEvent liftMotionEvent = makeLiftMotionEvent(); + ASSERT_EQ(1u, liftMotionEvent.getPointerCount()); + // Note: a MotionEvent's "history size" is one less than its number of samples. + ASSERT_EQ(0u, liftMotionEvent.getHistorySize()); + EXPECT_EQ(AMOTION_EVENT_ACTION_UP, liftMotionEvent.getAction()); +} + +// --- Ground-truth-generation helper functions. --- + +std::vector<GroundTruthPoint> generateConstantGroundTruthPoints( + const GroundTruthPoint& groundTruthPoint, size_t numPoints) { + std::vector<GroundTruthPoint> groundTruthPoints; + nsecs_t timestamp = groundTruthPoint.timestamp; + for (size_t i = 0; i < numPoints; ++i) { + groundTruthPoints.emplace_back(groundTruthPoint); + groundTruthPoints.back().timestamp = timestamp; + timestamp += TEST_PREDICTION_INTERVAL_NANOS; + } + return groundTruthPoints; +} + +// This function uses the coordinate system (y, x), with +y pointing downwards and +x pointing +// rightwards. Angles are measured counterclockwise from down (+y). +std::vector<GroundTruthPoint> generateCircularArcGroundTruthPoints(Eigen::Vector2f initialPosition, + float initialAngle, + float velocity, + float turningAngle, + size_t numPoints) { + std::vector<GroundTruthPoint> groundTruthPoints; + // Create first point. + if (numPoints > 0) { + groundTruthPoints.push_back({{.position = initialPosition, .pressure = 0.0f}, + .timestamp = TEST_INITIAL_TIMESTAMP}); + } + float trajectoryAngle = initialAngle; // measured counterclockwise from +y axis. + for (size_t i = 1; i < numPoints; ++i) { + const Eigen::Vector2f trajectory = + Eigen::Rotation2D(trajectoryAngle) * Eigen::Vector2f(1, 0); + groundTruthPoints.push_back( + {{.position = groundTruthPoints.back().position + velocity * trajectory, + .pressure = 0.0f}, + .timestamp = groundTruthPoints.back().timestamp + TEST_PREDICTION_INTERVAL_NANOS}); + trajectoryAngle += turningAngle; + } + return groundTruthPoints; +} + +TEST(GenerateConstantGroundTruthPointsTest, BasicTest) { + const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10, 20), .pressure = 0.3f}, + .timestamp = TEST_INITIAL_TIMESTAMP}; + const std::vector<GroundTruthPoint> groundTruthPoints = + generateConstantGroundTruthPoints(groundTruthPoint, /*numPoints=*/3); + + ASSERT_EQ(3u, groundTruthPoints.size()); + // First point. + EXPECT_EQ(groundTruthPoints[0].position, groundTruthPoint.position); + EXPECT_EQ(groundTruthPoints[0].pressure, groundTruthPoint.pressure); + EXPECT_EQ(groundTruthPoints[0].timestamp, groundTruthPoint.timestamp); + // Second point. + EXPECT_EQ(groundTruthPoints[1].position, groundTruthPoint.position); + EXPECT_EQ(groundTruthPoints[1].pressure, groundTruthPoint.pressure); + EXPECT_GT(groundTruthPoints[1].timestamp, groundTruthPoints[0].timestamp); + // Third point. + EXPECT_EQ(groundTruthPoints[2].position, groundTruthPoint.position); + EXPECT_EQ(groundTruthPoints[2].pressure, groundTruthPoint.pressure); + EXPECT_GT(groundTruthPoints[2].timestamp, groundTruthPoints[1].timestamp); +} + +TEST(GenerateCircularArcGroundTruthTest, StraightLineUpwards) { + const std::vector<GroundTruthPoint> groundTruthPoints = generateCircularArcGroundTruthPoints( + /*initialPosition=*/Eigen::Vector2f(0, 0), + /*initialAngle=*/M_PI, + /*velocity=*/1.0f, + /*turningAngle=*/0.0f, + /*numPoints=*/3); + + ASSERT_EQ(3u, groundTruthPoints.size()); + EXPECT_THAT(groundTruthPoints[0].position, Vector2fNear(Eigen::Vector2f(0, 0), 1e-6)); + EXPECT_THAT(groundTruthPoints[1].position, Vector2fNear(Eigen::Vector2f(-1, 0), 1e-6)); + EXPECT_THAT(groundTruthPoints[2].position, Vector2fNear(Eigen::Vector2f(-2, 0), 1e-6)); + // Check that timestamps are increasing between consecutive ground truth points. + EXPECT_GT(groundTruthPoints[1].timestamp, groundTruthPoints[0].timestamp); + EXPECT_GT(groundTruthPoints[2].timestamp, groundTruthPoints[1].timestamp); +} + +TEST(GenerateCircularArcGroundTruthTest, CounterclockwiseSquare) { + // Generate points in a counterclockwise unit square starting pointing right. + const std::vector<GroundTruthPoint> groundTruthPoints = generateCircularArcGroundTruthPoints( + /*initialPosition=*/Eigen::Vector2f(10, 100), + /*initialAngle=*/M_PI_2, + /*velocity=*/1.0f, + /*turningAngle=*/M_PI_2, + /*numPoints=*/5); + + ASSERT_EQ(5u, groundTruthPoints.size()); + EXPECT_THAT(groundTruthPoints[0].position, Vector2fNear(Eigen::Vector2f(10, 100), 1e-6)); + EXPECT_THAT(groundTruthPoints[1].position, Vector2fNear(Eigen::Vector2f(10, 101), 1e-6)); + EXPECT_THAT(groundTruthPoints[2].position, Vector2fNear(Eigen::Vector2f(9, 101), 1e-6)); + EXPECT_THAT(groundTruthPoints[3].position, Vector2fNear(Eigen::Vector2f(9, 100), 1e-6)); + EXPECT_THAT(groundTruthPoints[4].position, Vector2fNear(Eigen::Vector2f(10, 100), 1e-6)); +} + +// --- Prediction-generation helper functions. --- + +// Creates a sequence of predictions with values equal to those of the given GroundTruthPoint. +std::vector<PredictionPoint> generateConstantPredictions(const GroundTruthPoint& groundTruthPoint) { + std::vector<PredictionPoint> predictions; + nsecs_t predictionTimestamp = groundTruthPoint.timestamp + TEST_PREDICTION_INTERVAL_NANOS; + for (size_t j = 0; j < TEST_MAX_NUM_PREDICTIONS; ++j) { + predictions.push_back(PredictionPoint{{.position = groundTruthPoint.position, + .pressure = groundTruthPoint.pressure}, + .originTimestamp = groundTruthPoint.timestamp, + .targetTimestamp = predictionTimestamp}); + predictionTimestamp += TEST_PREDICTION_INTERVAL_NANOS; + } + return predictions; +} + +// Generates TEST_MAX_NUM_PREDICTIONS predictions from the given most recent two ground truth points +// by linear extrapolation of position and pressure. The interval between consecutive predictions' +// timestamps is TEST_PREDICTION_INTERVAL_NANOS. +std::vector<PredictionPoint> generatePredictionsByLinearExtrapolation( + const GroundTruthPoint& firstGroundTruth, const GroundTruthPoint& secondGroundTruth) { + // Precompute deltas. + const Eigen::Vector2f trajectory = secondGroundTruth.position - firstGroundTruth.position; + const float deltaPressure = secondGroundTruth.pressure - firstGroundTruth.pressure; + // Compute predictions. + std::vector<PredictionPoint> predictions; + Eigen::Vector2f predictionPosition = secondGroundTruth.position; + float predictionPressure = secondGroundTruth.pressure; + nsecs_t predictionTargetTimestamp = secondGroundTruth.timestamp; + for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS; ++i) { + predictionPosition += trajectory; + predictionPressure += deltaPressure; + predictionTargetTimestamp += TEST_PREDICTION_INTERVAL_NANOS; + predictions.push_back( + PredictionPoint{{.position = predictionPosition, .pressure = predictionPressure}, + .originTimestamp = secondGroundTruth.timestamp, + .targetTimestamp = predictionTargetTimestamp}); + } + return predictions; +} + +TEST(GeneratePredictionsTest, GenerateConstantPredictions) { + const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10, 20), .pressure = 0.3f}, + .timestamp = TEST_INITIAL_TIMESTAMP}; + const std::vector<PredictionPoint> predictionPoints = + generateConstantPredictions(groundTruthPoint); + + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, predictionPoints.size()); + for (size_t i = 0; i < predictionPoints.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + EXPECT_THAT(predictionPoints[i].position, Vector2fNear(groundTruthPoint.position, 1e-6)); + EXPECT_THAT(predictionPoints[i].pressure, FloatNear(groundTruthPoint.pressure, 1e-6)); + EXPECT_EQ(predictionPoints[i].originTimestamp, groundTruthPoint.timestamp); + EXPECT_EQ(predictionPoints[i].targetTimestamp, + groundTruthPoint.timestamp + + static_cast<nsecs_t>(i + 1) * TEST_PREDICTION_INTERVAL_NANOS); + } +} + +TEST(GeneratePredictionsTest, LinearExtrapolationFromTwoPoints) { + const nsecs_t initialTimestamp = TEST_INITIAL_TIMESTAMP; + const std::vector<PredictionPoint> predictionPoints = generatePredictionsByLinearExtrapolation( + GroundTruthPoint{{.position = Eigen::Vector2f(100, 200), .pressure = 0.9f}, + .timestamp = initialTimestamp}, + GroundTruthPoint{{.position = Eigen::Vector2f(105, 190), .pressure = 0.8f}, + .timestamp = initialTimestamp + TEST_PREDICTION_INTERVAL_NANOS}); + + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, predictionPoints.size()); + const nsecs_t originTimestamp = initialTimestamp + TEST_PREDICTION_INTERVAL_NANOS; + EXPECT_THAT(predictionPoints[0], + PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(110, 180), + .pressure = 0.7f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + + TEST_PREDICTION_INTERVAL_NANOS}, + 0.001)); + EXPECT_THAT(predictionPoints[1], + PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(115, 170), + .pressure = 0.6f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + + 2 * TEST_PREDICTION_INTERVAL_NANOS}, + 0.001)); + EXPECT_THAT(predictionPoints[2], + PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(120, 160), + .pressure = 0.5f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + + 3 * TEST_PREDICTION_INTERVAL_NANOS}, + 0.001)); + EXPECT_THAT(predictionPoints[3], + PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(125, 150), + .pressure = 0.4f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + + 4 * TEST_PREDICTION_INTERVAL_NANOS}, + 0.001)); + EXPECT_THAT(predictionPoints[4], + PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(130, 140), + .pressure = 0.3f}, + .originTimestamp = originTimestamp, + .targetTimestamp = originTimestamp + + 5 * TEST_PREDICTION_INTERVAL_NANOS}, + 0.001)); +} + +// Generates predictions by linear extrapolation for each consecutive pair of ground truth points +// (see the comment for the above function for further explanation). Returns a vector of vectors of +// prediction points, where the first index is the source ground truth index, and the second is the +// prediction target index. +// +// The returned vector has size equal to the input vector, and the first element of the returned +// vector is always empty. +std::vector<std::vector<PredictionPoint>> generateAllPredictionsByLinearExtrapolation( + const std::vector<GroundTruthPoint>& groundTruthPoints) { + std::vector<std::vector<PredictionPoint>> allPredictions; + allPredictions.emplace_back(); + for (size_t i = 1; i < groundTruthPoints.size(); ++i) { + allPredictions.push_back(generatePredictionsByLinearExtrapolation(groundTruthPoints[i - 1], + groundTruthPoints[i])); + } + return allPredictions; +} + +TEST(GeneratePredictionsTest, GenerateAllPredictions) { + const nsecs_t initialTimestamp = TEST_INITIAL_TIMESTAMP; + std::vector<GroundTruthPoint> + groundTruthPoints{GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), + .pressure = 0.5f}, + .timestamp = initialTimestamp}, + GroundTruthPoint{{.position = Eigen::Vector2f(1, -1), + .pressure = 0.51f}, + .timestamp = initialTimestamp + + 2 * TEST_PREDICTION_INTERVAL_NANOS}, + GroundTruthPoint{{.position = Eigen::Vector2f(2, -2), + .pressure = 0.52f}, + .timestamp = initialTimestamp + + 3 * TEST_PREDICTION_INTERVAL_NANOS}}; + + const std::vector<std::vector<PredictionPoint>> allPredictions = + generateAllPredictionsByLinearExtrapolation(groundTruthPoints); + + // Check format of allPredictions data. + ASSERT_EQ(groundTruthPoints.size(), allPredictions.size()); + EXPECT_TRUE(allPredictions[0].empty()); + EXPECT_EQ(TEST_MAX_NUM_PREDICTIONS, allPredictions[1].size()); + EXPECT_EQ(TEST_MAX_NUM_PREDICTIONS, allPredictions[2].size()); + + // Check positions of predictions generated from first pair of ground truth points. + EXPECT_THAT(allPredictions[1][0].position, Vector2fNear(Eigen::Vector2f(2, -2), 1e-9)); + EXPECT_THAT(allPredictions[1][1].position, Vector2fNear(Eigen::Vector2f(3, -3), 1e-9)); + EXPECT_THAT(allPredictions[1][2].position, Vector2fNear(Eigen::Vector2f(4, -4), 1e-9)); + EXPECT_THAT(allPredictions[1][3].position, Vector2fNear(Eigen::Vector2f(5, -5), 1e-9)); + EXPECT_THAT(allPredictions[1][4].position, Vector2fNear(Eigen::Vector2f(6, -6), 1e-9)); + + // Check pressures of predictions generated from first pair of ground truth points. + EXPECT_FLOAT_EQ(0.52f, allPredictions[1][0].pressure); + EXPECT_FLOAT_EQ(0.53f, allPredictions[1][1].pressure); + EXPECT_FLOAT_EQ(0.54f, allPredictions[1][2].pressure); + EXPECT_FLOAT_EQ(0.55f, allPredictions[1][3].pressure); + EXPECT_FLOAT_EQ(0.56f, allPredictions[1][4].pressure); +} + +// --- Prediction error helper functions. --- + +struct GeneralPositionErrors { + float alongTrajectoryErrorMean; + float alongTrajectoryErrorStd; + float offTrajectoryRmse; +}; + +// Inputs: +// • Vector of ground truth points +// • Vector of vectors of prediction points, where the first index is the source ground truth +// index, and the second is the prediction target index. +// +// Returns a vector of GeneralPositionErrors, indexed by prediction time delta bucket. +std::vector<GeneralPositionErrors> computeGeneralPositionErrors( + const std::vector<GroundTruthPoint>& groundTruthPoints, + const std::vector<std::vector<PredictionPoint>>& predictionPoints) { + // Aggregate errors by time bucket (prediction target index). + std::vector<GeneralPositionErrors> generalPostitionErrors; + for (size_t predictionTargetIndex = 0; predictionTargetIndex < TEST_MAX_NUM_PREDICTIONS; + ++predictionTargetIndex) { + std::vector<float> alongTrajectoryErrors; + std::vector<float> alongTrajectorySquaredErrors; + std::vector<float> offTrajectoryErrors; + for (size_t sourceGroundTruthIndex = 1; sourceGroundTruthIndex < groundTruthPoints.size(); + ++sourceGroundTruthIndex) { + const size_t targetGroundTruthIndex = + sourceGroundTruthIndex + predictionTargetIndex + 1; + // Only include errors for points with a ground truth value. + if (targetGroundTruthIndex < groundTruthPoints.size()) { + const Eigen::Vector2f trajectory = + (groundTruthPoints[targetGroundTruthIndex].position - + groundTruthPoints[targetGroundTruthIndex - 1].position) + .normalized(); + const Eigen::Vector2f orthogonalTrajectory = + Eigen::Rotation2Df(M_PI_2) * trajectory; + const Eigen::Vector2f positionError = + predictionPoints[sourceGroundTruthIndex][predictionTargetIndex].position - + groundTruthPoints[targetGroundTruthIndex].position; + alongTrajectoryErrors.push_back(positionError.dot(trajectory)); + alongTrajectorySquaredErrors.push_back(alongTrajectoryErrors.back() * + alongTrajectoryErrors.back()); + offTrajectoryErrors.push_back(positionError.dot(orthogonalTrajectory)); + } + } + generalPostitionErrors.push_back( + {.alongTrajectoryErrorMean = average(alongTrajectoryErrors), + .alongTrajectoryErrorStd = standardDeviation(alongTrajectoryErrors), + .offTrajectoryRmse = rmse(offTrajectoryErrors)}); + } + return generalPostitionErrors; +} + +// Inputs: +// • Vector of ground truth points +// • Vector of vectors of prediction points, where the first index is the source ground truth +// index, and the second is the prediction target index. +// +// Returns a vector of pressure RMSEs, indexed by prediction time delta bucket. +std::vector<float> computePressureRmses( + const std::vector<GroundTruthPoint>& groundTruthPoints, + const std::vector<std::vector<PredictionPoint>>& predictionPoints) { + // Aggregate errors by time bucket (prediction target index). + std::vector<float> pressureRmses; + for (size_t predictionTargetIndex = 0; predictionTargetIndex < TEST_MAX_NUM_PREDICTIONS; + ++predictionTargetIndex) { + std::vector<float> pressureErrors; + for (size_t sourceGroundTruthIndex = 1; sourceGroundTruthIndex < groundTruthPoints.size(); + ++sourceGroundTruthIndex) { + const size_t targetGroundTruthIndex = + sourceGroundTruthIndex + predictionTargetIndex + 1; + // Only include errors for points with a ground truth value. + if (targetGroundTruthIndex < groundTruthPoints.size()) { + pressureErrors.push_back( + predictionPoints[sourceGroundTruthIndex][predictionTargetIndex].pressure - + groundTruthPoints[targetGroundTruthIndex].pressure); + } + } + pressureRmses.push_back(rmse(pressureErrors)); + } + return pressureRmses; +} + +TEST(ErrorComputationHelperTest, ComputeGeneralPositionErrorsSimpleTest) { + std::vector<GroundTruthPoint> groundTruthPoints = + generateConstantGroundTruthPoints(GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), + .pressure = 0.0f}, + .timestamp = TEST_INITIAL_TIMESTAMP}, + /*numPoints=*/TEST_MAX_NUM_PREDICTIONS + 2); + groundTruthPoints[3].position = Eigen::Vector2f(1, 0); + groundTruthPoints[4].position = Eigen::Vector2f(1, 1); + groundTruthPoints[5].position = Eigen::Vector2f(1, 3); + groundTruthPoints[6].position = Eigen::Vector2f(2, 3); + + std::vector<std::vector<PredictionPoint>> predictionPoints = + generateAllPredictionsByLinearExtrapolation(groundTruthPoints); + + // The generated predictions look like: + // + // | Source | Target Ground Truth Index | + // | Index | 2 | 3 | 4 | 5 | 6 | + // |------------|--------|--------|--------|--------|--------| + // | 1 | (0, 0) | (0, 0) | (0, 0) | (0, 0) | (0, 0) | + // | 2 | | (0, 0) | (0, 0) | (0, 0) | (0, 0) | + // | 3 | | | (2, 0) | (3, 0) | (4, 0) | + // | 4 | | | | (1, 2) | (1, 3) | + // | 5 | | | | | (1, 5) | + // |---------------------------------------------------------| + // | Actual Ground Truth Values | + // | Position | (0, 0) | (1, 0) | (1, 1) | (1, 3) | (2, 3) | + // | Previous | (0, 0) | (0, 0) | (1, 0) | (1, 1) | (1, 3) | + // + // Note: this table organizes prediction targets by target ground truth index. Metrics are + // aggregated across points with the same prediction time bucket index, which is different. + // Each down-right diagonal from this table gives us points from a unique time bucket. + + // Initialize expected prediction errors from the table above. The first time bucket corresponds + // to the long diagonal of the table, and subsequent time buckets step up-right from there. + const std::vector<std::vector<float>> expectedAlongTrajectoryErrors{{0, -1, -1, -1, -1}, + {-1, -1, -3, -1}, + {-1, -3, 2}, + {-3, -2}, + {-2}}; + const std::vector<std::vector<float>> expectedOffTrajectoryErrors{{0, 0, 1, 0, 2}, + {0, 1, 2, 0}, + {1, 1, 3}, + {1, 3}, + {3}}; + + std::vector<GeneralPositionErrors> generalPositionErrors = + computeGeneralPositionErrors(groundTruthPoints, predictionPoints); + + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, generalPositionErrors.size()); + for (size_t i = 0; i < generalPositionErrors.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + EXPECT_FLOAT_EQ(average(expectedAlongTrajectoryErrors[i]), + generalPositionErrors[i].alongTrajectoryErrorMean); + EXPECT_FLOAT_EQ(standardDeviation(expectedAlongTrajectoryErrors[i]), + generalPositionErrors[i].alongTrajectoryErrorStd); + EXPECT_FLOAT_EQ(rmse(expectedOffTrajectoryErrors[i]), + generalPositionErrors[i].offTrajectoryRmse); + } +} + +TEST(ErrorComputationHelperTest, ComputePressureRmsesSimpleTest) { + // Generate ground truth points with pressures {0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5}. + // (We need TEST_MAX_NUM_PREDICTIONS + 2 to test all prediction time buckets.) + std::vector<GroundTruthPoint> groundTruthPoints = + generateConstantGroundTruthPoints(GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), + .pressure = 0.0f}, + .timestamp = TEST_INITIAL_TIMESTAMP}, + /*numPoints=*/TEST_MAX_NUM_PREDICTIONS + 2); + for (size_t i = 4; i < groundTruthPoints.size(); ++i) { + groundTruthPoints[i].pressure = 0.5f; + } + + std::vector<std::vector<PredictionPoint>> predictionPoints = + generateAllPredictionsByLinearExtrapolation(groundTruthPoints); + + std::vector<float> pressureRmses = computePressureRmses(groundTruthPoints, predictionPoints); + + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, pressureRmses.size()); + EXPECT_FLOAT_EQ(rmse(std::vector<float>{0.0f, 0.0f, -0.5f, 0.5f, 0.0f}), pressureRmses[0]); + EXPECT_FLOAT_EQ(rmse(std::vector<float>{0.0f, -0.5f, -0.5f, 1.0f}), pressureRmses[1]); + EXPECT_FLOAT_EQ(rmse(std::vector<float>{-0.5f, -0.5f, -0.5f}), pressureRmses[2]); + EXPECT_FLOAT_EQ(rmse(std::vector<float>{-0.5f, -0.5f}), pressureRmses[3]); + EXPECT_FLOAT_EQ(rmse(std::vector<float>{-0.5f}), pressureRmses[4]); +} + +// --- MotionPredictorMetricsManager tests. --- + +// Helper function that instantiates a MetricsManager with the given mock logged AtomFields. Takes +// vectors of ground truth and prediction points of the same length, and passes these points to the +// MetricsManager. The format of these vectors is expected to be: +// • groundTruthPoints: chronologically-ordered ground truth points, with at least 2 elements. +// • predictionPoints: the first index points to a vector of predictions corresponding to the +// source ground truth point with the same index. +// - The first element should be empty, because there are not expected to be predictions until +// we have received 2 ground truth points. +// - The last element may be empty, because there will be no future ground truth points to +// associate with those predictions (if not empty, it will be ignored). +// - To test all prediction buckets, there should be at least TEST_MAX_NUM_PREDICTIONS non-empty +// prediction sets (that is, excluding the first and last). Thus, groundTruthPoints and +// predictionPoints should have size at least TEST_MAX_NUM_PREDICTIONS + 2. +// +// The passed-in outAtomFields will contain the logged AtomFields when the function returns. +// +// This function returns void so that it can use test assertions. +void runMetricsManager(const std::vector<GroundTruthPoint>& groundTruthPoints, + const std::vector<std::vector<PredictionPoint>>& predictionPoints, + std::vector<AtomFields>& outAtomFields) { + MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS, + TEST_MAX_NUM_PREDICTIONS); + metricsManager.setMockLoggedAtomFields(&outAtomFields); + + // Validate structure of groundTruthPoints and predictionPoints. + ASSERT_EQ(predictionPoints.size(), groundTruthPoints.size()); + ASSERT_GE(groundTruthPoints.size(), 2u); + ASSERT_EQ(predictionPoints[0].size(), 0u); + for (size_t i = 1; i + 1 < predictionPoints.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + ASSERT_EQ(predictionPoints[i].size(), TEST_MAX_NUM_PREDICTIONS); + } + + // Pass ground truth points and predictions (for all except first and last ground truth). + for (size_t i = 0; i < groundTruthPoints.size(); ++i) { + metricsManager.onRecord(makeMotionEvent(groundTruthPoints[i])); + if ((i > 0) && (i + 1 < predictionPoints.size())) { + metricsManager.onPredict(makeMotionEvent(predictionPoints[i])); + } + } + // Send a stroke-end event to trigger the logging call. + metricsManager.onRecord(makeLiftMotionEvent()); +} + +// Vacuous test: +// • Input: no prediction data. +// • Expectation: no metrics should be logged. +TEST(MotionPredictorMetricsManagerTest, NoPredictions) { + std::vector<AtomFields> mockLoggedAtomFields; + MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS, + TEST_MAX_NUM_PREDICTIONS); + metricsManager.setMockLoggedAtomFields(&mockLoggedAtomFields); + + metricsManager.onRecord(makeMotionEvent( + GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), .pressure = 0}, .timestamp = 0})); + metricsManager.onRecord(makeLiftMotionEvent()); + + // Check that mockLoggedAtomFields is still empty (as it was initialized empty), ensuring that + // no metrics were logged. + EXPECT_EQ(0u, mockLoggedAtomFields.size()); +} + +// Perfect predictions test: +// • Input: constant input events, perfect predictions matching the input events. +// • Expectation: all error metrics should be zero, or NO_DATA_SENTINEL for "unreported" metrics. +// (For example, scale-invariant errors are only reported for the final time bucket.) +TEST(MotionPredictorMetricsManagerTest, ConstantGroundTruthPerfectPredictions) { + GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10.0f, 20.0f), .pressure = 0.6f}, + .timestamp = TEST_INITIAL_TIMESTAMP}; + + // Generate ground truth and prediction points as described by the runMetricsManager comment. + std::vector<GroundTruthPoint> groundTruthPoints; + std::vector<std::vector<PredictionPoint>> predictionPoints; + for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS + 2; ++i) { + groundTruthPoints.push_back(groundTruthPoint); + predictionPoints.push_back(i > 0 ? generateConstantPredictions(groundTruthPoint) + : std::vector<PredictionPoint>{}); + groundTruthPoint.timestamp += TEST_PREDICTION_INTERVAL_NANOS; + } + + std::vector<AtomFields> atomFields; + runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); + // Check that errors are all zero, or NO_DATA_SENTINEL for unreported metrics. + for (size_t i = 0; i < atomFields.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + const AtomFields& atom = atomFields[i]; + const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); + EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); + // General errors: reported for every time bucket. + EXPECT_EQ(0, atom.alongTrajectoryErrorMeanMillipixels); + EXPECT_EQ(0, atom.alongTrajectoryErrorStdMillipixels); + EXPECT_EQ(0, atom.offTrajectoryRmseMillipixels); + EXPECT_EQ(0, atom.pressureRmseMilliunits); + // High-velocity errors: reported only for the last two time buckets. + // However, this data has zero velocity, so these metrics should all be NO_DATA_SENTINEL. + EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityAlongTrajectoryRmse); + EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityOffTrajectoryRmse); + // Scale-invariant errors: reported only for the last time bucket. + if (i + 1 == atomFields.size()) { + EXPECT_EQ(0, atom.scaleInvariantAlongTrajectoryRmse); + EXPECT_EQ(0, atom.scaleInvariantOffTrajectoryRmse); + } else { + EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantAlongTrajectoryRmse); + EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantOffTrajectoryRmse); + } + } +} + +TEST(MotionPredictorMetricsManagerTest, QuadraticPressureLinearPredictions) { + // Generate ground truth points. + // + // Ground truth pressures are a quadratically increasing function from some initial value. + const float initialPressure = 0.5f; + const float quadraticCoefficient = 0.01f; + std::vector<GroundTruthPoint> groundTruthPoints; + nsecs_t timestamp = TEST_INITIAL_TIMESTAMP; + // As described in the runMetricsManager comment, we should have TEST_MAX_NUM_PREDICTIONS + 2 + // ground truth points. + for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS + 2; ++i) { + const float pressure = initialPressure + quadraticCoefficient * static_cast<float>(i * i); + groundTruthPoints.push_back( + GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), .pressure = pressure}, + .timestamp = timestamp}); + timestamp += TEST_PREDICTION_INTERVAL_NANOS; + } + + // Note: the first index is the source ground truth index, and the second is the prediction + // target index. + std::vector<std::vector<PredictionPoint>> predictionPoints = + generateAllPredictionsByLinearExtrapolation(groundTruthPoints); + + const std::vector<float> pressureErrors = + computePressureRmses(groundTruthPoints, predictionPoints); + + // Run test. + std::vector<AtomFields> atomFields; + runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + + // Check logged metrics match expectations. + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); + for (size_t i = 0; i < atomFields.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + const AtomFields& atom = atomFields[i]; + // Check time bucket delta matches expectation based on index and prediction interval. + const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); + EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); + // Check pressure error matches expectation. + EXPECT_NEAR(static_cast<int>(1000 * pressureErrors[i]), atom.pressureRmseMilliunits, 1); + } +} + +TEST(MotionPredictorMetricsManagerTest, QuadraticPositionLinearPredictionsGeneralErrors) { + // Generate ground truth points. + // + // Each component of the ground truth positions are an independent quadratically increasing + // function from some initial value. + const Eigen::Vector2f initialPosition(200, 300); + const Eigen::Vector2f quadraticCoefficients(-2, 3); + std::vector<GroundTruthPoint> groundTruthPoints; + nsecs_t timestamp = TEST_INITIAL_TIMESTAMP; + // As described in the runMetricsManager comment, we should have TEST_MAX_NUM_PREDICTIONS + 2 + // ground truth points. + for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS + 2; ++i) { + const Eigen::Vector2f position = + initialPosition + quadraticCoefficients * static_cast<float>(i * i); + groundTruthPoints.push_back( + GroundTruthPoint{{.position = position, .pressure = 0.5}, .timestamp = timestamp}); + timestamp += TEST_PREDICTION_INTERVAL_NANOS; + } + + // Note: the first index is the source ground truth index, and the second is the prediction + // target index. + std::vector<std::vector<PredictionPoint>> predictionPoints = + generateAllPredictionsByLinearExtrapolation(groundTruthPoints); + + std::vector<GeneralPositionErrors> generalPositionErrors = + computeGeneralPositionErrors(groundTruthPoints, predictionPoints); + + // Run test. + std::vector<AtomFields> atomFields; + runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + + // Check logged metrics match expectations. + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); + for (size_t i = 0; i < atomFields.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + const AtomFields& atom = atomFields[i]; + // Check time bucket delta matches expectation based on index and prediction interval. + const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); + EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); + // Check general position errors match expectation. + EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].alongTrajectoryErrorMean), + atom.alongTrajectoryErrorMeanMillipixels, 1); + EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].alongTrajectoryErrorStd), + atom.alongTrajectoryErrorStdMillipixels, 1); + EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].offTrajectoryRmse), + atom.offTrajectoryRmseMillipixels, 1); + } +} + +// Counterclockwise regular octagonal section test: +// • Input – ground truth: constantly-spaced input events starting at a trajectory pointing exactly +// rightwards, and rotating by 45° counterclockwise after each input. +// • Input – predictions: simple linear extrapolations of previous two ground truth points. +// +// The code below uses the following terminology to distinguish references to ground truth events: +// • Source ground truth: the most recent ground truth point received at the time the prediction +// was made. +// • Target ground truth: the ground truth event that the prediction was attempting to match. +TEST(MotionPredictorMetricsManagerTest, CounterclockwiseOctagonGroundTruthLinearPredictions) { + // Select a stroke velocity that exceeds the high-velocity threshold of 1100 px/sec. + // For an input rate of 240 hz, 1100 px/sec * (1/240) sec/input ≈ 4.58 pixels per input. + const float strokeVelocity = 10; // pixels per input + + // As described in the runMetricsManager comment, we should have TEST_MAX_NUM_PREDICTIONS + 2 + // ground truth points. + std::vector<GroundTruthPoint> groundTruthPoints = generateCircularArcGroundTruthPoints( + /*initialPosition=*/Eigen::Vector2f(100, 100), + /*initialAngle=*/M_PI_2, + /*velocity=*/strokeVelocity, + /*turningAngle=*/-M_PI_4, + /*numPoints=*/TEST_MAX_NUM_PREDICTIONS + 2); + + std::vector<std::vector<PredictionPoint>> predictionPoints = + generateAllPredictionsByLinearExtrapolation(groundTruthPoints); + + std::vector<GeneralPositionErrors> generalPositionErrors = + computeGeneralPositionErrors(groundTruthPoints, predictionPoints); + + // Run test. + std::vector<AtomFields> atomFields; + runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + + // Check logged metrics match expectations. + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); + for (size_t i = 0; i < atomFields.size(); ++i) { + SCOPED_TRACE(testing::Message() << "i = " << i); + const AtomFields& atom = atomFields[i]; + const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); + EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); + + // General errors: reported for every time bucket. + EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].alongTrajectoryErrorMean), + atom.alongTrajectoryErrorMeanMillipixels, 1); + // We allow for some floating point error in standard deviation (0.02 pixels). + EXPECT_NEAR(1000 * generalPositionErrors[i].alongTrajectoryErrorStd, + atom.alongTrajectoryErrorStdMillipixels, 20); + // All position errors are equal, so the standard deviation should be approximately zero. + EXPECT_NEAR(0, atom.alongTrajectoryErrorStdMillipixels, 20); + // Absolute value for RMSE, since it must be non-negative. + EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].offTrajectoryRmse), + atom.offTrajectoryRmseMillipixels, 1); + + // High-velocity errors: reported only for the last two time buckets. + // + // Since our input stroke velocity is chosen to be above the high-velocity threshold, all + // data contributes to high-velocity errors, and thus high-velocity errors should be equal + // to general errors (where reported). + // + // As above, use absolute value for RMSE, since it must be non-negative. + if (i + 2 >= atomFields.size()) { + EXPECT_NEAR(static_cast<int>( + 1000 * std::abs(generalPositionErrors[i].alongTrajectoryErrorMean)), + atom.highVelocityAlongTrajectoryRmse, 1); + EXPECT_NEAR(static_cast<int>(1000 * + std::abs(generalPositionErrors[i].offTrajectoryRmse)), + atom.highVelocityOffTrajectoryRmse, 1); + } else { + EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityAlongTrajectoryRmse); + EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityOffTrajectoryRmse); + } + + // Scale-invariant errors: reported only for the last time bucket, where the reported value + // is the aggregation across all time buckets. + // + // The MetricsManager stores mMaxNumPredictions recent ground truth segments. Our ground + // truth segments here all have a length of strokeVelocity, so we can convert general errors + // to scale-invariant errors by dividing by `strokeVelocty * TEST_MAX_NUM_PREDICTIONS`. + // + // As above, use absolute value for RMSE, since it must be non-negative. + if (i + 1 == atomFields.size()) { + const float pathLength = strokeVelocity * TEST_MAX_NUM_PREDICTIONS; + std::vector<float> alongTrajectoryAbsoluteErrors; + std::vector<float> offTrajectoryAbsoluteErrors; + for (size_t j = 0; j < TEST_MAX_NUM_PREDICTIONS; ++j) { + alongTrajectoryAbsoluteErrors.push_back( + std::abs(generalPositionErrors[j].alongTrajectoryErrorMean)); + offTrajectoryAbsoluteErrors.push_back( + std::abs(generalPositionErrors[j].offTrajectoryRmse)); + } + EXPECT_NEAR(static_cast<int>(1000 * average(alongTrajectoryAbsoluteErrors) / + pathLength), + atom.scaleInvariantAlongTrajectoryRmse, 1); + EXPECT_NEAR(static_cast<int>(1000 * average(offTrajectoryAbsoluteErrors) / pathLength), + atom.scaleInvariantOffTrajectoryRmse, 1); + } else { + EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantAlongTrajectoryRmse); + EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantOffTrajectoryRmse); + } + } +} + +} // namespace +} // namespace android diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index 7a62f5ec58..4ac7ae920e 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -72,11 +72,20 @@ TEST(MotionPredictorTest, IsPredictionAvailable) { ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN)); } +TEST(MotionPredictorTest, StationaryNoiseFloor) { + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, + []() { return true /*enable prediction*/; }); + predictor.record(getMotionEvent(DOWN, 0, 1, 30ms)); + predictor.record(getMotionEvent(MOVE, 0, 1, 35ms)); // No movement. + std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC); + ASSERT_EQ(nullptr, predicted); +} + TEST(MotionPredictorTest, Offset) { MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, []() { return true /*enable prediction*/; }); predictor.record(getMotionEvent(DOWN, 0, 1, 30ms)); - predictor.record(getMotionEvent(MOVE, 0, 2, 35ms)); + predictor.record(getMotionEvent(MOVE, 0, 5, 35ms)); // Move enough to overcome the noise floor. std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC); ASSERT_NE(nullptr, predicted); ASSERT_GE(predicted->getEventTime(), 41); diff --git a/libs/input/tests/VelocityTracker_test.cpp b/libs/input/tests/VelocityTracker_test.cpp index ae721093a0..73f25cc615 100644 --- a/libs/input/tests/VelocityTracker_test.cpp +++ b/libs/input/tests/VelocityTracker_test.cpp @@ -282,6 +282,11 @@ static void computeAndCheckAxisScrollVelocity( const std::vector<std::pair<std::chrono::nanoseconds, float>>& motions, std::optional<float> targetVelocity) { checkVelocity(computeVelocity(strategy, motions, AMOTION_EVENT_AXIS_SCROLL), targetVelocity); + // The strategy LSQ2 is not compatible with AXIS_SCROLL. In those situations, we should fall + // back to a strategy that supports differential axes. + checkVelocity(computeVelocity(VelocityTracker::Strategy::LSQ2, motions, + AMOTION_EVENT_AXIS_SCROLL), + targetVelocity); } static void computeAndCheckQuadraticEstimate(const std::vector<PlanarMotionEventEntry>& motions, |