diff options
| author | 2023-02-10 11:45:01 -0800 | |
|---|---|---|
| committer | 2023-02-10 11:46:38 -0800 | |
| commit | bd66e62069a5495ee8ba8b9ff8c35d3a2075a06d (patch) | |
| tree | 5762c35299ccce6890fae8a8c6f29075ded84d92 | |
| parent | cb3229aaf2233ebb917d967a6e73d48cce1a1480 (diff) | |
Postpone loading the TFLite model until a supported event is recorded.
Bug: 267050081
Test: atest libinput_tests
Change-Id: I09666da123a58786e8a6d47d4c29a475e92f2bbf
| -rw-r--r-- | include/input/MotionPredictor.h | 2 | ||||
| -rw-r--r-- | libs/input/MotionPredictor.cpp | 11 |
2 files changed, 10 insertions, 3 deletions
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h index 3fae4e6b68..68ebf75fc6 100644 --- a/include/input/MotionPredictor.h +++ b/include/input/MotionPredictor.h @@ -19,6 +19,7 @@ #include <cstdint> #include <memory> #include <mutex> +#include <string> #include <unordered_map> #include <android-base/thread_annotations.h> @@ -73,6 +74,7 @@ public: private: const nsecs_t mPredictionTimestampOffsetNanos; + const std::string mModelPath; const std::function<bool()> mCheckMotionPredictionEnabled; std::unique_ptr<TfLiteMotionPredictorModel> mModel; diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 0f889e8128..7d11ef2575 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -65,9 +65,8 @@ TfLiteMotionPredictorSample::Point convertPrediction( MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath, std::function<bool()> checkMotionPredictionEnabled) : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos), - mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)), - mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH - : modelPath)) {} + mModelPath(modelPath == nullptr ? DEFAULT_MODEL_PATH : modelPath), + mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {} void MotionPredictor::record(const MotionEvent& event) { if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) { @@ -76,6 +75,11 @@ void MotionPredictor::record(const MotionEvent& event) { return; } + // Initialise the model now that it's likely to be used. + if (!mModel) { + mModel = TfLiteMotionPredictorModel::create(mModelPath.c_str()); + } + TfLiteMotionPredictorBuffers& buffers = mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second; @@ -130,6 +134,7 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times continue; } + LOG_ALWAYS_FATAL_IF(!mModel); buffer.copyTo(*mModel); LOG_ALWAYS_FATAL_IF(!mModel->invoke()); |