summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Philip Quinn <pquinn@google.com> 2023-02-10 11:45:01 -0800
committer Philip Quinn <pquinn@google.com> 2023-02-10 11:46:38 -0800
commitbd66e62069a5495ee8ba8b9ff8c35d3a2075a06d (patch)
tree5762c35299ccce6890fae8a8c6f29075ded84d92
parentcb3229aaf2233ebb917d967a6e73d48cce1a1480 (diff)
Postpone loading the TFLite model until a supported event is recorded.
Bug: 267050081 Test: atest libinput_tests Change-Id: I09666da123a58786e8a6d47d4c29a475e92f2bbf
-rw-r--r--include/input/MotionPredictor.h2
-rw-r--r--libs/input/MotionPredictor.cpp11
2 files changed, 10 insertions, 3 deletions
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h
index 3fae4e6b68..68ebf75fc6 100644
--- a/include/input/MotionPredictor.h
+++ b/include/input/MotionPredictor.h
@@ -19,6 +19,7 @@
#include <cstdint>
#include <memory>
#include <mutex>
+#include <string>
#include <unordered_map>
#include <android-base/thread_annotations.h>
@@ -73,6 +74,7 @@ public:
private:
const nsecs_t mPredictionTimestampOffsetNanos;
+ const std::string mModelPath;
const std::function<bool()> mCheckMotionPredictionEnabled;
std::unique_ptr<TfLiteMotionPredictorModel> mModel;
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 0f889e8128..7d11ef2575 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -65,9 +65,8 @@ TfLiteMotionPredictorSample::Point convertPrediction(
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
std::function<bool()> checkMotionPredictionEnabled)
: mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
- mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
- mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH
- : modelPath)) {}
+ mModelPath(modelPath == nullptr ? DEFAULT_MODEL_PATH : modelPath),
+ mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
void MotionPredictor::record(const MotionEvent& event) {
if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
@@ -76,6 +75,11 @@ void MotionPredictor::record(const MotionEvent& event) {
return;
}
+ // Initialise the model now that it's likely to be used.
+ if (!mModel) {
+ mModel = TfLiteMotionPredictorModel::create(mModelPath.c_str());
+ }
+
TfLiteMotionPredictorBuffers& buffers =
mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
@@ -130,6 +134,7 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
continue;
}
+ LOG_ALWAYS_FATAL_IF(!mModel);
buffer.copyTo(*mModel);
LOG_ALWAYS_FATAL_IF(!mModel->invoke());