diff options
-rw-r--r-- | data/etc/input/motion_predictor_config.xml | 8 | ||||
-rw-r--r-- | include/input/TfLiteMotionPredictor.h | 5 | ||||
-rw-r--r-- | libs/input/MotionPredictor.cpp | 24 | ||||
-rw-r--r-- | libs/input/TfLiteMotionPredictor.cpp | 2 | ||||
-rw-r--r-- | libs/input/tests/Android.bp | 1 | ||||
-rw-r--r-- | libs/input/tests/MotionPredictor_test.cpp | 80 |
6 files changed, 104 insertions, 16 deletions
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml index 39772aece2..a20993f924 100644 --- a/data/etc/input/motion_predictor_config.xml +++ b/data/etc/input/motion_predictor_config.xml @@ -31,5 +31,13 @@ the UX issue mentioned above. --> <distance-noise-floor>0.2</distance-noise-floor> + <!-- The low and high jerk thresholds for prediction pruning. + + The jerk thresholds are based on normalized dt = 1 calculations, and + are taken from Jetpacks MotionEventPredictor's KalmanPredictor + implementation (using its ACCURATE_LOW_JANK and ACCURATE_HIGH_JANK). + --> + <low-jerk>0.1</low-jerk> + <high-jerk>0.2</high-jerk> </motion-predictor> diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index 2edc138f67..728a8e1e39 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -105,6 +105,11 @@ public: // The noise floor for predictions. // Distances (r) less than this should be discarded as noise. float distanceNoiseFloor = 0; + + // Low and high jerk thresholds (with normalized dt = 1) for predictions. + // High jerk means more predictions will be pruned, vice versa for low. + float lowJerk = 0; + float highJerk = 0; }; // Creates a model from an encoded Flatbuffer model. diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 77292d4798..5b61d3953f 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -18,6 +18,7 @@ #include <input/MotionPredictor.h> +#include <algorithm> #include <array> #include <cinttypes> #include <cmath> @@ -62,6 +63,11 @@ TfLiteMotionPredictorSample::Point convertPrediction( return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta}; } +float normalizeRange(float x, float min, float max) { + const float normalized = (x - min) / (max - min); + return std::min(1.0f, std::max(0.0f, normalized)); +} + } // namespace // --- JerkTracker --- @@ -255,6 +261,17 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { int64_t predictionTime = mBuffers->lastTimestamp(); const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos; + const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0); + const float fractionKept = + 1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk); + // float to ensure proper division below. + const float predictionTimeWindow = futureTime - predictionTime; + const int maxNumPredictions = static_cast<int>( + std::ceil(predictionTimeWindow / mModel->config().predictionInterval * fractionKept)); + ALOGD_IF(isDebug(), + "jerk (d^3p/normalizedDt^3): %f, fraction of prediction window pruned: %f, max number " + "of predictions: %d", + jerkMagnitude, 1 - fractionKept, maxNumPredictions); for (size_t i = 0; i < static_cast<size_t>(predictedR.size()) && predictionTime <= futureTime; ++i) { if (predictedR[i] < mModel->config().distanceNoiseFloor) { @@ -269,13 +286,12 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { break; } if (input_flags::enable_prediction_pruning_via_jerk_thresholding()) { - // TODO(b/266747654): Stop predictions if confidence is < some threshold - // Arbitrarily high pruning index, will correct once jerk thresholding is implemented. - const size_t upperBoundPredictionIndex = std::numeric_limits<size_t>::max(); - if (i > upperBoundPredictionIndex) { + if (i >= static_cast<size_t>(maxNumPredictions)) { break; } } + // TODO(b/266747654): Stop predictions if confidence is < some + // threshold. Currently predictions are pruned via jerk thresholding. const TfLiteMotionPredictorSample::Point predictedPoint = convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]); diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index d17476e216..b843a4bbf6 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -281,6 +281,8 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() Config config{ .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"), .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), + .lowJerk = parseXMLFloat(*configRoot, "low-jerk"), + .highJerk = parseXMLFloat(*configRoot, "high-jerk"), }; return std::unique_ptr<TfLiteMotionPredictorModel>( diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp index e67a65a114..ee140b72bd 100644 --- a/libs/input/tests/Android.bp +++ b/libs/input/tests/Android.bp @@ -36,6 +36,7 @@ cc_test { "tensorflow_headers", ], static_libs: [ + "libflagtest", "libgmock", "libgui_window_info_static", "libinput", diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index f74874cfe9..dc38feffd6 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -14,9 +14,12 @@ * limitations under the License. */ +// TODO(b/331815574): Decouple this test from assumed config values. #include <chrono> #include <cmath> +#include <com_android_input_flags.h> +#include <flag_macros.h> #include <gmock/gmock.h> #include <gtest/gtest.h> #include <gui/constants.h> @@ -197,18 +200,14 @@ TEST(MotionPredictorTest, Offset) { TEST(MotionPredictorTest, FollowsGesture) { MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, []() { return true /*enable prediction*/; }); - - // MOVE without a DOWN is ignored. - predictor.record(getMotionEvent(MOVE, 1, 3, 10ms)); - EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC)); - - predictor.record(getMotionEvent(DOWN, 2, 5, 20ms)); - predictor.record(getMotionEvent(MOVE, 2, 7, 30ms)); - predictor.record(getMotionEvent(MOVE, 3, 9, 40ms)); - EXPECT_NE(nullptr, predictor.predict(50 * NSEC_PER_MSEC)); - - predictor.record(getMotionEvent(UP, 4, 11, 50ms)); - EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC)); + predictor.record(getMotionEvent(DOWN, 3.75, 3, 20ms)); + predictor.record(getMotionEvent(MOVE, 4.8, 3, 30ms)); + predictor.record(getMotionEvent(MOVE, 6.2, 3, 40ms)); + predictor.record(getMotionEvent(MOVE, 8, 3, 50ms)); + EXPECT_NE(nullptr, predictor.predict(90 * NSEC_PER_MSEC)); + + predictor.record(getMotionEvent(UP, 10.25, 3, 60ms)); + EXPECT_EQ(nullptr, predictor.predict(100 * NSEC_PER_MSEC)); } TEST(MotionPredictorTest, MultipleDevicesNotSupported) { @@ -250,6 +249,63 @@ TEST(MotionPredictorTest, FlagDisablesPrediction) { ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN)); } +TEST_WITH_FLAGS( + MotionPredictorTest, LowJerkNoPruning, + REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags, + enable_prediction_pruning_via_jerk_thresholding))) { + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, + []() { return true /*enable prediction*/; }); + + // Jerk is low (0.05 normalized). + predictor.record(getMotionEvent(DOWN, 2, 7, 20ms)); + predictor.record(getMotionEvent(MOVE, 2.75, 7, 30ms)); + predictor.record(getMotionEvent(MOVE, 3.8, 7, 40ms)); + predictor.record(getMotionEvent(MOVE, 5.2, 7, 50ms)); + predictor.record(getMotionEvent(MOVE, 7, 7, 60ms)); + std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC); + EXPECT_NE(nullptr, predicted); + EXPECT_EQ(static_cast<size_t>(5), predicted->getHistorySize() + 1); +} + +TEST_WITH_FLAGS( + MotionPredictorTest, HighJerkPredictionsPruned, + REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags, + enable_prediction_pruning_via_jerk_thresholding))) { + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, + []() { return true /*enable prediction*/; }); + + // Jerk is incredibly high. + predictor.record(getMotionEvent(DOWN, 0, 5, 20ms)); + predictor.record(getMotionEvent(MOVE, 0, 70, 30ms)); + predictor.record(getMotionEvent(MOVE, 0, 139, 40ms)); + predictor.record(getMotionEvent(MOVE, 0, 1421, 50ms)); + predictor.record(getMotionEvent(MOVE, 0, 41233, 60ms)); + std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC); + EXPECT_EQ(nullptr, predicted); +} + +TEST_WITH_FLAGS( + MotionPredictorTest, MediumJerkPredictionsSomePruned, + REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags, + enable_prediction_pruning_via_jerk_thresholding))) { + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, + []() { return true /*enable prediction*/; }); + + // Jerk is medium (1.5 normalized, which is halfway between LOW_JANK and HIGH_JANK) + predictor.record(getMotionEvent(DOWN, 0, 4, 20ms)); + predictor.record(getMotionEvent(MOVE, 0, 6.25, 30ms)); + predictor.record(getMotionEvent(MOVE, 0, 9.4, 40ms)); + predictor.record(getMotionEvent(MOVE, 0, 13.6, 50ms)); + predictor.record(getMotionEvent(MOVE, 0, 19, 60ms)); + std::unique_ptr<MotionEvent> predicted = predictor.predict(82 * NSEC_PER_MSEC); + EXPECT_NE(nullptr, predicted); + // Halfway between LOW_JANK and HIGH_JANK means that half of the predictions + // will be pruned. If model prediction window is close enough to predict() + // call time window, then half of the model predictions (5/2 -> 2) will be + // ouputted. + EXPECT_EQ(static_cast<size_t>(3), predicted->getHistorySize() + 1); +} + using AtomFields = MotionPredictorMetricsManager::AtomFields; using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction; |