diff options
author | 2023-06-26 14:15:15 -0700 | |
---|---|---|
committer | 2023-06-26 17:25:10 -0700 | |
commit | f84fa49e72f14b14669bf99bad92ed05169e5cf1 (patch) | |
tree | 43f9f4304c53b8f5818f152a60257cfc84f3d22e | |
parent | 5d8faa4526cfe2a1fed4647234fabf5e260120b3 (diff) |
Move MotionPredictor config to an XML file alongside the model.
Test: atest libinput_tests
Fixes: 266747937
Change-Id: Ic5ec548d2edc8bad5e8b88aaf8511cd297a89275
-rw-r--r-- | data/etc/input/Android.bp | 15 | ||||
-rw-r--r-- | data/etc/input/motion_predictor_config.xml | 20 | ||||
-rw-r--r-- | data/etc/input/motion_predictor_model.tflite (renamed from data/etc/input/motion_predictor_model.fb) | bin | 34080 -> 34080 bytes | |||
-rw-r--r-- | include/input/TfLiteMotionPredictor.h | 9 | ||||
-rw-r--r-- | libs/input/Android.bp | 2 | ||||
-rw-r--r-- | libs/input/MotionPredictor.cpp | 6 | ||||
-rw-r--r-- | libs/input/TfLiteMotionPredictor.cpp | 40 | ||||
-rw-r--r-- | libs/input/tests/Android.bp | 3 |
8 files changed, 79 insertions, 16 deletions
diff --git a/data/etc/input/Android.bp b/data/etc/input/Android.bp index 90f3c6b49a..b662491272 100644 --- a/data/etc/input/Android.bp +++ b/data/etc/input/Android.bp @@ -3,12 +3,21 @@ package { } filegroup { - name: "motion_predictor_model.fb", - srcs: ["motion_predictor_model.fb"], + name: "motion_predictor_model", + srcs: [ + "motion_predictor_model.tflite", + "motion_predictor_config.xml", + ], } prebuilt_etc { name: "motion_predictor_model_prebuilt", filename_from_src: true, - src: "motion_predictor_model.fb", + src: "motion_predictor_model.tflite", +} + +prebuilt_etc { + name: "motion_predictor_model_config", + filename_from_src: true, + src: "motion_predictor_config.xml", } diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml new file mode 100644 index 0000000000..03dfd63cbd --- /dev/null +++ b/data/etc/input/motion_predictor_config.xml @@ -0,0 +1,20 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- Copyright (C) 2023 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<motion-predictor> + <!-- The time interval (ns) between the model's predictions. --> + <prediction-interval>4166666</prediction-interval> <!-- 4.167 ms = ~240 Hz --> +</motion-predictor> + diff --git a/data/etc/input/motion_predictor_model.fb b/data/etc/input/motion_predictor_model.tflite Binary files differindex 10b3c8b114..10b3c8b114 100644 --- a/data/etc/input/motion_predictor_model.fb +++ b/data/etc/input/motion_predictor_model.tflite diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index a340bd0575..fbd60261b2 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -25,6 +25,7 @@ #include <android-base/mapped_file.h> #include <input/RingBuffer.h> +#include <utils/Timers.h> #include <tensorflow/lite/core/api/error_reporter.h> #include <tensorflow/lite/interpreter.h> @@ -109,6 +110,9 @@ public: // Returns the length of the model's output buffers. size_t outputLength() const; + // Returns the time interval between predictions. + nsecs_t predictionInterval() const { return mPredictionInterval; } + // Executes the model. // Returns true if the model successfully executed and the output tensors can be read. bool invoke(); @@ -127,7 +131,8 @@ public: std::span<const float> outputPressure() const; private: - explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model); + explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model, + nsecs_t predictionInterval); void allocateTensors(); void attachInputTensors(); @@ -148,6 +153,8 @@ private: std::unique_ptr<tflite::FlatBufferModel> mModel; std::unique_ptr<tflite::Interpreter> mInterpreter; tflite::SignatureRunner* mRunner = nullptr; + + const nsecs_t mPredictionInterval = 0; }; } // namespace android diff --git a/libs/input/Android.bp b/libs/input/Android.bp index 9ac1829b93..8a17d8a831 100644 --- a/libs/input/Android.bp +++ b/libs/input/Android.bp @@ -215,6 +215,7 @@ cc_library { "libcutils", "liblog", "libPlatformProperties", + "libtinyxml2", "libvintf", ], @@ -271,6 +272,7 @@ cc_library { required: [ "motion_predictor_model_prebuilt", + "motion_predictor_model_config", ], }, host: { diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 3037573538..947a95610a 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -36,9 +36,6 @@ namespace android { namespace { -const int64_t PREDICTION_INTERVAL_NANOS = - 12500000 / 3; // TODO(b/266747937): Get this from the model. - /** * Log debug messages about predictions. * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG" @@ -189,7 +186,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { // TODO(b/266747654): Stop predictions if predicted pressure is < some threshold. coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]); - predictionTime += PREDICTION_INTERVAL_NANOS; + predictionTime += mModel->predictionInterval(); if (i == 0) { hasPredictions = true; prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(), @@ -208,7 +205,6 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) { axisFrom = axisTo; axisTo = point; } - // TODO(b/266747511): Interpolate to futureTime? if (!hasPredictions) { return nullptr; } diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index 85fa176129..9f4aaa8337 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -36,6 +36,7 @@ #define ATRACE_TAG ATRACE_TAG_INPUT #include <cutils/trace.h> #include <log/log.h> +#include <utils/Timers.h> #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -44,6 +45,8 @@ #include "tensorflow/lite/model.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tinyxml2.h" + namespace android { namespace { @@ -72,16 +75,31 @@ bool fileExists(const char* filename) { std::string getModelPath() { #if defined(__ANDROID__) - static const char* oemModel = "/vendor/etc/motion_predictor_model.fb"; + static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite"; if (fileExists(oemModel)) { return oemModel; } - return "/system/etc/motion_predictor_model.fb"; + return "/system/etc/motion_predictor_model.tflite"; #else - return base::GetExecutableDirectory() + "/motion_predictor_model.fb"; + return base::GetExecutableDirectory() + "/motion_predictor_model.tflite"; #endif } +std::string getConfigPath() { + // The config file should be alongside the model file. + return base::Dirname(getModelPath()) + "/motion_predictor_config.xml"; +} + +int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) { + const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); + LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); + + int64_t value = 0; + LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS, + "Failed to parse %s: %s", elementName, element->GetText()); + return value; +} + // A TFLite ErrorReporter that logs to logcat. class LoggingErrorReporter : public tflite::ErrorReporter { public: @@ -246,13 +264,23 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() PLOG(FATAL) << "Failed to mmap model"; } + const std::string configPath = getConfigPath(); + tinyxml2::XMLDocument configDocument; + LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS, + "Failed to load config file from %s", configPath.c_str()); + + // Parse configuration file. + const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor"); + LOG_ALWAYS_FATAL_IF(!configRoot); + const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"); + return std::unique_ptr<TfLiteMotionPredictorModel>( - new TfLiteMotionPredictorModel(std::move(modelBuffer))); + new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval)); } TfLiteMotionPredictorModel::TfLiteMotionPredictorModel( - std::unique_ptr<android::base::MappedFile> model) - : mFlatBuffer(std::move(model)) { + std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval) + : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) { CHECK(mFlatBuffer); mErrorReporter = std::make_unique<LoggingErrorReporter>(); mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(), diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp index 6aae25d6d7..cadac88030 100644 --- a/libs/input/tests/Android.bp +++ b/libs/input/tests/Android.bp @@ -64,12 +64,13 @@ cc_test { "libcutils", "liblog", "libPlatformProperties", + "libtinyxml2", "libutils", "libvintf", ], data: [ "data/*", - ":motion_predictor_model.fb", + ":motion_predictor_model", ], test_options: { unit_test: true, |