diff options
author | 2023-02-08 22:50:59 -0800 | |
---|---|---|
committer | 2023-02-09 22:33:05 -0800 | |
commit | cb3229aaf2233ebb917d967a6e73d48cce1a1480 (patch) | |
tree | ff19ed484f5f0e967f8182d972b8bdea6a4f3802 | |
parent | da6a448e2dfbfa7f13ce243e9273ceb9bcda4388 (diff) |
Use mmap to read TFLite model.
The buffers in the model file are used directly by TFLite, and so a
small memory saving can be achieved by backing those memory pages with
the file itself.
Bug: 267050081
Test: atest libinput_tests
Change-Id: I743a3c94477d4bb778b6e0c4b4890a44f4e19aa4
-rw-r--r-- | include/input/TfLiteMotionPredictor.h | 6 | ||||
-rw-r--r-- | libs/input/TfLiteMotionPredictor.cpp | 41 |
2 files changed, 32 insertions, 15 deletions
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h index 6e9afc314b..54e2851a7a 100644 --- a/include/input/TfLiteMotionPredictor.h +++ b/include/input/TfLiteMotionPredictor.h @@ -22,8 +22,8 @@ #include <memory> #include <optional> #include <span> -#include <string> +#include <android-base/mapped_file.h> #include <input/RingBuffer.h> #include <tensorflow/lite/core/api/error_reporter.h> @@ -124,7 +124,7 @@ public: std::span<const float> outputPressure() const; private: - explicit TfLiteMotionPredictorModel(std::string model); + explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model); void allocateTensors(); void attachInputTensors(); @@ -140,7 +140,7 @@ private: const TfLiteTensor* mOutputPhi = nullptr; const TfLiteTensor* mOutputPressure = nullptr; - std::string mFlatBuffer; + std::unique_ptr<android::base::MappedFile> mFlatBuffer; std::unique_ptr<tflite::ErrorReporter> mErrorReporter; std::unique_ptr<tflite::FlatBufferModel> mModel; std::unique_ptr<tflite::Interpreter> mInterpreter; diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp index fbb7106540..10510d675c 100644 --- a/libs/input/TfLiteMotionPredictor.cpp +++ b/libs/input/TfLiteMotionPredictor.cpp @@ -17,19 +17,21 @@ #define LOG_TAG "TfLiteMotionPredictor" #include <input/TfLiteMotionPredictor.h> +#include <fcntl.h> +#include <sys/mman.h> +#include <unistd.h> + #include <algorithm> #include <cmath> #include <cstddef> #include <cstdint> -#include <fstream> -#include <ios> -#include <iterator> #include <memory> #include <span> -#include <string> #include <type_traits> #include <utility> +#include <android-base/logging.h> +#include <android-base/mapped_file.h> #define ATRACE_TAG ATRACE_TAG_INPUT #include <cutils/trace.h> #include <log/log.h> @@ -206,21 +208,36 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create( const char* modelPath) { - std::ifstream f(modelPath, std::ios::binary); - LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath); + const int fd = open(modelPath, O_RDONLY); + if (fd == -1) { + PLOG(FATAL) << "Could not read model from " << modelPath; + } + + const off_t fdSize = lseek(fd, 0, SEEK_END); + if (fdSize == -1) { + PLOG(FATAL) << "Failed to determine file size"; + } - std::string data; - data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>()); + std::unique_ptr<android::base::MappedFile> modelBuffer = + android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ); + if (!modelBuffer) { + PLOG(FATAL) << "Failed to mmap model"; + } + if (close(fd) == -1) { + PLOG(FATAL) << "Failed to close model fd"; + } return std::unique_ptr<TfLiteMotionPredictorModel>( - new TfLiteMotionPredictorModel(std::move(data))); + new TfLiteMotionPredictorModel(std::move(modelBuffer))); } -TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model) +TfLiteMotionPredictorModel::TfLiteMotionPredictorModel( + std::unique_ptr<android::base::MappedFile> model) : mFlatBuffer(std::move(model)) { + CHECK(mFlatBuffer); mErrorReporter = std::make_unique<LoggingErrorReporter>(); - mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(), - mFlatBuffer.length(), + mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(), + mFlatBuffer->size(), /*extra_verifier=*/nullptr, mErrorReporter.get()); LOG_ALWAYS_FATAL_IF(!mModel); |