diff options
author | 2024-11-11 21:49:51 +0000 | |
---|---|---|
committer | 2024-11-20 23:57:51 +0000 | |
commit | 53463397c3b84bb71d7fe155a8ada86fdcc0e96b (patch) | |
tree | f56a833a21b8347ba608befaeb416c109917d2bd | |
parent | 08ee19997d0ad4fab38465ef878b666c9fffb203 (diff) |
Fix One Euro filter's units of computation
Changed the units that the One Euro filter uses to compute the filtered
coordinates. This was causing a crash because if two timestamps were
sufficiently close to each other, by the time of implicitly converting
from nanoseconds to seconds, they were considered equal. This led to a
zero division when calculating the sampling frequency. Now, everything
is handled in the scale of nanoseconds, and conversion are done if and
only if they're necessary.
Bug: 297226446
Flag: EXEMPT bugfix
Test: TEST=libinput_tests; m $TEST && $ANDROID_HOST_OUT/nativetest64/$TEST/$TEST
Change-Id: I7fced6db447074cccb3d938eb9dc7a9707433f53
-rw-r--r-- | include/input/CoordinateFilter.h | 2 | ||||
-rw-r--r-- | include/input/OneEuroFilter.h | 10 | ||||
-rw-r--r-- | libs/input/CoordinateFilter.cpp | 2 | ||||
-rw-r--r-- | libs/input/OneEuroFilter.cpp | 34 | ||||
-rw-r--r-- | libs/input/tests/Android.bp | 1 | ||||
-rw-r--r-- | libs/input/tests/InputConsumerFilteredResampling_test.cpp | 218 | ||||
-rw-r--r-- | libs/input/tests/OneEuroFilter_test.cpp | 5 | ||||
-rw-r--r-- | libs/input/tests/TestEventMatchers.h | 9 |
8 files changed, 258 insertions, 23 deletions
diff --git a/include/input/CoordinateFilter.h b/include/input/CoordinateFilter.h index f36472dc8c..8f2e605e85 100644 --- a/include/input/CoordinateFilter.h +++ b/include/input/CoordinateFilter.h @@ -44,7 +44,7 @@ public: * the previous call. * @param coords Coordinates to be overwritten by the corresponding filtered coordinates. */ - void filter(std::chrono::duration<float> timestamp, PointerCoords& coords); + void filter(std::chrono::nanoseconds timestamp, PointerCoords& coords); private: OneEuroFilter mXFilter; diff --git a/include/input/OneEuroFilter.h b/include/input/OneEuroFilter.h index a0168e4f91..bdd82b2ee8 100644 --- a/include/input/OneEuroFilter.h +++ b/include/input/OneEuroFilter.h @@ -56,7 +56,7 @@ public: * provided in the previous call. * @param rawPosition Position to be filtered. */ - float filter(std::chrono::duration<float> timestamp, float rawPosition); + float filter(std::chrono::nanoseconds timestamp, float rawPosition); private: /** @@ -67,7 +67,7 @@ private: /** * Slope of the cutoff frequency criterion. This is the term scaling the absolute value of the - * filtered signal's speed. The data member is dimensionless, that is, it does not have units. + * filtered signal's speed. Units are 1 / position. */ const float mBeta; @@ -78,9 +78,9 @@ private: const float mSpeedCutoffFreq; /** - * The timestamp from the previous call. Units are seconds. + * The timestamp from the previous call. */ - std::optional<std::chrono::duration<float>> mPrevTimestamp; + std::optional<std::chrono::nanoseconds> mPrevTimestamp; /** * The raw position from the previous call. @@ -88,7 +88,7 @@ private: std::optional<float> mPrevRawPosition; /** - * The filtered velocity from the previous call. Units are position per second. + * The filtered velocity from the previous call. Units are position per nanosecond. */ std::optional<float> mPrevFilteredVelocity; diff --git a/libs/input/CoordinateFilter.cpp b/libs/input/CoordinateFilter.cpp index d231474577..a32685bd53 100644 --- a/libs/input/CoordinateFilter.cpp +++ b/libs/input/CoordinateFilter.cpp @@ -23,7 +23,7 @@ namespace android { CoordinateFilter::CoordinateFilter(float minCutoffFreq, float beta) : mXFilter{minCutoffFreq, beta}, mYFilter{minCutoffFreq, beta} {} -void CoordinateFilter::filter(std::chrono::duration<float> timestamp, PointerCoords& coords) { +void CoordinateFilter::filter(std::chrono::nanoseconds timestamp, PointerCoords& coords) { coords.setAxisValue(AMOTION_EVENT_AXIS_X, mXFilter.filter(timestamp, coords.getX())); coords.setAxisValue(AMOTION_EVENT_AXIS_Y, mYFilter.filter(timestamp, coords.getY())); } diff --git a/libs/input/OneEuroFilter.cpp b/libs/input/OneEuroFilter.cpp index 400d7c9ab0..7b0d104da1 100644 --- a/libs/input/OneEuroFilter.cpp +++ b/libs/input/OneEuroFilter.cpp @@ -25,16 +25,24 @@ namespace android { namespace { +using namespace std::literals::chrono_literals; + +const float kHertzPerGigahertz = 1E9f; +const float kGigahertzPerHertz = 1E-9f; + +// filteredSpeed's units are position per nanosecond. beta's units are 1 / position. inline float cutoffFreq(float minCutoffFreq, float beta, float filteredSpeed) { - return minCutoffFreq + beta * std::abs(filteredSpeed); + return kHertzPerGigahertz * + ((minCutoffFreq * kGigahertzPerHertz) + beta * std::abs(filteredSpeed)); } -inline float smoothingFactor(std::chrono::duration<float> samplingPeriod, float cutoffFreq) { - return samplingPeriod.count() / (samplingPeriod.count() + (1.0 / (2.0 * M_PI * cutoffFreq))); +inline float smoothingFactor(std::chrono::nanoseconds samplingPeriod, float cutoffFreq) { + const float constant = 2.0f * M_PI * samplingPeriod.count() * (cutoffFreq * kGigahertzPerHertz); + return constant / (constant + 1); } -inline float lowPassFilter(float rawPosition, float prevFilteredPosition, float smoothingFactor) { - return smoothingFactor * rawPosition + (1 - smoothingFactor) * prevFilteredPosition; +inline float lowPassFilter(float rawValue, float prevFilteredValue, float smoothingFactor) { + return smoothingFactor * rawValue + (1 - smoothingFactor) * prevFilteredValue; } } // namespace @@ -42,17 +50,17 @@ inline float lowPassFilter(float rawPosition, float prevFilteredPosition, float OneEuroFilter::OneEuroFilter(float minCutoffFreq, float beta, float speedCutoffFreq) : mMinCutoffFreq{minCutoffFreq}, mBeta{beta}, mSpeedCutoffFreq{speedCutoffFreq} {} -float OneEuroFilter::filter(std::chrono::duration<float> timestamp, float rawPosition) { - LOG_IF(FATAL, mPrevFilteredPosition.has_value() && (timestamp <= *mPrevTimestamp)) - << "Timestamp must be greater than mPrevTimestamp"; +float OneEuroFilter::filter(std::chrono::nanoseconds timestamp, float rawPosition) { + LOG_IF(FATAL, mPrevTimestamp.has_value() && (*mPrevTimestamp >= timestamp)) + << "Timestamp must be greater than mPrevTimestamp. Timestamp: " << timestamp.count() + << "ns. mPrevTimestamp: " << mPrevTimestamp->count() << "ns"; - const std::chrono::duration<float> samplingPeriod = (mPrevTimestamp.has_value()) - ? (timestamp - *mPrevTimestamp) - : std::chrono::duration<float>{1.0}; + const std::chrono::nanoseconds samplingPeriod = + (mPrevTimestamp.has_value()) ? (timestamp - *mPrevTimestamp) : 1s; const float rawVelocity = (mPrevFilteredPosition.has_value()) - ? ((rawPosition - *mPrevFilteredPosition) / samplingPeriod.count()) - : 0.0; + ? ((rawPosition - *mPrevFilteredPosition) / (samplingPeriod.count())) + : 0.0f; const float speedSmoothingFactor = smoothingFactor(samplingPeriod, mSpeedCutoffFreq); diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp index 46e819061f..d1c564d020 100644 --- a/libs/input/tests/Android.bp +++ b/libs/input/tests/Android.bp @@ -17,6 +17,7 @@ cc_test { "IdGenerator_test.cpp", "InputChannel_test.cpp", "InputConsumer_test.cpp", + "InputConsumerFilteredResampling_test.cpp", "InputConsumerResampling_test.cpp", "InputDevice_test.cpp", "InputEvent_test.cpp", diff --git a/libs/input/tests/InputConsumerFilteredResampling_test.cpp b/libs/input/tests/InputConsumerFilteredResampling_test.cpp new file mode 100644 index 0000000000..757cd18a38 --- /dev/null +++ b/libs/input/tests/InputConsumerFilteredResampling_test.cpp @@ -0,0 +1,218 @@ +/** + * Copyright 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <input/InputConsumerNoResampling.h> + +#include <chrono> +#include <iostream> +#include <memory> +#include <queue> + +#include <TestEventMatchers.h> +#include <TestInputChannel.h> +#include <android-base/logging.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <input/Input.h> +#include <input/InputEventBuilders.h> +#include <input/Resampler.h> +#include <utils/Looper.h> +#include <utils/StrongPointer.h> + +namespace android { +namespace { + +using std::chrono::nanoseconds; + +using ::testing::AllOf; +using ::testing::Matcher; + +const int32_t ACTION_DOWN = AMOTION_EVENT_ACTION_DOWN; +const int32_t ACTION_MOVE = AMOTION_EVENT_ACTION_MOVE; + +struct Pointer { + int32_t id{0}; + ToolType toolType{ToolType::FINGER}; + float x{0.0f}; + float y{0.0f}; + bool isResampled{false}; + + PointerBuilder asPointerBuilder() const { + return PointerBuilder{id, toolType}.x(x).y(y).isResampled(isResampled); + } +}; + +} // namespace + +class InputConsumerFilteredResamplingTest : public ::testing::Test, public InputConsumerCallbacks { +protected: + InputConsumerFilteredResamplingTest() + : mClientTestChannel{std::make_shared<TestInputChannel>("TestChannel")}, + mLooper{sp<Looper>::make(/*allowNonCallbacks=*/false)} { + Looper::setForThread(mLooper); + mConsumer = std::make_unique< + InputConsumerNoResampling>(mClientTestChannel, mLooper, *this, []() { + return std::make_unique<FilteredLegacyResampler>(/*minCutoffFreq=*/4.7, /*beta=*/0.01); + }); + } + + void invokeLooperCallback() const { + sp<LooperCallback> callback; + ASSERT_TRUE(mLooper->getFdStateDebug(mClientTestChannel->getFd(), /*ident=*/nullptr, + /*events=*/nullptr, &callback, /*data=*/nullptr)); + ASSERT_NE(callback, nullptr); + callback->handleEvent(mClientTestChannel->getFd(), ALOOPER_EVENT_INPUT, /*data=*/nullptr); + } + + void assertOnBatchedInputEventPendingWasCalled() { + ASSERT_GT(mOnBatchedInputEventPendingInvocationCount, 0UL) + << "onBatchedInputEventPending was not called"; + --mOnBatchedInputEventPendingInvocationCount; + } + + void assertReceivedMotionEvent(const Matcher<MotionEvent>& matcher) { + ASSERT_TRUE(!mMotionEvents.empty()) << "No motion events were received"; + std::unique_ptr<MotionEvent> motionEvent = std::move(mMotionEvents.front()); + mMotionEvents.pop(); + ASSERT_NE(motionEvent, nullptr) << "The consumed motion event must not be nullptr"; + EXPECT_THAT(*motionEvent, matcher); + } + + InputMessage nextPointerMessage(nanoseconds eventTime, int32_t action, const Pointer& pointer); + + std::shared_ptr<TestInputChannel> mClientTestChannel; + sp<Looper> mLooper; + std::unique_ptr<InputConsumerNoResampling> mConsumer; + + // Batched input events + std::queue<std::unique_ptr<KeyEvent>> mKeyEvents; + std::queue<std::unique_ptr<MotionEvent>> mMotionEvents; + std::queue<std::unique_ptr<FocusEvent>> mFocusEvents; + std::queue<std::unique_ptr<CaptureEvent>> mCaptureEvents; + std::queue<std::unique_ptr<DragEvent>> mDragEvents; + std::queue<std::unique_ptr<TouchModeEvent>> mTouchModeEvents; + +private: + // InputConsumer callbacks + void onKeyEvent(std::unique_ptr<KeyEvent> event, uint32_t seq) override { + mKeyEvents.push(std::move(event)); + mConsumer->finishInputEvent(seq, /*handled=*/true); + } + + void onMotionEvent(std::unique_ptr<MotionEvent> event, uint32_t seq) override { + mMotionEvents.push(std::move(event)); + mConsumer->finishInputEvent(seq, /*handled=*/true); + } + + void onBatchedInputEventPending(int32_t pendingBatchSource) override { + if (!mConsumer->probablyHasInput()) { + ADD_FAILURE() << "Should deterministically have input because there is a batch"; + } + ++mOnBatchedInputEventPendingInvocationCount; + } + + void onFocusEvent(std::unique_ptr<FocusEvent> event, uint32_t seq) override { + mFocusEvents.push(std::move(event)); + mConsumer->finishInputEvent(seq, /*handled=*/true); + } + + void onCaptureEvent(std::unique_ptr<CaptureEvent> event, uint32_t seq) override { + mCaptureEvents.push(std::move(event)); + mConsumer->finishInputEvent(seq, /*handled=*/true); + } + + void onDragEvent(std::unique_ptr<DragEvent> event, uint32_t seq) override { + mDragEvents.push(std::move(event)); + mConsumer->finishInputEvent(seq, /*handled=*/true); + } + + void onTouchModeEvent(std::unique_ptr<TouchModeEvent> event, uint32_t seq) override { + mTouchModeEvents.push(std::move(event)); + mConsumer->finishInputEvent(seq, /*handled=*/true); + } + + uint32_t mLastSeq{0}; + size_t mOnBatchedInputEventPendingInvocationCount{0}; +}; + +InputMessage InputConsumerFilteredResamplingTest::nextPointerMessage(nanoseconds eventTime, + int32_t action, + const Pointer& pointer) { + ++mLastSeq; + return InputMessageBuilder{InputMessage::Type::MOTION, mLastSeq} + .eventTime(eventTime.count()) + .source(AINPUT_SOURCE_TOUCHSCREEN) + .action(action) + .pointer(pointer.asPointerBuilder()) + .build(); +} + +TEST_F(InputConsumerFilteredResamplingTest, NeighboringTimestampsDoNotResultInZeroDivision) { + mClientTestChannel->enqueueMessage( + nextPointerMessage(0ms, ACTION_DOWN, Pointer{.x = 0.0f, .y = 0.0f})); + + invokeLooperCallback(); + + assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_DOWN), WithSampleCount(1))); + + const std::chrono::nanoseconds initialTime{56'821'700'000'000}; + + mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 4'929'000ns, ACTION_MOVE, + Pointer{.x = 1.0f, .y = 1.0f})); + mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 9'352'000ns, ACTION_MOVE, + Pointer{.x = 2.0f, .y = 2.0f})); + mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 14'531'000ns, ACTION_MOVE, + Pointer{.x = 3.0f, .y = 3.0f})); + + invokeLooperCallback(); + mConsumer->consumeBatchedInputEvents(initialTime.count() + 18'849'395 /*ns*/); + + assertOnBatchedInputEventPendingWasCalled(); + // Three samples are expected. The first two of the batch, and the resampled one. The + // coordinates of the resampled sample are hardcoded because the matcher requires them. However, + // the primary intention here is to check that the last sample is resampled. + assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(3), + WithSample(/*sampleIndex=*/2, + Sample{initialTime + 13'849'395ns, + {PointerArgs{.x = 1.3286f, + .y = 1.3286f, + .isResampled = true}}}))); + + mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 20'363'000ns, ACTION_MOVE, + Pointer{.x = 4.0f, .y = 4.0f})); + mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 25'745'000ns, ACTION_MOVE, + Pointer{.x = 5.0f, .y = 5.0f})); + // This sample is part of the stream of messages, but should not be consumed because its + // timestamp is greater than the ajusted frame time. + mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 31'337'000ns, ACTION_MOVE, + Pointer{.x = 6.0f, .y = 6.0f})); + + invokeLooperCallback(); + mConsumer->consumeBatchedInputEvents(initialTime.count() + 35'516'062 /*ns*/); + + assertOnBatchedInputEventPendingWasCalled(); + // Four samples are expected because the last sample of the previous batch was not consumed. + assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(4))); + + mClientTestChannel->assertFinishMessage(/*seq=*/1, /*handled=*/true); + mClientTestChannel->assertFinishMessage(/*seq=*/2, /*handled=*/true); + mClientTestChannel->assertFinishMessage(/*seq=*/3, /*handled=*/true); + mClientTestChannel->assertFinishMessage(/*seq=*/4, /*handled=*/true); + mClientTestChannel->assertFinishMessage(/*seq=*/5, /*handled=*/true); + mClientTestChannel->assertFinishMessage(/*seq=*/6, /*handled=*/true); +} + +} // namespace android diff --git a/libs/input/tests/OneEuroFilter_test.cpp b/libs/input/tests/OneEuroFilter_test.cpp index 270e789c84..8645508ea7 100644 --- a/libs/input/tests/OneEuroFilter_test.cpp +++ b/libs/input/tests/OneEuroFilter_test.cpp @@ -98,7 +98,10 @@ protected: std::vector<Sample> filteredSignal; for (const Sample& sample : signal) { filteredSignal.push_back( - Sample{sample.timestamp, mFilter.filter(sample.timestamp, sample.value)}); + Sample{sample.timestamp, + mFilter.filter(std::chrono::duration_cast<std::chrono::nanoseconds>( + sample.timestamp), + sample.value)}); } return filteredSignal; } diff --git a/libs/input/tests/TestEventMatchers.h b/libs/input/tests/TestEventMatchers.h index 3589de599f..de96600f66 100644 --- a/libs/input/tests/TestEventMatchers.h +++ b/libs/input/tests/TestEventMatchers.h @@ -17,6 +17,7 @@ #pragma once #include <chrono> +#include <cmath> #include <ostream> #include <vector> @@ -150,14 +151,18 @@ public: ++pointerIndex) { const PointerCoords& pointerCoords = *(motionEvent.getHistoricalRawPointerCoords(pointerIndex, mSampleIndex)); - if ((pointerCoords.getX() != mSample.pointers[pointerIndex].x) || - (pointerCoords.getY() != mSample.pointers[pointerIndex].y)) { + + if ((std::abs(pointerCoords.getX() - mSample.pointers[pointerIndex].x) > + MotionEvent::ROUNDING_PRECISION) || + (std::abs(pointerCoords.getY() - mSample.pointers[pointerIndex].y) > + MotionEvent::ROUNDING_PRECISION)) { *os << "sample coordinates mismatch at pointer index " << pointerIndex << ". sample: (" << pointerCoords.getX() << ", " << pointerCoords.getY() << ") expected: (" << mSample.pointers[pointerIndex].x << ", " << mSample.pointers[pointerIndex].y << ")"; return false; } + if (motionEvent.isResampled(pointerIndex, mSampleIndex) != mSample.pointers[pointerIndex].isResampled) { *os << "resampling flag mismatch. sample: " |