summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Philip Quinn <pquinn@google.com> 2023-07-25 16:54:08 +0000
committer Android (Google) Code Review <android-gerrit@google.com> 2023-07-25 16:54:08 +0000
commit3c60252e14227672185711d56d65d2573bbb034b (patch)
tree84d361865b12f8dc52c5d278b2e60763c72de7fc
parent027baa7a4c9792e920ae32813444f6e936cebb56 (diff)
parent107ce707b5eaf8758806d456d07832c1db477555 (diff)
Merge "Update motion prediction model." into udc-qpr-dev
-rw-r--r--data/etc/input/motion_predictor_config.xml15
-rw-r--r--data/etc/input/motion_predictor_model.tflitebin34080 -> 179532 bytes
-rw-r--r--include/input/TfLiteMotionPredictor.h15
-rw-r--r--libs/input/MotionPredictor.cpp19
-rw-r--r--libs/input/TfLiteMotionPredictor.cpp36
-rw-r--r--libs/input/tests/MotionPredictor_test.cpp11
6 files changed, 74 insertions, 22 deletions
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml
index 03dfd63cbd..39772aece2 100644
--- a/data/etc/input/motion_predictor_config.xml
+++ b/data/etc/input/motion_predictor_config.xml
@@ -16,5 +16,20 @@
<motion-predictor>
<!-- The time interval (ns) between the model's predictions. -->
<prediction-interval>4166666</prediction-interval> <!-- 4.167 ms = ~240 Hz -->
+ <!-- The noise floor (px) for predicted distances.
+
+ As the model is trained stochastically, there is some expected minimum
+ variability in its output. This can be a UX issue when the input device
+ is moving slowly and the variability is large relative to the magnitude
+ of the motion. In these cases, it is better to inhibit the prediction,
+ rather than show noisy predictions (and there is little benefit to
+ prediction anyway).
+
+ The value for this parameter should at least be close to the maximum
+ predicted distance when the input device is held stationary (i.e. the
+ expected minimum variability), and perhaps a little larger to capture
+ the UX issue mentioned above.
+ -->
+ <distance-noise-floor>0.2</distance-noise-floor>
</motion-predictor>
diff --git a/data/etc/input/motion_predictor_model.tflite b/data/etc/input/motion_predictor_model.tflite
index 10b3c8b114..45fc162cd1 100644
--- a/data/etc/input/motion_predictor_model.tflite
+++ b/data/etc/input/motion_predictor_model.tflite
Binary files differ
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
index fbd60261b2..2edc138f67 100644
--- a/include/input/TfLiteMotionPredictor.h
+++ b/include/input/TfLiteMotionPredictor.h
@@ -99,6 +99,14 @@ private:
// A TFLite model for generating motion predictions.
class TfLiteMotionPredictorModel {
public:
+ struct Config {
+ // The time between predictions.
+ nsecs_t predictionInterval = 0;
+ // The noise floor for predictions.
+ // Distances (r) less than this should be discarded as noise.
+ float distanceNoiseFloor = 0;
+ };
+
// Creates a model from an encoded Flatbuffer model.
static std::unique_ptr<TfLiteMotionPredictorModel> create();
@@ -110,8 +118,7 @@ public:
// Returns the length of the model's output buffers.
size_t outputLength() const;
- // Returns the time interval between predictions.
- nsecs_t predictionInterval() const { return mPredictionInterval; }
+ const Config& config() const { return mConfig; }
// Executes the model.
// Returns true if the model successfully executed and the output tensors can be read.
@@ -132,7 +139,7 @@ public:
private:
explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
- nsecs_t predictionInterval);
+ Config config);
void allocateTensors();
void attachInputTensors();
@@ -154,7 +161,7 @@ private:
std::unique_ptr<tflite::Interpreter> mInterpreter;
tflite::SignatureRunner* mRunner = nullptr;
- const nsecs_t mPredictionInterval = 0;
+ const Config mConfig = {};
};
} // namespace android
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 68e688817b..c2ea35c6bf 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -138,7 +138,8 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
// Pass input event to the MetricsManager.
if (!mMetricsManager) {
mMetricsManager =
- std::make_optional<MotionPredictorMetricsManager>(mModel->predictionInterval(),
+ std::make_optional<MotionPredictorMetricsManager>(mModel->config()
+ .predictionInterval,
mModel->outputLength());
}
mMetricsManager->onRecord(event);
@@ -184,8 +185,18 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
- // TODO(b/266747654): Stop predictions if confidence and/or predicted pressure are below
- // some thresholds.
+ if (predictedR[i] < mModel->config().distanceNoiseFloor) {
+ // Stop predicting when the predicted output is below the model's noise floor.
+ //
+ // We assume that all subsequent predictions in the batch are unreliable because later
+ // predictions are conditional on earlier predictions, and a state of noise is not a
+ // good basis for prediction.
+ //
+ // The UX trade-off is that this potentially sacrifices some predictions when the input
+ // device starts to speed up, but avoids producing noisy predictions as it slows down.
+ break;
+ }
+ // TODO(b/266747654): Stop predictions if confidence is < some threshold.
const TfLiteMotionPredictorSample::Point predictedPoint =
convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
@@ -197,7 +208,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y);
coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
- predictionTime += mModel->predictionInterval();
+ predictionTime += mModel->config().predictionInterval;
if (i == 0) {
hasPredictions = true;
prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index 9f4aaa8337..5984b4d3b9 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -100,6 +100,16 @@ int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elemen
return value;
}
+float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) {
+ const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
+ LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
+
+ float value = 0;
+ LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS,
+ "Failed to parse %s: %s", elementName, element->GetText());
+ return value;
+}
+
// A TFLite ErrorReporter that logs to logcat.
class LoggingErrorReporter : public tflite::ErrorReporter {
public:
@@ -152,6 +162,7 @@ std::unique_ptr<tflite::OpResolver> createOpResolver() {
::tflite::ops::builtin::Register_CONCATENATION());
resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
::tflite::ops::builtin::Register_FULLY_CONNECTED());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU());
return resolver;
}
@@ -208,13 +219,7 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
float phi = 0;
float orientation = 0;
- // Ignore the sample if there is no movement. These samples can occur when there's change to a
- // property other than the coordinates and pollute the input to the model.
- if (r == 0) {
- return;
- }
-
- if (!mAxisFrom) { // Second point.
+ if (!mAxisFrom && r > 0) { // Second point.
// We can only determine the distance from the first point, and not any
// angle. However, if the second point forms an axis, the orientation can
// be transformed relative to that axis.
@@ -235,8 +240,10 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
}
// Update the axis for the next point.
- mAxisFrom = mAxisTo;
- mAxisTo = sample;
+ if (r > 0) {
+ mAxisFrom = mAxisTo;
+ mAxisTo = sample;
+ }
// Push the current sample onto the end of the input buffers.
mInputR.pushBack(r);
@@ -272,15 +279,18 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
// Parse configuration file.
const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
LOG_ALWAYS_FATAL_IF(!configRoot);
- const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval");
+ Config config{
+ .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
+ .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
+ };
return std::unique_ptr<TfLiteMotionPredictorModel>(
- new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval));
+ new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config)));
}
TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
- std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval)
- : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) {
+ std::unique_ptr<android::base::MappedFile> model, Config config)
+ : mFlatBuffer(std::move(model)), mConfig(std::move(config)) {
CHECK(mFlatBuffer);
mErrorReporter = std::make_unique<LoggingErrorReporter>();
mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index 7a62f5ec58..4ac7ae920e 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -72,11 +72,20 @@ TEST(MotionPredictorTest, IsPredictionAvailable) {
ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
}
+TEST(MotionPredictorTest, StationaryNoiseFloor) {
+ MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
+ []() { return true /*enable prediction*/; });
+ predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
+ predictor.record(getMotionEvent(MOVE, 0, 1, 35ms)); // No movement.
+ std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
+ ASSERT_EQ(nullptr, predicted);
+}
+
TEST(MotionPredictorTest, Offset) {
MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
[]() { return true /*enable prediction*/; });
predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
- predictor.record(getMotionEvent(MOVE, 0, 2, 35ms));
+ predictor.record(getMotionEvent(MOVE, 0, 5, 35ms)); // Move enough to overcome the noise floor.
std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
ASSERT_NE(nullptr, predicted);
ASSERT_GE(predicted->getEventTime(), 41);