diff options
author | 2023-11-14 14:47:10 -0800 | |
---|---|---|
committer | 2023-12-07 17:59:27 -0800 | |
commit | 7b26dbea1d677b8d783f64ac1be23274898547ec (patch) | |
tree | 277cc4d130d720b752978f404f40d4d0535a0515 | |
parent | 192968f908e22a8c086481805382c438703c5f68 (diff) |
Pass all input events to MetricsManager
The MetricsManager needs to receive UP/CANCEL events to trigger
atom reporting. I must have moved these lines around during the
refactor and overlooked this mistake.
This change also modifies MotionPredictor and MetricsManager to
hold a "ReportAtomFunction" to facilitate testing.
Test: `statsd_testdrive 718` shows atoms reported with `adb shell setenforce 0`.
Test: `atest frameworks/native/libs/input/tests/MotionPredictor_test.cpp -c` passes.
Test: `atest frameworks/native/libs/input/tests/MotionPredictorMetricsManager_test.cpp -c` passes.
Bug: 311066949
Change-Id: Icbb709bbb7cf548512e0d9aa062783d554b857e3
-rw-r--r-- | include/input/MotionPredictor.h | 16 | ||||
-rw-r--r-- | include/input/MotionPredictorMetricsManager.h | 47 | ||||
-rw-r--r-- | libs/input/MotionPredictor.cpp | 19 | ||||
-rw-r--r-- | libs/input/MotionPredictorMetricsManager.cpp | 55 | ||||
-rw-r--r-- | libs/input/tests/MotionPredictorMetricsManager_test.cpp | 79 | ||||
-rw-r--r-- | libs/input/tests/MotionPredictor_test.cpp | 31 |
6 files changed, 151 insertions, 96 deletions
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h index 8797962886..3b6e40183f 100644 --- a/include/input/MotionPredictor.h +++ b/include/input/MotionPredictor.h @@ -19,6 +19,7 @@ #include <cstdint> #include <memory> #include <mutex> +#include <optional> #include <string> #include <unordered_map> @@ -57,20 +58,23 @@ static inline bool isMotionPredictionEnabled() { */ class MotionPredictor { public: + using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction; + /** * Parameters: * predictionTimestampOffsetNanos: additional, constant shift to apply to the target * prediction time. The prediction will target the time t=(prediction time + * predictionTimestampOffsetNanos). * - * modelPath: filesystem path to a TfLiteMotionPredictorModel flatbuffer, or nullptr to use the - * default model path. - * - * checkEnableMotionPredition: the function to check whether the prediction should run. Used to + * checkEnableMotionPrediction: the function to check whether the prediction should run. Used to * provide an additional way of turning prediction on and off. Can be toggled at runtime. + * + * reportAtomFunction: the function that will be called to report prediction metrics. If + * omitted, the implementation will choose a default metrics reporting mechanism. */ MotionPredictor(nsecs_t predictionTimestampOffsetNanos, - std::function<bool()> checkEnableMotionPrediction = isMotionPredictionEnabled); + std::function<bool()> checkEnableMotionPrediction = isMotionPredictionEnabled, + ReportAtomFunction reportAtomFunction = {}); /** * Record the actual motion received by the view. This event will be used for calculating the @@ -95,6 +99,8 @@ private: std::optional<MotionEvent> mLastEvent; std::optional<MotionPredictorMetricsManager> mMetricsManager; + + const ReportAtomFunction mReportAtomFunction; }; } // namespace android diff --git a/include/input/MotionPredictorMetricsManager.h b/include/input/MotionPredictorMetricsManager.h index 12e50ba3b4..38472d8df7 100644 --- a/include/input/MotionPredictorMetricsManager.h +++ b/include/input/MotionPredictorMetricsManager.h @@ -18,7 +18,6 @@ #include <cstdint> #include <functional> #include <limits> -#include <optional> #include <vector> #include <input/Input.h> // for MotionEvent @@ -37,15 +36,33 @@ namespace android { * * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final - * AtomFields are computed and reported to the stats library. + * AtomFields are computed and reported to the stats library. The number of atoms reported is equal + * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one + * "prediction time bucket" — the amount of time into the future being predicted. * * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported. */ class MotionPredictorMetricsManager { public: - // Note: the MetricsManager assumes that the input interval equals the prediction interval. - MotionPredictorMetricsManager(nsecs_t predictionInterval, size_t maxNumPredictions); + struct AtomFields; + + using ReportAtomFunction = std::function<void(const AtomFields&)>; + + static void defaultReportAtomFunction(const AtomFields& atomFields); + + // Parameters: + // • predictionInterval: the time interval between successive prediction target timestamps. + // Note: the MetricsManager assumes that the input interval equals the prediction interval. + // • maxNumPredictions: the maximum number of distinct target timestamps the prediction model + // will generate predictions for. The MetricsManager reports this many atoms per stroke. + // • [Optional] reportAtomFunction: the function that will be called to report metrics. If + // omitted (or if an empty function is given), the `stats_write(…)` function from the Android + // stats library will be used. + MotionPredictorMetricsManager( + nsecs_t predictionInterval, + size_t maxNumPredictions, + ReportAtomFunction reportAtomFunction = defaultReportAtomFunction); // This method should be called once for each call to MotionPredictor::record, receiving the // forwarded MotionEvent argument. @@ -121,7 +138,7 @@ public: // magnitude makes it unobtainable in practice.) static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min(); - // Final metrics reported in the atom. + // Final metric values reported in the atom. struct AtomFields { int deltaTimeBucketMilliseconds = 0; @@ -140,15 +157,6 @@ public: int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels }; - // Allow tests to pass in a mock AtomFields pointer. - // - // When metrics are reported to the stats library on stroke end, they will also be written to - // mockLoggedAtomFields, overwriting existing data. The size of mockLoggedAtomFields will equal - // the number of calls to stats_write for that stroke. - void setMockLoggedAtomFields(std::vector<AtomFields>* mockLoggedAtomFields) { - mMockLoggedAtomFields = mockLoggedAtomFields; - } - private: // The interval between consecutive predictions' target timestamps. We assume that the input // interval also equals this value. @@ -172,11 +180,7 @@ private: std::vector<AggregatedStrokeMetrics> mAggregatedMetrics; std::vector<AtomFields> mAtomFields; - // Non-owning pointer to the location of mock AtomFields. If present, will be filled with the - // values reported to stats_write on each batch of reported metrics. - // - // This pointer must remain valid as long as the MotionPredictorMetricsManager exists. - std::vector<AtomFields>* mMockLoggedAtomFields = nullptr; + const ReportAtomFunction mReportAtomFunction; // Helper methods for the implementation of onRecord and onPredict. @@ -196,10 +200,7 @@ private: // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics. void computeAtomFields(); - // Reports the metrics given by the current data in mAtomFields: - // • If on an Android device, reports the metrics to stats_write. - // • If mMockLoggedAtomFields is present, it will be overwritten with logged metrics, with one - // AtomFields element per call to stats_write. + // Reports the current data in mAtomFields by calling mReportAtomFunction. void reportMetrics(); }; diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp index 412931bc41..c4e3ff6dee 100644 --- a/libs/input/MotionPredictor.cpp +++ b/libs/input/MotionPredictor.cpp @@ -60,9 +60,11 @@ TfLiteMotionPredictorSample::Point convertPrediction( // --- MotionPredictor --- MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, - std::function<bool()> checkMotionPredictionEnabled) + std::function<bool()> checkMotionPredictionEnabled, + ReportAtomFunction reportAtomFunction) : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos), - mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {} + mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)), + mReportAtomFunction(reportAtomFunction) {} android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) { @@ -90,6 +92,13 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength()); } + // Pass input event to the MetricsManager. + if (!mMetricsManager) { + mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(), + mReportAtomFunction); + } + mMetricsManager->onRecord(event); + const int32_t action = event.getActionMasked(); if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) { ALOGD_IF(isDebug(), "End of event stream"); @@ -135,12 +144,6 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { } mLastEvent->copyFrom(&event, /*keepHistory=*/false); - // Pass input event to the MetricsManager. - if (!mMetricsManager) { - mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength()); - } - mMetricsManager->onRecord(event); - return {}; } diff --git a/libs/input/MotionPredictorMetricsManager.cpp b/libs/input/MotionPredictorMetricsManager.cpp index 67b103290f..0412d08181 100644 --- a/libs/input/MotionPredictorMetricsManager.cpp +++ b/libs/input/MotionPredictorMetricsManager.cpp @@ -46,13 +46,36 @@ inline constexpr float PATH_LENGTH_EPSILON = 0.001; } // namespace -MotionPredictorMetricsManager::MotionPredictorMetricsManager(nsecs_t predictionInterval, - size_t maxNumPredictions) +void MotionPredictorMetricsManager::defaultReportAtomFunction( + const MotionPredictorMetricsManager::AtomFields& atomFields) { + // Call stats_write logging function only on Android targets (not supported on host). +#ifdef __ANDROID__ + android::stats::libinput:: + stats_write(android::stats::libinput::STYLUS_PREDICTION_METRICS_REPORTED, + /*stylus_vendor_id=*/0, + /*stylus_product_id=*/0, + atomFields.deltaTimeBucketMilliseconds, + atomFields.alongTrajectoryErrorMeanMillipixels, + atomFields.alongTrajectoryErrorStdMillipixels, + atomFields.offTrajectoryRmseMillipixels, + atomFields.pressureRmseMilliunits, + atomFields.highVelocityAlongTrajectoryRmse, + atomFields.highVelocityOffTrajectoryRmse, + atomFields.scaleInvariantAlongTrajectoryRmse, + atomFields.scaleInvariantOffTrajectoryRmse); +#endif +} + +MotionPredictorMetricsManager::MotionPredictorMetricsManager( + nsecs_t predictionInterval, + size_t maxNumPredictions, + ReportAtomFunction reportAtomFunction) : mPredictionInterval(predictionInterval), mMaxNumPredictions(maxNumPredictions), mRecentGroundTruthPoints(maxNumPredictions + 1), mAggregatedMetrics(maxNumPredictions), - mAtomFields(maxNumPredictions) {} + mAtomFields(maxNumPredictions), + mReportAtomFunction(reportAtomFunction ? reportAtomFunction : defaultReportAtomFunction) {} void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) { // Convert MotionEvent to GroundTruthPoint. @@ -81,8 +104,8 @@ void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) { if (mRecentGroundTruthPoints.size() >= 2) { computeAtomFields(); reportMetrics(); - break; } + break; } } } @@ -345,28 +368,10 @@ void MotionPredictorMetricsManager::computeAtomFields() { } void MotionPredictorMetricsManager::reportMetrics() { - // Report one atom for each time bucket. + LOG_ALWAYS_FATAL_IF(!mReportAtomFunction); + // Report one atom for each prediction time bucket. for (size_t i = 0; i < mAtomFields.size(); ++i) { - // Call stats_write logging function only on Android targets (not supported on host). -#ifdef __ANDROID__ - android::stats::libinput:: - stats_write(android::stats::libinput::STYLUS_PREDICTION_METRICS_REPORTED, - /*stylus_vendor_id=*/0, - /*stylus_product_id=*/0, mAtomFields[i].deltaTimeBucketMilliseconds, - mAtomFields[i].alongTrajectoryErrorMeanMillipixels, - mAtomFields[i].alongTrajectoryErrorStdMillipixels, - mAtomFields[i].offTrajectoryRmseMillipixels, - mAtomFields[i].pressureRmseMilliunits, - mAtomFields[i].highVelocityAlongTrajectoryRmse, - mAtomFields[i].highVelocityOffTrajectoryRmse, - mAtomFields[i].scaleInvariantAlongTrajectoryRmse, - mAtomFields[i].scaleInvariantOffTrajectoryRmse); -#endif - } - - // Set mock atom fields, if available. - if (mMockLoggedAtomFields != nullptr) { - *mMockLoggedAtomFields = mAtomFields; + mReportAtomFunction(mAtomFields[i]); } } diff --git a/libs/input/tests/MotionPredictorMetricsManager_test.cpp b/libs/input/tests/MotionPredictorMetricsManager_test.cpp index b420a5a4e7..31cc1459fc 100644 --- a/libs/input/tests/MotionPredictorMetricsManager_test.cpp +++ b/libs/input/tests/MotionPredictorMetricsManager_test.cpp @@ -39,6 +39,7 @@ using ::testing::Matches; using GroundTruthPoint = MotionPredictorMetricsManager::GroundTruthPoint; using PredictionPoint = MotionPredictorMetricsManager::PredictionPoint; using AtomFields = MotionPredictorMetricsManager::AtomFields; +using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction; inline constexpr int NANOS_PER_MILLIS = 1'000'000; @@ -664,9 +665,16 @@ TEST(ErrorComputationHelperTest, ComputePressureRmsesSimpleTest) { // --- MotionPredictorMetricsManager tests. --- -// Helper function that instantiates a MetricsManager with the given mock logged AtomFields. Takes -// vectors of ground truth and prediction points of the same length, and passes these points to the -// MetricsManager. The format of these vectors is expected to be: +// Creates a mock atom reporting function that appends the reported atom to the given vector. +ReportAtomFunction createMockReportAtomFunction(std::vector<AtomFields>& reportedAtomFields) { + return [&reportedAtomFields](const AtomFields& atomFields) -> void { + reportedAtomFields.push_back(atomFields); + }; +} + +// Helper function that instantiates a MetricsManager that reports metrics to outReportedAtomFields. +// Takes vectors of ground truth and prediction points of the same length, and passes these points +// to the MetricsManager. The format of these vectors is expected to be: // • groundTruthPoints: chronologically-ordered ground truth points, with at least 2 elements. // • predictionPoints: the first index points to a vector of predictions corresponding to the // source ground truth point with the same index. @@ -678,15 +686,16 @@ TEST(ErrorComputationHelperTest, ComputePressureRmsesSimpleTest) { // prediction sets (that is, excluding the first and last). Thus, groundTruthPoints and // predictionPoints should have size at least TEST_MAX_NUM_PREDICTIONS + 2. // -// The passed-in outAtomFields will contain the logged AtomFields when the function returns. +// When the function returns, outReportedAtomFields will contain the reported AtomFields. // // This function returns void so that it can use test assertions. void runMetricsManager(const std::vector<GroundTruthPoint>& groundTruthPoints, const std::vector<std::vector<PredictionPoint>>& predictionPoints, - std::vector<AtomFields>& outAtomFields) { + std::vector<AtomFields>& outReportedAtomFields) { MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS, - TEST_MAX_NUM_PREDICTIONS); - metricsManager.setMockLoggedAtomFields(&outAtomFields); + TEST_MAX_NUM_PREDICTIONS, + createMockReportAtomFunction( + outReportedAtomFields)); // Validate structure of groundTruthPoints and predictionPoints. ASSERT_EQ(predictionPoints.size(), groundTruthPoints.size()); @@ -712,18 +721,18 @@ void runMetricsManager(const std::vector<GroundTruthPoint>& groundTruthPoints, // • Input: no prediction data. // • Expectation: no metrics should be logged. TEST(MotionPredictorMetricsManagerTest, NoPredictions) { - std::vector<AtomFields> mockLoggedAtomFields; + std::vector<AtomFields> reportedAtomFields; MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS, - TEST_MAX_NUM_PREDICTIONS); - metricsManager.setMockLoggedAtomFields(&mockLoggedAtomFields); + TEST_MAX_NUM_PREDICTIONS, + createMockReportAtomFunction(reportedAtomFields)); metricsManager.onRecord(makeMotionEvent( GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), .pressure = 0}, .timestamp = 0})); metricsManager.onRecord(makeLiftMotionEvent()); - // Check that mockLoggedAtomFields is still empty (as it was initialized empty), ensuring that + // Check that reportedAtomFields is still empty (as it was initialized empty), ensuring that // no metrics were logged. - EXPECT_EQ(0u, mockLoggedAtomFields.size()); + EXPECT_EQ(0u, reportedAtomFields.size()); } // Perfect predictions test: @@ -744,14 +753,14 @@ TEST(MotionPredictorMetricsManagerTest, ConstantGroundTruthPerfectPredictions) { groundTruthPoint.timestamp += TEST_PREDICTION_INTERVAL_NANOS; } - std::vector<AtomFields> atomFields; - runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + std::vector<AtomFields> reportedAtomFields; + runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields); - ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size()); // Check that errors are all zero, or NO_DATA_SENTINEL for unreported metrics. - for (size_t i = 0; i < atomFields.size(); ++i) { + for (size_t i = 0; i < reportedAtomFields.size(); ++i) { SCOPED_TRACE(testing::Message() << "i = " << i); - const AtomFields& atom = atomFields[i]; + const AtomFields& atom = reportedAtomFields[i]; const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); // General errors: reported for every time bucket. @@ -764,7 +773,7 @@ TEST(MotionPredictorMetricsManagerTest, ConstantGroundTruthPerfectPredictions) { EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityAlongTrajectoryRmse); EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityOffTrajectoryRmse); // Scale-invariant errors: reported only for the last time bucket. - if (i + 1 == atomFields.size()) { + if (i + 1 == reportedAtomFields.size()) { EXPECT_EQ(0, atom.scaleInvariantAlongTrajectoryRmse); EXPECT_EQ(0, atom.scaleInvariantOffTrajectoryRmse); } else { @@ -801,14 +810,14 @@ TEST(MotionPredictorMetricsManagerTest, QuadraticPressureLinearPredictions) { computePressureRmses(groundTruthPoints, predictionPoints); // Run test. - std::vector<AtomFields> atomFields; - runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + std::vector<AtomFields> reportedAtomFields; + runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields); // Check logged metrics match expectations. - ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); - for (size_t i = 0; i < atomFields.size(); ++i) { + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size()); + for (size_t i = 0; i < reportedAtomFields.size(); ++i) { SCOPED_TRACE(testing::Message() << "i = " << i); - const AtomFields& atom = atomFields[i]; + const AtomFields& atom = reportedAtomFields[i]; // Check time bucket delta matches expectation based on index and prediction interval. const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); @@ -845,14 +854,14 @@ TEST(MotionPredictorMetricsManagerTest, QuadraticPositionLinearPredictionsGenera computeGeneralPositionErrors(groundTruthPoints, predictionPoints); // Run test. - std::vector<AtomFields> atomFields; - runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + std::vector<AtomFields> reportedAtomFields; + runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields); // Check logged metrics match expectations. - ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); - for (size_t i = 0; i < atomFields.size(); ++i) { + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size()); + for (size_t i = 0; i < reportedAtomFields.size(); ++i) { SCOPED_TRACE(testing::Message() << "i = " << i); - const AtomFields& atom = atomFields[i]; + const AtomFields& atom = reportedAtomFields[i]; // Check time bucket delta matches expectation based on index and prediction interval. const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); @@ -896,14 +905,14 @@ TEST(MotionPredictorMetricsManagerTest, CounterclockwiseOctagonGroundTruthLinear computeGeneralPositionErrors(groundTruthPoints, predictionPoints); // Run test. - std::vector<AtomFields> atomFields; - runMetricsManager(groundTruthPoints, predictionPoints, atomFields); + std::vector<AtomFields> reportedAtomFields; + runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields); // Check logged metrics match expectations. - ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size()); - for (size_t i = 0; i < atomFields.size(); ++i) { + ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size()); + for (size_t i = 0; i < reportedAtomFields.size(); ++i) { SCOPED_TRACE(testing::Message() << "i = " << i); - const AtomFields& atom = atomFields[i]; + const AtomFields& atom = reportedAtomFields[i]; const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1); EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds); @@ -926,7 +935,7 @@ TEST(MotionPredictorMetricsManagerTest, CounterclockwiseOctagonGroundTruthLinear // to general errors (where reported). // // As above, use absolute value for RMSE, since it must be non-negative. - if (i + 2 >= atomFields.size()) { + if (i + 2 >= reportedAtomFields.size()) { EXPECT_NEAR(static_cast<int>( 1000 * std::abs(generalPositionErrors[i].alongTrajectoryErrorMean)), atom.highVelocityAlongTrajectoryRmse, 1); @@ -946,7 +955,7 @@ TEST(MotionPredictorMetricsManagerTest, CounterclockwiseOctagonGroundTruthLinear // to scale-invariant errors by dividing by `strokeVelocty * TEST_MAX_NUM_PREDICTIONS`. // // As above, use absolute value for RMSE, since it must be non-negative. - if (i + 1 == atomFields.size()) { + if (i + 1 == reportedAtomFields.size()) { const float pathLength = strokeVelocity * TEST_MAX_NUM_PREDICTIONS; std::vector<float> alongTrajectoryAbsoluteErrors; std::vector<float> offTrajectoryAbsoluteErrors; diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp index 4ac7ae920e..33431146ea 100644 --- a/libs/input/tests/MotionPredictor_test.cpp +++ b/libs/input/tests/MotionPredictor_test.cpp @@ -147,4 +147,35 @@ TEST(MotionPredictorTest, FlagDisablesPrediction) { ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN)); } +using AtomFields = MotionPredictorMetricsManager::AtomFields; +using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction; + +// Creates a mock atom reporting function that appends the reported atom to the given vector. +// The passed-in pointer must not be nullptr. +ReportAtomFunction createMockReportAtomFunction(std::vector<AtomFields>* reportedAtomFields) { + return [reportedAtomFields](const AtomFields& atomFields) -> void { + reportedAtomFields->push_back(atomFields); + }; +} + +TEST(MotionPredictorMetricsManagerIntegrationTest, ReportsMetrics) { + std::vector<AtomFields> reportedAtomFields; + MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, + []() { return true /*enable prediction*/; }, + createMockReportAtomFunction(&reportedAtomFields)); + + ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 1, 0ms, /*deviceId=*/0)).ok()); + ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 2, 4ms, /*deviceId=*/0)).ok()); + ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 3, 3, 8ms, /*deviceId=*/0)).ok()); + ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 4, 4, 12ms, /*deviceId=*/0)).ok()); + ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 5, 5, 16ms, /*deviceId=*/0)).ok()); + ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 6, 6, 20ms, /*deviceId=*/0)).ok()); + ASSERT_TRUE(predictor.record(getMotionEvent(UP, 7, 7, 24ms, /*deviceId=*/0)).ok()); + + // The number of atoms reported should equal the number of prediction time buckets, which is + // given by the prediction model's output length. For now, this value is always 5, and we + // hardcode it because it's not publicly accessible from the MotionPredictor. + EXPECT_EQ(5u, reportedAtomFields.size()); +} + } // namespace android |