summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Philip Quinn <pquinn@google.com> 2023-02-08 22:50:59 -0800
committer Philip Quinn <pquinn@google.com> 2023-02-09 22:33:05 -0800
commitcb3229aaf2233ebb917d967a6e73d48cce1a1480 (patch)
treeff19ed484f5f0e967f8182d972b8bdea6a4f3802
parentda6a448e2dfbfa7f13ce243e9273ceb9bcda4388 (diff)
Use mmap to read TFLite model.
The buffers in the model file are used directly by TFLite, and so a small memory saving can be achieved by backing those memory pages with the file itself. Bug: 267050081 Test: atest libinput_tests Change-Id: I743a3c94477d4bb778b6e0c4b4890a44f4e19aa4
-rw-r--r--include/input/TfLiteMotionPredictor.h6
-rw-r--r--libs/input/TfLiteMotionPredictor.cpp41
2 files changed, 32 insertions, 15 deletions
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
index 6e9afc314b..54e2851a7a 100644
--- a/include/input/TfLiteMotionPredictor.h
+++ b/include/input/TfLiteMotionPredictor.h
@@ -22,8 +22,8 @@
#include <memory>
#include <optional>
#include <span>
-#include <string>
+#include <android-base/mapped_file.h>
#include <input/RingBuffer.h>
#include <tensorflow/lite/core/api/error_reporter.h>
@@ -124,7 +124,7 @@ public:
std::span<const float> outputPressure() const;
private:
- explicit TfLiteMotionPredictorModel(std::string model);
+ explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model);
void allocateTensors();
void attachInputTensors();
@@ -140,7 +140,7 @@ private:
const TfLiteTensor* mOutputPhi = nullptr;
const TfLiteTensor* mOutputPressure = nullptr;
- std::string mFlatBuffer;
+ std::unique_ptr<android::base::MappedFile> mFlatBuffer;
std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
std::unique_ptr<tflite::FlatBufferModel> mModel;
std::unique_ptr<tflite::Interpreter> mInterpreter;
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index fbb7106540..10510d675c 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -17,19 +17,21 @@
#define LOG_TAG "TfLiteMotionPredictor"
#include <input/TfLiteMotionPredictor.h>
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
-#include <fstream>
-#include <ios>
-#include <iterator>
#include <memory>
#include <span>
-#include <string>
#include <type_traits>
#include <utility>
+#include <android-base/logging.h>
+#include <android-base/mapped_file.h>
#define ATRACE_TAG ATRACE_TAG_INPUT
#include <cutils/trace.h>
#include <log/log.h>
@@ -206,21 +208,36 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
const char* modelPath) {
- std::ifstream f(modelPath, std::ios::binary);
- LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath);
+ const int fd = open(modelPath, O_RDONLY);
+ if (fd == -1) {
+ PLOG(FATAL) << "Could not read model from " << modelPath;
+ }
+
+ const off_t fdSize = lseek(fd, 0, SEEK_END);
+ if (fdSize == -1) {
+ PLOG(FATAL) << "Failed to determine file size";
+ }
- std::string data;
- data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>());
+ std::unique_ptr<android::base::MappedFile> modelBuffer =
+ android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
+ if (!modelBuffer) {
+ PLOG(FATAL) << "Failed to mmap model";
+ }
+ if (close(fd) == -1) {
+ PLOG(FATAL) << "Failed to close model fd";
+ }
return std::unique_ptr<TfLiteMotionPredictorModel>(
- new TfLiteMotionPredictorModel(std::move(data)));
+ new TfLiteMotionPredictorModel(std::move(modelBuffer)));
}
-TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
+TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
+ std::unique_ptr<android::base::MappedFile> model)
: mFlatBuffer(std::move(model)) {
+ CHECK(mFlatBuffer);
mErrorReporter = std::make_unique<LoggingErrorReporter>();
- mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(),
- mFlatBuffer.length(),
+ mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
+ mFlatBuffer->size(),
/*extra_verifier=*/nullptr,
mErrorReporter.get());
LOG_ALWAYS_FATAL_IF(!mModel);