summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Cody Heiner <codyheiner@google.com> 2023-03-30 18:41:45 -0700
committer Cody Heiner <codyheiner@google.com> 2023-04-05 19:31:11 -0700
commitdbd14eb7bcfc3255455c343da1b52f99bbd864ff (patch)
treedcd7cdebf8a35e34f905ab111475e1f39f8cb68c
parent1f428c8affd9ca3951b01e88cb93545f109f78f3 (diff)
Add outputLength method
Test: build succeeds, `atest libinput-tests` passes. Bug: 268245099 Change-Id: I030da703a907eef44e85d186144eddc53b5998cc
-rw-r--r--include/input/TfLiteMotionPredictor.h3
-rw-r--r--libs/input/TfLiteMotionPredictor.cpp4
-rw-r--r--libs/input/tests/TfLiteMotionPredictor_test.cpp6
3 files changed, 11 insertions, 2 deletions
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
index 7de551b417..a340bd0575 100644
--- a/include/input/TfLiteMotionPredictor.h
+++ b/include/input/TfLiteMotionPredictor.h
@@ -106,6 +106,9 @@ public:
// Returns the length of the model's input buffers.
size_t inputLength() const;
+ // Returns the length of the model's output buffers.
+ size_t outputLength() const;
+
// Executes the model.
// Returns true if the model successfully executed and the output tensors can be read.
bool invoke();
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index 3b061d1cf1..85fa176129 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -346,6 +346,10 @@ size_t TfLiteMotionPredictorModel::inputLength() const {
return getTensorBuffer<const float>(mInputR).size();
}
+size_t TfLiteMotionPredictorModel::outputLength() const {
+ return getTensorBuffer<const float>(mOutputR).size();
+}
+
std::span<float> TfLiteMotionPredictorModel::inputR() {
return getTensorBuffer<float>(mInputR);
}
diff --git a/libs/input/tests/TfLiteMotionPredictor_test.cpp b/libs/input/tests/TfLiteMotionPredictor_test.cpp
index 6e76ac1e52..b5ed9e4430 100644
--- a/libs/input/tests/TfLiteMotionPredictor_test.cpp
+++ b/libs/input/tests/TfLiteMotionPredictor_test.cpp
@@ -139,8 +139,10 @@ TEST(TfLiteMotionPredictorTest, ModelInputOutputLength) {
ASSERT_TRUE(model->invoke());
- ASSERT_EQ(model->outputR().size(), model->outputPhi().size());
- ASSERT_EQ(model->outputR().size(), model->outputPressure().size());
+ const int outputLength = model->outputLength();
+ ASSERT_EQ(outputLength, model->outputR().size());
+ ASSERT_EQ(outputLength, model->outputPhi().size());
+ ASSERT_EQ(outputLength, model->outputPressure().size());
}
TEST(TfLiteMotionPredictorTest, ModelOutput) {