diff options
| author | 2023-07-25 16:54:08 +0000 | |
|---|---|---|
| committer | 2023-07-25 16:54:08 +0000 | |
| commit | 3c60252e14227672185711d56d65d2573bbb034b (patch) | |
| tree | 84d361865b12f8dc52c5d278b2e60763c72de7fc | |
| parent | 027baa7a4c9792e920ae32813444f6e936cebb56 (diff) | |
| parent | 107ce707b5eaf8758806d456d07832c1db477555 (diff) | |
Merge "Update motion prediction model." into udc-qpr-dev
| -rw-r--r-- | data/etc/input/motion_predictor_config.xml | 15 | ||||
| -rw-r--r-- | data/etc/input/motion_predictor_model.tflite | bin | 34080 -> 179532 bytes | |||
| -rw-r--r-- | include/input/TfLiteMotionPredictor.h | 15 | ||||
| -rw-r--r-- | libs/input/MotionPredictor.cpp | 19 | ||||
| -rw-r--r-- | libs/input/TfLiteMotionPredictor.cpp | 36 | ||||
| -rw-r--r-- | libs/input/tests/MotionPredictor_test.cpp | 11 |
6 files changed, 74 insertions, 22 deletions
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml index 03dfd63cbd..39772aece2 100644 --- a/data/etc/input/motion_predictor_config.xml +++ b/data/etc/input/motion_predictor_config.xml @@ -16,5 +16,20 @@ <motion-predictor> <!-- The time interval (ns) between the model's predictions. --> <prediction-interval>4166666</prediction-interval> <!-- 4.167 ms = ~240 Hz --> + <!-- The noise floor (px) for predicted distances. + + As the model is trained stochastically, there is some expected minimum + variability in its output. This can be a UX issue when the input device + is moving slowly and the variability is large relative to the magnitude + of the motion. In these cases, it is better to inhibit the prediction, + rather than show noisy predictions (and there is little benefit to + prediction anyway). + + The value for this parameter should at least be close to the maximum + predicted distance when the input device is held stationary (i.e. the + expected minimum variability), and perhaps a little larger to capture + the UX issue mentioned above. + --> + <distance-noise-floor>0.2</distance-noise-floor> </motion-predictor> diff --git a/data/etc/input/motion_predictor_model.tflite b/data/etc/input/motion_predictor_model.tflite Binary files differindex 10b3c8b114..45fc162cd1 100644 --- a/data/etc/input/motion_predictor_model.tflite +++ b/data/etc/input/motion_predictor_model.tflite diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index fbd60261b2..2edc138f67 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -99,6 +99,14 @@ private: // A TFLite model for generating motion predictions. class TfLiteMotionPredictorModel { public: + struct Config { + // The time between predictions. + nsecs_t predictionInterval = 0; + // The noise floor for predictions. + // Distances (r) less than this should be discarded as noise. + float distanceNoiseFloor = 0; + }; + // Creates a model from an encoded Flatbuffer model. static std::unique_ptr<TfLiteMotionPredictorModel> create(); @@ -110,8 +118,7 @@ public: // Returns the length of the model's output buffers. size_t outputLength() const; - // Returns the time interval between predictions. - nsecs_t predictionInterval() const { return mPredictionInterval; } + const Config& config() const { return mConfig; } // Executes the model. // Returns true if the model successfully executed and the output tensors can be read. @@ -132,7 +139,7 @@ public: private: explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model, - nsecs_t predictionInterval); + Config config); void allocateTensors(); void attachInputTensors(); @@ -154,7 +161,7 @@ private: std::unique_ptr<tflite::Interpreter> mInterpreter; tflite::SignatureRunner* mRunner = nullptr; - const nsecs_t mPredictionInterval = 0; + const Config mConfig = {}; }; } // namespace android diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 68e688817b..c2ea35c6bf 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -138,7 +138,8 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { // Pass input event to the MetricsManager. if (!mMetricsManager) { mMetricsManager = - std::make_optional<MotionPredictorMetricsManager>(mModel->predictionInterval(), + std::make_optional<MotionPredictorMetricsManager>(mModel->config() + .predictionInterval, mModel->outputLength()); } mMetricsManager->onRecord(event); @@ -184,8 +185,18 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos; for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) { - // TODO(b/266747654): Stop predictions if confidence and/or predicted pressure are below - // some thresholds. + if (predictedR[i] < mModel->config().distanceNoiseFloor) { + // Stop predicting when the predicted output is below the model's noise floor. + // + // We assume that all subsequent predictions in the batch are unreliable because later + // predictions are conditional on earlier predictions, and a state of noise is not a + // good basis for prediction. + // + // The UX trade-off is that this potentially sacrifices some predictions when the input + // device starts to speed up, but avoids producing noisy predictions as it slows down. + break; + } + // TODO(b/266747654): Stop predictions if confidence is < some threshold. const TfLiteMotionPredictorSample::Point predictedPoint = convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]); @@ -197,7 +208,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y); coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]); - predictionTime += mModel->predictionInterval(); + predictionTime += mModel->config().predictionInterval; if (i == 0) { hasPredictions = true; prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(), diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index 9f4aaa8337..5984b4d3b9 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -100,6 +100,16 @@ int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elemen return value; } +float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) { + const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); + LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); + + float value = 0; + LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS, + "Failed to parse %s: %s", elementName, element->GetText()); + return value; +} + // A TFLite ErrorReporter that logs to logcat. class LoggingErrorReporter : public tflite::ErrorReporter { public: @@ -152,6 +162,7 @@ std::unique_ptr<tflite::OpResolver> createOpResolver() { ::tflite::ops::builtin::Register_CONCATENATION()); resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, ::tflite::ops::builtin::Register_FULLY_CONNECTED()); + resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU()); return resolver; } @@ -208,13 +219,7 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, float phi = 0; float orientation = 0; - // Ignore the sample if there is no movement. These samples can occur when there's change to a - // property other than the coordinates and pollute the input to the model. - if (r == 0) { - return; - } - - if (!mAxisFrom) { // Second point. + if (!mAxisFrom && r > 0) { // Second point. // We can only determine the distance from the first point, and not any // angle. However, if the second point forms an axis, the orientation can // be transformed relative to that axis. @@ -235,8 +240,10 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, } // Update the axis for the next point. - mAxisFrom = mAxisTo; - mAxisTo = sample; + if (r > 0) { + mAxisFrom = mAxisTo; + mAxisTo = sample; + } // Push the current sample onto the end of the input buffers. mInputR.pushBack(r); @@ -272,15 +279,18 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() // Parse configuration file. const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor"); LOG_ALWAYS_FATAL_IF(!configRoot); - const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"); + Config config{ + .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"), + .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), + }; return std::unique_ptr<TfLiteMotionPredictorModel>( - new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval)); + new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config))); } TfLiteMotionPredictorModel::TfLiteMotionPredictorModel( - std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval) - : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) { + std::unique_ptr<android::base::MappedFile> model, Config config) + : mFlatBuffer(std::move(model)), mConfig(std::move(config)) { CHECK(mFlatBuffer); mErrorReporter = std::make_unique<LoggingErrorReporter>(); mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(), diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index 7a62f5ec58..4ac7ae920e 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -72,11 +72,20 @@ TEST(MotionPredictorTest, IsPredictionAvailable) { ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN)); } +TEST(MotionPredictorTest, StationaryNoiseFloor) { + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, + []() { return true /*enable prediction*/; }); + predictor.record(getMotionEvent(DOWN, 0, 1, 30ms)); + predictor.record(getMotionEvent(MOVE, 0, 1, 35ms)); // No movement. + std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC); + ASSERT_EQ(nullptr, predicted); +} + TEST(MotionPredictorTest, Offset) { MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, []() { return true /*enable prediction*/; }); predictor.record(getMotionEvent(DOWN, 0, 1, 30ms)); - predictor.record(getMotionEvent(MOVE, 0, 2, 35ms)); + predictor.record(getMotionEvent(MOVE, 0, 5, 35ms)); // Move enough to overcome the noise floor. std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC); ASSERT_NE(nullptr, predicted); ASSERT_GE(predicted->getEventTime(), 41); |