From 5e0e7cf9d7918fac53e21a088e2f7d914cde9b8f Mon Sep 17 00:00:00 2001 From: Derek Wu Date: Thu, 4 Jul 2024 11:14:18 +0000 Subject: Add smoothing to jerk calculations and updated jerk thresholds. Test: atest libinput_tests Test: atest CtsInputTestCases Test: atest MotionPredictorBenchmark MotionPredictorTest Test: Using stylus in a drawing app and seeing the jerk logs. Bug: 266747654 Bug: 353161308 Flag: com.android.input.flags.enable_prediction_pruning_via_jerk_thresholding Change-Id: I3d6c47d94d66e5ff2b33474acbca72daca051242 --- data/etc/input/motion_predictor_config.xml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'data') diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml index c3f2fedc71..14540ec512 100644 --- a/data/etc/input/motion_predictor_config.xml +++ b/data/etc/input/motion_predictor_config.xml @@ -35,7 +35,10 @@ The jerk thresholds are based on normalized dt = 1 calculations. --> - 1.0 - 1.1 + 1.5 + 2.0 + + + 0.25 -- cgit v1.2.3-59-g8ed1b From cc6aec59b86cdc7e26f7d637716c0ad85694931c Mon Sep 17 00:00:00 2001 From: Derek Wu Date: Tue, 30 Jul 2024 11:59:56 +0000 Subject: Refactor JerkTracker and MotionPredictor for better testing. Changes include renaming forgetFactor to alpha. Test: atest libinput_tests Bug: 266747654 Bug: 353161308 Flag: com.android.input.flags.enable_prediction_pruning_via_jerk_thresholding Change-Id: Icd056d36a3d7894c6c9b4b957233002ad961a9a1 --- data/etc/input/motion_predictor_config.xml | 5 ++- include/input/MotionPredictor.h | 30 ++++++--------- include/input/TfLiteMotionPredictor.h | 2 +- libs/input/MotionPredictor.cpp | 62 +++++++++++++----------------- libs/input/TfLiteMotionPredictor.cpp | 2 +- libs/input/tests/MotionPredictor_test.cpp | 23 +++++------ 6 files changed, 56 insertions(+), 68 deletions(-) (limited to 'data') diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml index 14540ec512..f593eda42d 100644 --- a/data/etc/input/motion_predictor_config.xml +++ b/data/etc/input/motion_predictor_config.xml @@ -38,7 +38,8 @@ 1.5 2.0 - - 0.25 + + 0.25 diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h index 2f1ef86428..200c301ffe 100644 --- a/include/input/MotionPredictor.h +++ b/include/input/MotionPredictor.h @@ -43,7 +43,9 @@ static inline bool isMotionPredictionEnabled() { class JerkTracker { public: // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1. - JerkTracker(bool normalizedDt); + // alpha is the coefficient of the first-order IIR filter for jerk. A factor of 1 results + // in no smoothing. + JerkTracker(bool normalizedDt, float alpha); // Add a position to the tracker and update derivative estimates. void pushSample(int64_t timestamp, float xPos, float yPos); @@ -56,15 +58,10 @@ public: // acceleration) and has the units of d^3p/dt^3. std::optional jerkMagnitude() const; - // forgetFactor is the coefficient of the first-order IIR filter for jerk. A factor of 1 results - // in no smoothing. - void setForgetFactor(float forgetFactor); - float getForgetFactor() const; - private: const bool mNormalizedDt; // Coefficient of first-order IIR filter to smooth jerk calculation. - float mForgetFactor = 1; + const float mAlpha; RingBuffer mTimestamps{4}; std::array mXDerivatives{}; // [x, x', x'', x'''] @@ -124,11 +121,6 @@ public: bool isPredictionAvailable(int32_t deviceId, int32_t source); - /** - * Currently used to expose config constants in testing. - */ - const TfLiteMotionPredictorModel::Config& getModelConfig(); - private: const nsecs_t mPredictionTimestampOffsetNanos; const std::function mCheckMotionPredictionEnabled; @@ -137,15 +129,17 @@ private: std::unique_ptr mBuffers; std::optional mLastEvent; - // mJerkTracker assumes normalized dt = 1 between recorded samples because - // the underlying mModel input also assumes fixed-interval samples. - // Normalized dt as 1 is also used to correspond with the similar Jank - // implementation from the JetPack MotionPredictor implementation. - JerkTracker mJerkTracker{true}; - std::optional mMetricsManager; + std::unique_ptr mJerkTracker; + + std::unique_ptr mMetricsManager; const ReportAtomFunction mReportAtomFunction; + + // Initialize prediction model and associated objects. + // Called during lazy initialization. + // TODO: b/210158587 Consider removing lazy initialization. + void initializeObjects(); }; } // namespace android diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index 08a4330d27..49e909ea55 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -112,7 +112,7 @@ public: float highJerk = 0; // Coefficient for the first-order IIR filter for jerk calculation. - float jerkForgetFactor = 1; + float jerkAlpha = 1; }; // Creates a model from an encoded Flatbuffer model. diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 9c70535ef5..c61d3943e0 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -72,7 +72,8 @@ float normalizeRange(float x, float min, float max) { // --- JerkTracker --- -JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {} +JerkTracker::JerkTracker(bool normalizedDt, float alpha) + : mNormalizedDt(normalizedDt), mAlpha(alpha) {} void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) { // If we previously had full samples, we have a previous jerk calculation @@ -122,7 +123,7 @@ void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) { float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]); ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude); if (applySmoothing) { - mJerkMagnitude = mJerkMagnitude + (mForgetFactor * (newJerkMagnitude - mJerkMagnitude)); + mJerkMagnitude = mJerkMagnitude + (mAlpha * (newJerkMagnitude - mJerkMagnitude)); } else { mJerkMagnitude = newJerkMagnitude; } @@ -143,14 +144,6 @@ std::optional JerkTracker::jerkMagnitude() const { return std::nullopt; } -void JerkTracker::setForgetFactor(float forgetFactor) { - mForgetFactor = forgetFactor; -} - -float JerkTracker::getForgetFactor() const { - return mForgetFactor; -} - // --- MotionPredictor --- MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, @@ -160,6 +153,24 @@ MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)), mReportAtomFunction(reportAtomFunction) {} +void MotionPredictor::initializeObjects() { + mModel = TfLiteMotionPredictorModel::create(); + LOG_ALWAYS_FATAL_IF(!mModel); + + // mJerkTracker assumes normalized dt = 1 between recorded samples because + // the underlying mModel input also assumes fixed-interval samples. + // Normalized dt as 1 is also used to correspond with the similar Jank + // implementation from the JetPack MotionPredictor implementation. + mJerkTracker = std::make_unique(/*normalizedDt=*/true, mModel->config().jerkAlpha); + + mBuffers = std::make_unique(mModel->inputLength()); + + mMetricsManager = + std::make_unique(mModel->config().predictionInterval, + mModel->outputLength(), + mReportAtomFunction); +} + android::base::Result MotionPredictor::record(const MotionEvent& event) { if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) { // We still have an active gesture for another device. The provided MotionEvent is not @@ -176,29 +187,18 @@ android::base::Result MotionPredictor::record(const MotionEvent& event) { return {}; } - // Initialise the model now that it's likely to be used. if (!mModel) { - mModel = TfLiteMotionPredictorModel::create(); - LOG_ALWAYS_FATAL_IF(!mModel); - mJerkTracker.setForgetFactor(mModel->config().jerkForgetFactor); - } - - if (!mBuffers) { - mBuffers = std::make_unique(mModel->inputLength()); + initializeObjects(); } // Pass input event to the MetricsManager. - if (!mMetricsManager) { - mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(), - mReportAtomFunction); - } mMetricsManager->onRecord(event); const int32_t action = event.getActionMasked(); if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) { ALOGD_IF(isDebug(), "End of event stream"); mBuffers->reset(); - mJerkTracker.reset(); + mJerkTracker->reset(); mLastEvent.reset(); return {}; } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) { @@ -233,9 +233,9 @@ android::base::Result MotionPredictor::record(const MotionEvent& event) { 0, i), .orientation = event.getHistoricalOrientation(0, i), }); - mJerkTracker.pushSample(event.getHistoricalEventTime(i), - coords->getAxisValue(AMOTION_EVENT_AXIS_X), - coords->getAxisValue(AMOTION_EVENT_AXIS_Y)); + mJerkTracker->pushSample(event.getHistoricalEventTime(i), + coords->getAxisValue(AMOTION_EVENT_AXIS_X), + coords->getAxisValue(AMOTION_EVENT_AXIS_Y)); } if (!mLastEvent) { @@ -283,7 +283,7 @@ std::unique_ptr MotionPredictor::predict(nsecs_t timestamp) { int64_t predictionTime = mBuffers->lastTimestamp(); const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos; - const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0); + const float jerkMagnitude = mJerkTracker->jerkMagnitude().value_or(0); const float fractionKept = 1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk); // float to ensure proper division below. @@ -379,12 +379,4 @@ bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source return true; } -const TfLiteMotionPredictorModel::Config& MotionPredictor::getModelConfig() { - if (!mModel) { - mModel = TfLiteMotionPredictorModel::create(); - LOG_ALWAYS_FATAL_IF(!mModel); - } - return mModel->config(); -} - } // namespace android diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index b401c985e6..5250a9d2db 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -283,7 +283,7 @@ std::unique_ptr TfLiteMotionPredictorModel::create() .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), .lowJerk = parseXMLFloat(*configRoot, "low-jerk"), .highJerk = parseXMLFloat(*configRoot, "high-jerk"), - .jerkForgetFactor = parseXMLFloat(*configRoot, "jerk-forget-factor"), + .jerkAlpha = parseXMLFloat(*configRoot, "jerk-alpha"), }; return std::unique_ptr( diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index 5bd5794795..106e686a81 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -70,7 +70,7 @@ static MotionEvent getMotionEvent(int32_t action, float x, float y, } TEST(JerkTrackerTest, JerkReadiness) { - JerkTracker jerkTracker(true); + JerkTracker jerkTracker(/*normalizedDt=*/true, /*alpha=*/1); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); EXPECT_FALSE(jerkTracker.jerkMagnitude()); @@ -87,8 +87,8 @@ TEST(JerkTrackerTest, JerkReadiness) { } TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) { - JerkTracker jerkTracker(true); - jerkTracker.setForgetFactor(.5); + const float alpha = .5; + JerkTracker jerkTracker(/*normalizedDt=*/true, alpha); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); @@ -119,14 +119,13 @@ TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) { * y'': 3 -> -15 * y''': -18 */ - const float newJerk = (1 - jerkTracker.getForgetFactor()) * std::hypot(10, -1) + - jerkTracker.getForgetFactor() * std::hypot(-50, -18); + const float newJerk = (1 - alpha) * std::hypot(10, -1) + alpha * std::hypot(-50, -18); EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), newJerk); } TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) { - JerkTracker jerkTracker(false); - jerkTracker.setForgetFactor(.5); + const float alpha = .5; + JerkTracker jerkTracker(/*normalizedDt=*/false, alpha); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/10, 25, 53); jerkTracker.pushSample(/*timestamp=*/20, 30, 60); @@ -157,13 +156,12 @@ TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) { * y'': .03 -> -.125 (delta above, divide by 10) * y''': -.0155 (delta above, divide by 10) */ - const float newJerk = (1 - jerkTracker.getForgetFactor()) * std::hypot(.01, -.001) + - jerkTracker.getForgetFactor() * std::hypot(-.0375, -.0155); + const float newJerk = (1 - alpha) * std::hypot(.01, -.001) + alpha * std::hypot(-.0375, -.0155); EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), newJerk); } TEST(JerkTrackerTest, JerkCalculationAfterReset) { - JerkTracker jerkTracker(true); + JerkTracker jerkTracker(/*normalizedDt=*/true, /*alpha=*/1); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); @@ -297,8 +295,11 @@ TEST_WITH_FLAGS( MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, []() { return true /*enable prediction*/; }); + // Create another instance of TfLiteMotionPredictorModel to read config details. + std::unique_ptr testTfLiteModel = + TfLiteMotionPredictorModel::create(); const float mediumJerk = - (predictor.getModelConfig().lowJerk + predictor.getModelConfig().highJerk) / 2; + (testTfLiteModel->config().lowJerk + testTfLiteModel->config().highJerk) / 2; const float a = 3; // initial acceleration const float b = 4; // initial velocity const float c = 5; // initial position -- cgit v1.2.3-59-g8ed1b