summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Philip Quinn <pquinn@google.com> 2023-06-26 14:15:15 -0700
committer Philip Quinn <pquinn@google.com> 2023-06-26 17:25:10 -0700
commitf84fa49e72f14b14669bf99bad92ed05169e5cf1 (patch)
tree43f9f4304c53b8f5818f152a60257cfc84f3d22e
parent5d8faa4526cfe2a1fed4647234fabf5e260120b3 (diff)
Move MotionPredictor config to an XML file alongside the model.
Test: atest libinput_tests Fixes: 266747937 Change-Id: Ic5ec548d2edc8bad5e8b88aaf8511cd297a89275
-rw-r--r--data/etc/input/Android.bp15
-rw-r--r--data/etc/input/motion_predictor_config.xml20
-rw-r--r--data/etc/input/motion_predictor_model.tflite (renamed from data/etc/input/motion_predictor_model.fb)bin34080 -> 34080 bytes
-rw-r--r--include/input/TfLiteMotionPredictor.h9
-rw-r--r--libs/input/Android.bp2
-rw-r--r--libs/input/MotionPredictor.cpp6
-rw-r--r--libs/input/TfLiteMotionPredictor.cpp40
-rw-r--r--libs/input/tests/Android.bp3
8 files changed, 79 insertions, 16 deletions
diff --git a/data/etc/input/Android.bp b/data/etc/input/Android.bp
index 90f3c6b49a..b662491272 100644
--- a/data/etc/input/Android.bp
+++ b/data/etc/input/Android.bp
@@ -3,12 +3,21 @@ package {
}
filegroup {
- name: "motion_predictor_model.fb",
- srcs: ["motion_predictor_model.fb"],
+ name: "motion_predictor_model",
+ srcs: [
+ "motion_predictor_model.tflite",
+ "motion_predictor_config.xml",
+ ],
}
prebuilt_etc {
name: "motion_predictor_model_prebuilt",
filename_from_src: true,
- src: "motion_predictor_model.fb",
+ src: "motion_predictor_model.tflite",
+}
+
+prebuilt_etc {
+ name: "motion_predictor_model_config",
+ filename_from_src: true,
+ src: "motion_predictor_config.xml",
}
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml
new file mode 100644
index 0000000000..03dfd63cbd
--- /dev/null
+++ b/data/etc/input/motion_predictor_config.xml
@@ -0,0 +1,20 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2023 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<motion-predictor>
+ <!-- The time interval (ns) between the model's predictions. -->
+ <prediction-interval>4166666</prediction-interval> <!-- 4.167 ms = ~240 Hz -->
+</motion-predictor>
+
diff --git a/data/etc/input/motion_predictor_model.fb b/data/etc/input/motion_predictor_model.tflite
index 10b3c8b114..10b3c8b114 100644
--- a/data/etc/input/motion_predictor_model.fb
+++ b/data/etc/input/motion_predictor_model.tflite
Binary files differ
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
index a340bd0575..fbd60261b2 100644
--- a/include/input/TfLiteMotionPredictor.h
+++ b/include/input/TfLiteMotionPredictor.h
@@ -25,6 +25,7 @@
#include <android-base/mapped_file.h>
#include <input/RingBuffer.h>
+#include <utils/Timers.h>
#include <tensorflow/lite/core/api/error_reporter.h>
#include <tensorflow/lite/interpreter.h>
@@ -109,6 +110,9 @@ public:
// Returns the length of the model's output buffers.
size_t outputLength() const;
+ // Returns the time interval between predictions.
+ nsecs_t predictionInterval() const { return mPredictionInterval; }
+
// Executes the model.
// Returns true if the model successfully executed and the output tensors can be read.
bool invoke();
@@ -127,7 +131,8 @@ public:
std::span<const float> outputPressure() const;
private:
- explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model);
+ explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
+ nsecs_t predictionInterval);
void allocateTensors();
void attachInputTensors();
@@ -148,6 +153,8 @@ private:
std::unique_ptr<tflite::FlatBufferModel> mModel;
std::unique_ptr<tflite::Interpreter> mInterpreter;
tflite::SignatureRunner* mRunner = nullptr;
+
+ const nsecs_t mPredictionInterval = 0;
};
} // namespace android
diff --git a/libs/input/Android.bp b/libs/input/Android.bp
index 9ac1829b93..8a17d8a831 100644
--- a/libs/input/Android.bp
+++ b/libs/input/Android.bp
@@ -215,6 +215,7 @@ cc_library {
"libcutils",
"liblog",
"libPlatformProperties",
+ "libtinyxml2",
"libvintf",
],
@@ -271,6 +272,7 @@ cc_library {
required: [
"motion_predictor_model_prebuilt",
+ "motion_predictor_model_config",
],
},
host: {
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 3037573538..947a95610a 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -36,9 +36,6 @@
namespace android {
namespace {
-const int64_t PREDICTION_INTERVAL_NANOS =
- 12500000 / 3; // TODO(b/266747937): Get this from the model.
-
/**
* Log debug messages about predictions.
* Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
@@ -189,7 +186,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
// TODO(b/266747654): Stop predictions if predicted pressure is < some threshold.
coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
- predictionTime += PREDICTION_INTERVAL_NANOS;
+ predictionTime += mModel->predictionInterval();
if (i == 0) {
hasPredictions = true;
prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
@@ -208,7 +205,6 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
axisFrom = axisTo;
axisTo = point;
}
- // TODO(b/266747511): Interpolate to futureTime?
if (!hasPredictions) {
return nullptr;
}
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index 85fa176129..9f4aaa8337 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -36,6 +36,7 @@
#define ATRACE_TAG ATRACE_TAG_INPUT
#include <cutils/trace.h>
#include <log/log.h>
+#include <utils/Timers.h>
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@@ -44,6 +45,8 @@
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tinyxml2.h"
+
namespace android {
namespace {
@@ -72,16 +75,31 @@ bool fileExists(const char* filename) {
std::string getModelPath() {
#if defined(__ANDROID__)
- static const char* oemModel = "/vendor/etc/motion_predictor_model.fb";
+ static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite";
if (fileExists(oemModel)) {
return oemModel;
}
- return "/system/etc/motion_predictor_model.fb";
+ return "/system/etc/motion_predictor_model.tflite";
#else
- return base::GetExecutableDirectory() + "/motion_predictor_model.fb";
+ return base::GetExecutableDirectory() + "/motion_predictor_model.tflite";
#endif
}
+std::string getConfigPath() {
+ // The config file should be alongside the model file.
+ return base::Dirname(getModelPath()) + "/motion_predictor_config.xml";
+}
+
+int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) {
+ const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
+ LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
+
+ int64_t value = 0;
+ LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS,
+ "Failed to parse %s: %s", elementName, element->GetText());
+ return value;
+}
+
// A TFLite ErrorReporter that logs to logcat.
class LoggingErrorReporter : public tflite::ErrorReporter {
public:
@@ -246,13 +264,23 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
PLOG(FATAL) << "Failed to mmap model";
}
+ const std::string configPath = getConfigPath();
+ tinyxml2::XMLDocument configDocument;
+ LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS,
+ "Failed to load config file from %s", configPath.c_str());
+
+ // Parse configuration file.
+ const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
+ LOG_ALWAYS_FATAL_IF(!configRoot);
+ const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval");
+
return std::unique_ptr<TfLiteMotionPredictorModel>(
- new TfLiteMotionPredictorModel(std::move(modelBuffer)));
+ new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval));
}
TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
- std::unique_ptr<android::base::MappedFile> model)
- : mFlatBuffer(std::move(model)) {
+ std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval)
+ : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) {
CHECK(mFlatBuffer);
mErrorReporter = std::make_unique<LoggingErrorReporter>();
mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index 6aae25d6d7..cadac88030 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -64,12 +64,13 @@ cc_test {
"libcutils",
"liblog",
"libPlatformProperties",
+ "libtinyxml2",
"libutils",
"libvintf",
],
data: [
"data/*",
- ":motion_predictor_model.fb",
+ ":motion_predictor_model",
],
test_options: {
unit_test: true,