diff options
author | 2024-07-04 11:14:18 +0000 | |
---|---|---|
committer | 2024-07-15 08:10:01 +0000 | |
commit | 5e0e7cf9d7918fac53e21a088e2f7d914cde9b8f (patch) | |
tree | 58476614b8067ea56578eda64f7f8a809858c4e4 | |
parent | cda4744d221364232c5860397de3b317cc401687 (diff) |
Add smoothing to jerk calculations and updated jerk thresholds.
Test: atest libinput_tests
Test: atest CtsInputTestCases
Test: atest MotionPredictorBenchmark MotionPredictorTest
Test: Using stylus in a drawing app and seeing the jerk logs.
Bug: 266747654
Bug: 353161308
Flag: com.android.input.flags.enable_prediction_pruning_via_jerk_thresholding
Change-Id: I3d6c47d94d66e5ff2b33474acbca72daca051242
-rw-r--r-- | data/etc/input/motion_predictor_config.xml | 7 | ||||
-rw-r--r-- | include/input/MotionPredictor.h | 13 | ||||
-rw-r--r-- | include/input/TfLiteMotionPredictor.h | 3 | ||||
-rw-r--r-- | libs/input/MotionPredictor.cpp | 32 | ||||
-rw-r--r-- | libs/input/TfLiteMotionPredictor.cpp | 1 | ||||
-rw-r--r-- | libs/input/tests/MotionPredictor_test.cpp | 28 |
6 files changed, 72 insertions, 12 deletions
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml index c3f2fedc71..14540ec512 100644 --- a/data/etc/input/motion_predictor_config.xml +++ b/data/etc/input/motion_predictor_config.xml @@ -35,7 +35,10 @@ The jerk thresholds are based on normalized dt = 1 calculations. --> - <low-jerk>1.0</low-jerk> - <high-jerk>1.1</high-jerk> + <low-jerk>1.5</low-jerk> + <high-jerk>2.0</high-jerk> + + <!-- The forget factor in the first-order IIR filter for jerk smoothing --> + <jerk-forget-factor>0.25</jerk-forget-factor> </motion-predictor> diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h index f71503988f..2f1ef86428 100644 --- a/include/input/MotionPredictor.h +++ b/include/input/MotionPredictor.h @@ -56,12 +56,20 @@ public: // acceleration) and has the units of d^3p/dt^3. std::optional<float> jerkMagnitude() const; + // forgetFactor is the coefficient of the first-order IIR filter for jerk. A factor of 1 results + // in no smoothing. + void setForgetFactor(float forgetFactor); + float getForgetFactor() const; + private: const bool mNormalizedDt; + // Coefficient of first-order IIR filter to smooth jerk calculation. + float mForgetFactor = 1; RingBuffer<int64_t> mTimestamps{4}; std::array<float, 4> mXDerivatives{}; // [x, x', x'', x'''] std::array<float, 4> mYDerivatives{}; // [y, y', y'', y'''] + float mJerkMagnitude; }; /** @@ -116,6 +124,11 @@ public: bool isPredictionAvailable(int32_t deviceId, int32_t source); + /** + * Currently used to expose config constants in testing. + */ + const TfLiteMotionPredictorModel::Config& getModelConfig(); + private: const nsecs_t mPredictionTimestampOffsetNanos; const std::function<bool()> mCheckMotionPredictionEnabled; diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index 728a8e1e39..08a4330d27 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -110,6 +110,9 @@ public: // High jerk means more predictions will be pruned, vice versa for low. float lowJerk = 0; float highJerk = 0; + + // Coefficient for the first-order IIR filter for jerk calculation. + float jerkForgetFactor = 1; }; // Creates a model from an encoded Flatbuffer model. diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 5b61d3953f..9204b95745 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -75,6 +75,9 @@ float normalizeRange(float x, float min, float max) { JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {} void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) { + // If we previously had full samples, we have a previous jerk calculation + // to do weighted smoothing. + const bool applySmoothing = mTimestamps.size() == mTimestamps.capacity(); mTimestamps.pushBack(timestamp); const int numSamples = mTimestamps.size(); @@ -115,6 +118,16 @@ void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) { } } + if (numSamples == static_cast<int>(mTimestamps.capacity())) { + float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]); + ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude); + if (applySmoothing) { + mJerkMagnitude = mJerkMagnitude + (mForgetFactor * (newJerkMagnitude - mJerkMagnitude)); + } else { + mJerkMagnitude = newJerkMagnitude; + } + } + std::swap(newXDerivatives, mXDerivatives); std::swap(newYDerivatives, mYDerivatives); } @@ -125,11 +138,19 @@ void JerkTracker::reset() { std::optional<float> JerkTracker::jerkMagnitude() const { if (mTimestamps.size() == mTimestamps.capacity()) { - return std::hypot(mXDerivatives[3], mYDerivatives[3]); + return mJerkMagnitude; } return std::nullopt; } +void JerkTracker::setForgetFactor(float forgetFactor) { + mForgetFactor = forgetFactor; +} + +float JerkTracker::getForgetFactor() const { + return mForgetFactor; +} + // --- MotionPredictor --- MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, @@ -159,6 +180,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { if (!mModel) { mModel = TfLiteMotionPredictorModel::create(); LOG_ALWAYS_FATAL_IF(!mModel); + mJerkTracker.setForgetFactor(mModel->config().jerkForgetFactor); } if (!mBuffers) { @@ -357,4 +379,12 @@ bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source return true; } +const TfLiteMotionPredictorModel::Config& MotionPredictor::getModelConfig() { + if (!mModel) { + mModel = TfLiteMotionPredictorModel::create(); + LOG_ALWAYS_FATAL_IF(!mModel); + } + return mModel->config(); +} + } // namespace android diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index b843a4bbf6..b401c985e6 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -283,6 +283,7 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), .lowJerk = parseXMLFloat(*configRoot, "low-jerk"), .highJerk = parseXMLFloat(*configRoot, "high-jerk"), + .jerkForgetFactor = parseXMLFloat(*configRoot, "jerk-forget-factor"), }; return std::unique_ptr<TfLiteMotionPredictorModel>( diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index d077760757..5bd5794795 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -88,6 +88,7 @@ TEST(JerkTrackerTest, JerkReadiness) { TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) { JerkTracker jerkTracker(true); + jerkTracker.setForgetFactor(.5); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); @@ -118,11 +119,14 @@ TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) { * y'': 3 -> -15 * y''': -18 */ - EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-50, -18)); + const float newJerk = (1 - jerkTracker.getForgetFactor()) * std::hypot(10, -1) + + jerkTracker.getForgetFactor() * std::hypot(-50, -18); + EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), newJerk); } TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) { JerkTracker jerkTracker(false); + jerkTracker.setForgetFactor(.5); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/10, 25, 53); jerkTracker.pushSample(/*timestamp=*/20, 30, 60); @@ -153,7 +157,9 @@ TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) { * y'': .03 -> -.125 (delta above, divide by 10) * y''': -.0155 (delta above, divide by 10) */ - EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-.0375, -.0155)); + const float newJerk = (1 - jerkTracker.getForgetFactor()) * std::hypot(.01, -.001) + + jerkTracker.getForgetFactor() * std::hypot(-.0375, -.0155); + EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), newJerk); } TEST(JerkTrackerTest, JerkCalculationAfterReset) { @@ -291,15 +297,19 @@ TEST_WITH_FLAGS( MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, []() { return true /*enable prediction*/; }); - // Jerk is medium (1.05 normalized, which is halfway between LOW_JANK and HIGH_JANK) - predictor.record(getMotionEvent(DOWN, 0, 5.2, 20ms)); - predictor.record(getMotionEvent(MOVE, 0, 11.5, 30ms)); - predictor.record(getMotionEvent(MOVE, 0, 22, 40ms)); - predictor.record(getMotionEvent(MOVE, 0, 37.75, 50ms)); - predictor.record(getMotionEvent(MOVE, 0, 59.8, 60ms)); + const float mediumJerk = + (predictor.getModelConfig().lowJerk + predictor.getModelConfig().highJerk) / 2; + const float a = 3; // initial acceleration + const float b = 4; // initial velocity + const float c = 5; // initial position + predictor.record(getMotionEvent(DOWN, 0, c, 20ms)); + predictor.record(getMotionEvent(MOVE, 0, c + b, 30ms)); + predictor.record(getMotionEvent(MOVE, 0, c + 2 * b + a, 40ms)); + predictor.record(getMotionEvent(MOVE, 0, c + 3 * b + 3 * a + mediumJerk, 50ms)); + predictor.record(getMotionEvent(MOVE, 0, c + 4 * b + 6 * a + 4 * mediumJerk, 60ms)); std::unique_ptr<MotionEvent> predicted = predictor.predict(82 * NSEC_PER_MSEC); EXPECT_NE(nullptr, predicted); - // Halfway between LOW_JANK and HIGH_JANK means that half of the predictions + // Halfway between LOW_JERK and HIGH_JERK means that half of the predictions // will be pruned. If model prediction window is close enough to predict() // call time window, then half of the model predictions (5/2 -> 2) will be // ouputted. |