From 107ce707b5eaf8758806d456d07832c1db477555 Mon Sep 17 00:00:00 2001 From: Philip Quinn Date: Fri, 14 Jul 2023 13:07:13 -0700 Subject: Update motion prediction model. Input events with no movement (r = 0) are now included in the buffer so that the model can accurately determine when the input device has become stationary, and a noise floor is added to prevent spurious predictions when this happens. Benchmark results: Old: timeRecordAndPredict_mean (ns): 17990 timeRecordAndPredict_median (ns): 18024 timeRecordAndPredict_min (ns): 17606 timeRecordAndPredict_standardDeviation: 345 New: timeRecordAndPredict_mean (ns): 38394 timeRecordAndPredict_median (ns): 38476 timeRecordAndPredict_min (ns): 38083 timeRecordAndPredict_standardDeviation: 187 Bug: 288354672 PiperOrigin-RevId: 549064247 Test: predictions are visible in the motionprediction test app Test: atest CtsInputTestCases Test: atest MotionPredictorBenchmark MotionPredictorTest Test: atest --host libinput_tests Change-Id: I6c3917591323d7117c4ee2e91abf6c6004178f19 --- data/etc/input/motion_predictor_config.xml | 15 +++++++++++ data/etc/input/motion_predictor_model.tflite | Bin 34080 -> 179532 bytes include/input/TfLiteMotionPredictor.h | 15 ++++++++--- libs/input/MotionPredictor.cpp | 19 +++++++++++--- libs/input/TfLiteMotionPredictor.cpp | 36 +++++++++++++++++---------- libs/input/tests/MotionPredictor_test.cpp | 11 +++++++- 6 files changed, 74 insertions(+), 22 deletions(-) diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml index 03dfd63cbd..39772aece2 100644 --- a/data/etc/input/motion_predictor_config.xml +++ b/data/etc/input/motion_predictor_config.xml @@ -16,5 +16,20 @@ 4166666 + + 0.2 diff --git a/data/etc/input/motion_predictor_model.tflite b/data/etc/input/motion_predictor_model.tflite index 10b3c8b114..45fc162cd1 100644 Binary files a/data/etc/input/motion_predictor_model.tflite and b/data/etc/input/motion_predictor_model.tflite differ diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index fbd60261b2..2edc138f67 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -99,6 +99,14 @@ private: // A TFLite model for generating motion predictions. class TfLiteMotionPredictorModel { public: + struct Config { + // The time between predictions. + nsecs_t predictionInterval = 0; + // The noise floor for predictions. + // Distances (r) less than this should be discarded as noise. + float distanceNoiseFloor = 0; + }; + // Creates a model from an encoded Flatbuffer model. static std::unique_ptr create(); @@ -110,8 +118,7 @@ public: // Returns the length of the model's output buffers. size_t outputLength() const; - // Returns the time interval between predictions. - nsecs_t predictionInterval() const { return mPredictionInterval; } + const Config& config() const { return mConfig; } // Executes the model. // Returns true if the model successfully executed and the output tensors can be read. @@ -132,7 +139,7 @@ public: private: explicit TfLiteMotionPredictorModel(std::unique_ptr model, - nsecs_t predictionInterval); + Config config); void allocateTensors(); void attachInputTensors(); @@ -154,7 +161,7 @@ private: std::unique_ptr mInterpreter; tflite::SignatureRunner* mRunner = nullptr; - const nsecs_t mPredictionInterval = 0; + const Config mConfig = {}; }; } // namespace android diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 68e688817b..c2ea35c6bf 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -138,7 +138,8 @@ android::base::Result MotionPredictor::record(const MotionEvent& event) { // Pass input event to the MetricsManager. if (!mMetricsManager) { mMetricsManager = - std::make_optional(mModel->predictionInterval(), + std::make_optional(mModel->config() + .predictionInterval, mModel->outputLength()); } mMetricsManager->onRecord(event); @@ -184,8 +185,18 @@ std::unique_ptr MotionPredictor::predict(nsecs_t timestamp) { const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos; for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) { - // TODO(b/266747654): Stop predictions if confidence and/or predicted pressure are below - // some thresholds. + if (predictedR[i] < mModel->config().distanceNoiseFloor) { + // Stop predicting when the predicted output is below the model's noise floor. + // + // We assume that all subsequent predictions in the batch are unreliable because later + // predictions are conditional on earlier predictions, and a state of noise is not a + // good basis for prediction. + // + // The UX trade-off is that this potentially sacrifices some predictions when the input + // device starts to speed up, but avoids producing noisy predictions as it slows down. + break; + } + // TODO(b/266747654): Stop predictions if confidence is < some threshold. const TfLiteMotionPredictorSample::Point predictedPoint = convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]); @@ -197,7 +208,7 @@ std::unique_ptr MotionPredictor::predict(nsecs_t timestamp) { coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y); coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]); - predictionTime += mModel->predictionInterval(); + predictionTime += mModel->config().predictionInterval; if (i == 0) { hasPredictions = true; prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(), diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index 9f4aaa8337..5984b4d3b9 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -100,6 +100,16 @@ int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elemen return value; } +float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) { + const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); + LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); + + float value = 0; + LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS, + "Failed to parse %s: %s", elementName, element->GetText()); + return value; +} + // A TFLite ErrorReporter that logs to logcat. class LoggingErrorReporter : public tflite::ErrorReporter { public: @@ -152,6 +162,7 @@ std::unique_ptr createOpResolver() { ::tflite::ops::builtin::Register_CONCATENATION()); resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, ::tflite::ops::builtin::Register_FULLY_CONNECTED()); + resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU()); return resolver; } @@ -208,13 +219,7 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, float phi = 0; float orientation = 0; - // Ignore the sample if there is no movement. These samples can occur when there's change to a - // property other than the coordinates and pollute the input to the model. - if (r == 0) { - return; - } - - if (!mAxisFrom) { // Second point. + if (!mAxisFrom && r > 0) { // Second point. // We can only determine the distance from the first point, and not any // angle. However, if the second point forms an axis, the orientation can // be transformed relative to that axis. @@ -235,8 +240,10 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, } // Update the axis for the next point. - mAxisFrom = mAxisTo; - mAxisTo = sample; + if (r > 0) { + mAxisFrom = mAxisTo; + mAxisTo = sample; + } // Push the current sample onto the end of the input buffers. mInputR.pushBack(r); @@ -272,15 +279,18 @@ std::unique_ptr TfLiteMotionPredictorModel::create() // Parse configuration file. const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor"); LOG_ALWAYS_FATAL_IF(!configRoot); - const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"); + Config config{ + .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"), + .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), + }; return std::unique_ptr( - new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval)); + new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config))); } TfLiteMotionPredictorModel::TfLiteMotionPredictorModel( - std::unique_ptr model, nsecs_t predictionInterval) - : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) { + std::unique_ptr model, Config config) + : mFlatBuffer(std::move(model)), mConfig(std::move(config)) { CHECK(mFlatBuffer); mErrorReporter = std::make_unique(); mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(), diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index 7a62f5ec58..4ac7ae920e 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -72,11 +72,20 @@ TEST(MotionPredictorTest, IsPredictionAvailable) { ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN)); } +TEST(MotionPredictorTest, StationaryNoiseFloor) { + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, + []() { return true /*enable prediction*/; }); + predictor.record(getMotionEvent(DOWN, 0, 1, 30ms)); + predictor.record(getMotionEvent(MOVE, 0, 1, 35ms)); // No movement. + std::unique_ptr predicted = predictor.predict(40 * NSEC_PER_MSEC); + ASSERT_EQ(nullptr, predicted); +} + TEST(MotionPredictorTest, Offset) { MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, []() { return true /*enable prediction*/; }); predictor.record(getMotionEvent(DOWN, 0, 1, 30ms)); - predictor.record(getMotionEvent(MOVE, 0, 2, 35ms)); + predictor.record(getMotionEvent(MOVE, 0, 5, 35ms)); // Move enough to overcome the noise floor. std::unique_ptr predicted = predictor.predict(40 * NSEC_PER_MSEC); ASSERT_NE(nullptr, predicted); ASSERT_GE(predicted->getEventTime(), 41); -- cgit v1.2.3-59-g8ed1b