diff options
14 files changed, 1131 insertions, 29 deletions
| diff --git a/services/inputflinger/InputFilter.cpp b/services/inputflinger/InputFilter.cpp index 72c6f1a73b..1ada5e5678 100644 --- a/services/inputflinger/InputFilter.cpp +++ b/services/inputflinger/InputFilter.cpp @@ -118,6 +118,15 @@ void InputFilter::setAccessibilityBounceKeysThreshold(nsecs_t threshold) {      }  } +void InputFilter::setAccessibilitySlowKeysThreshold(nsecs_t threshold) { +    std::scoped_lock _l(mLock); + +    if (mConfig.slowKeysThresholdNs != threshold) { +        mConfig.slowKeysThresholdNs = threshold; +        notifyConfigurationChangedLocked(); +    } +} +  void InputFilter::setAccessibilityStickyKeysEnabled(bool enabled) {      std::scoped_lock _l(mLock); diff --git a/services/inputflinger/InputFilter.h b/services/inputflinger/InputFilter.h index 153d29dd53..4ddc9f4f6b 100644 --- a/services/inputflinger/InputFilter.h +++ b/services/inputflinger/InputFilter.h @@ -35,6 +35,7 @@ public:       */      virtual void dump(std::string& dump) = 0;      virtual void setAccessibilityBounceKeysThreshold(nsecs_t threshold) = 0; +    virtual void setAccessibilitySlowKeysThreshold(nsecs_t threshold) = 0;      virtual void setAccessibilityStickyKeysEnabled(bool enabled) = 0;  }; @@ -61,6 +62,7 @@ public:      void notifyDeviceReset(const NotifyDeviceResetArgs& args) override;      void notifyPointerCaptureChanged(const NotifyPointerCaptureChangedArgs& args) override;      void setAccessibilityBounceKeysThreshold(nsecs_t threshold) override; +    void setAccessibilitySlowKeysThreshold(nsecs_t threshold) override;      void setAccessibilityStickyKeysEnabled(bool enabled) override;      void dump(std::string& dump) override; diff --git a/services/inputflinger/InputFilterCallbacks.cpp b/services/inputflinger/InputFilterCallbacks.cpp index a8759b7cbd..6c3144230f 100644 --- a/services/inputflinger/InputFilterCallbacks.cpp +++ b/services/inputflinger/InputFilterCallbacks.cpp @@ -17,6 +17,11 @@  #define LOG_TAG "InputFilterCallbacks"  #include "InputFilterCallbacks.h" +#include <aidl/com/android/server/inputflinger/BnInputThread.h> +#include <android/binder_auto_utils.h> +#include <utils/StrongPointer.h> +#include <utils/Thread.h> +#include <functional>  namespace android { @@ -29,6 +34,47 @@ NotifyKeyArgs keyEventToNotifyKeyArgs(const AidlKeyEvent& event) {                           event.scanCode, event.metaState, event.downTime);  } +namespace { + +using namespace aidl::com::android::server::inputflinger; + +class InputFilterThreadImpl : public Thread { +public: +    explicit InputFilterThreadImpl(std::function<void()> loop) +          : Thread(/*canCallJava=*/true), mThreadLoop(loop) {} + +    ~InputFilterThreadImpl() {} + +private: +    std::function<void()> mThreadLoop; + +    bool threadLoop() override { +        mThreadLoop(); +        return true; +    } +}; + +class InputFilterThread : public BnInputThread { +public: +    InputFilterThread(std::shared_ptr<IInputThreadCallback> callback) : mCallback(callback) { +        mThread = sp<InputFilterThreadImpl>::make([this]() { loopOnce(); }); +        mThread->run("InputFilterThread", ANDROID_PRIORITY_URGENT_DISPLAY); +    } + +    ndk::ScopedAStatus finish() override { +        mThread->requestExit(); +        return ndk::ScopedAStatus::ok(); +    } + +private: +    sp<Thread> mThread; +    std::shared_ptr<IInputThreadCallback> mCallback; + +    void loopOnce() { LOG_ALWAYS_FATAL_IF(!mCallback->loopOnce().isOk()); } +}; + +} // namespace +  InputFilterCallbacks::InputFilterCallbacks(InputListenerInterface& listener,                                             InputFilterPolicyInterface& policy)        : mNextListener(listener), mPolicy(policy) {} @@ -49,6 +95,13 @@ ndk::ScopedAStatus InputFilterCallbacks::onModifierStateChanged(int32_t modifier      return ndk::ScopedAStatus::ok();  } +ndk::ScopedAStatus InputFilterCallbacks::createInputFilterThread( +        const std::shared_ptr<IInputThreadCallback>& callback, +        std::shared_ptr<IInputThread>* aidl_return) { +    *aidl_return = ndk::SharedRefBase::make<InputFilterThread>(callback); +    return ndk::ScopedAStatus::ok(); +} +  uint32_t InputFilterCallbacks::getModifierState() {      std::scoped_lock _l(mLock);      return mStickyModifierState.modifierState; diff --git a/services/inputflinger/InputFilterCallbacks.h b/services/inputflinger/InputFilterCallbacks.h index 31c160aeb9..a74955b5c6 100644 --- a/services/inputflinger/InputFilterCallbacks.h +++ b/services/inputflinger/InputFilterCallbacks.h @@ -19,6 +19,7 @@  #include <aidl/com/android/server/inputflinger/IInputFlingerRust.h>  #include <android/binder_auto_utils.h>  #include <utils/Mutex.h> +#include <memory>  #include <mutex>  #include "InputFilterPolicyInterface.h"  #include "InputListener.h" @@ -31,6 +32,9 @@ namespace android {  using IInputFilter = aidl::com::android::server::inputflinger::IInputFilter;  using AidlKeyEvent = aidl::com::android::server::inputflinger::KeyEvent; +using aidl::com::android::server::inputflinger::IInputThread; +using IInputThreadCallback = +        aidl::com::android::server::inputflinger::IInputThread::IInputThreadCallback;  class InputFilterCallbacks : public IInputFilter::BnInputFilterCallbacks {  public: @@ -53,6 +57,9 @@ private:      ndk::ScopedAStatus sendKeyEvent(const AidlKeyEvent& event) override;      ndk::ScopedAStatus onModifierStateChanged(int32_t modifierState,                                                int32_t lockedModifierState) override; +    ndk::ScopedAStatus createInputFilterThread( +            const std::shared_ptr<IInputThreadCallback>& callback, +            std::shared_ptr<IInputThread>* aidl_return) override;  };  } // namespace android
\ No newline at end of file diff --git a/services/inputflinger/aidl/com/android/server/inputflinger/IInputFilter.aidl b/services/inputflinger/aidl/com/android/server/inputflinger/IInputFilter.aidl index 2921d30b22..994d1c4b1a 100644 --- a/services/inputflinger/aidl/com/android/server/inputflinger/IInputFilter.aidl +++ b/services/inputflinger/aidl/com/android/server/inputflinger/IInputFilter.aidl @@ -17,6 +17,8 @@  package com.android.server.inputflinger;  import com.android.server.inputflinger.DeviceInfo; +import com.android.server.inputflinger.IInputThread; +import com.android.server.inputflinger.IInputThread.IInputThreadCallback;  import com.android.server.inputflinger.InputFilterConfiguration;  import com.android.server.inputflinger.KeyEvent; @@ -36,6 +38,9 @@ interface IInputFilter {          /** Sends back modifier state */          void onModifierStateChanged(int modifierState, int lockedModifierState); + +        /** Creates an Input filter thread */ +        IInputThread createInputFilterThread(in IInputThreadCallback callback);      }      /** Returns if InputFilter is enabled */ diff --git a/services/inputflinger/aidl/com/android/server/inputflinger/IInputThread.aidl b/services/inputflinger/aidl/com/android/server/inputflinger/IInputThread.aidl new file mode 100644 index 0000000000..2f6b8fc6ff --- /dev/null +++ b/services/inputflinger/aidl/com/android/server/inputflinger/IInputThread.aidl @@ -0,0 +1,45 @@ +/* + * Copyright 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *      http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.inputflinger; + +/** Interface to handle and run things on an InputThread +  * Exposes main functionality of InputThread.h to rust which internally used system/core/libutils +  * infrastructure. +  * +  * <p> +  * NOTE: Tried using rust provided threading infrastructure but that uses std::thread which doesn't +  * have JNI support and can't call into Java policy that we use currently. libutils provided +  * Thread.h also recommends against using std::thread and using the provided infrastructure that +  * already provides way of attaching JniEnv to the created thread. So, we are using this interface +  * to expose the InputThread infrastructure to rust. +  * </p> +  * TODO(b/321769871): Implement the threading infrastructure with JniEnv support in rust +  */ +interface IInputThread { +    /** Finish input thread (if not running, this call does nothing) */ +    void finish(); + +    /** Callbacks from C++ to call into inputflinger rust components */ +    interface IInputThreadCallback { +        /** +          * The created thread will keep looping and calling this function. +          * It's the responsibility of RUST component to appropriately put the thread to sleep and +          * wake according to the use case. +          */ +        void loopOnce(); +    } +}
\ No newline at end of file diff --git a/services/inputflinger/aidl/com/android/server/inputflinger/InputFilterConfiguration.aidl b/services/inputflinger/aidl/com/android/server/inputflinger/InputFilterConfiguration.aidl index 38b161203b..9984a6a9ea 100644 --- a/services/inputflinger/aidl/com/android/server/inputflinger/InputFilterConfiguration.aidl +++ b/services/inputflinger/aidl/com/android/server/inputflinger/InputFilterConfiguration.aidl @@ -22,6 +22,8 @@ package com.android.server.inputflinger;  parcelable InputFilterConfiguration {      // Threshold value for Bounce keys filter (check bounce_keys_filter.rs)      long bounceKeysThresholdNs; -    // If sticky keys filter is enabled +    // If sticky keys filter is enabled (check sticky_keys_filter.rs)      boolean stickyKeysEnabled; +    // Threshold value for Slow keys filter (check slow_keys_filter.rs) +    long slowKeysThresholdNs;  }
\ No newline at end of file diff --git a/services/inputflinger/rust/Android.bp b/services/inputflinger/rust/Android.bp index 2803805619..9e6dbe432d 100644 --- a/services/inputflinger/rust/Android.bp +++ b/services/inputflinger/rust/Android.bp @@ -42,6 +42,7 @@ rust_defaults {          "libbinder_rs",          "liblog_rust",          "liblogger", +        "libnix",      ],      host_supported: true,  } diff --git a/services/inputflinger/rust/bounce_keys_filter.rs b/services/inputflinger/rust/bounce_keys_filter.rs index 894b881638..2d5039a1b1 100644 --- a/services/inputflinger/rust/bounce_keys_filter.rs +++ b/services/inputflinger/rust/bounce_keys_filter.rs @@ -118,6 +118,10 @@ impl Filter for BounceKeysFilter {          }          self.next.notify_devices_changed(device_infos);      } + +    fn destroy(&mut self) { +        self.next.destroy(); +    }  }  #[cfg(test)] diff --git a/services/inputflinger/rust/input_filter.rs b/services/inputflinger/rust/input_filter.rs index e94a71fbf8..a544fa36ae 100644 --- a/services/inputflinger/rust/input_filter.rs +++ b/services/inputflinger/rust/input_filter.rs @@ -22,11 +22,14 @@ use binder::{Interface, Strong};  use com_android_server_inputflinger::aidl::com::android::server::inputflinger::{      DeviceInfo::DeviceInfo,      IInputFilter::{IInputFilter, IInputFilterCallbacks::IInputFilterCallbacks}, +    IInputThread::{IInputThread, IInputThreadCallback::IInputThreadCallback},      InputFilterConfiguration::InputFilterConfiguration,      KeyEvent::KeyEvent,  };  use crate::bounce_keys_filter::BounceKeysFilter; +use crate::input_filter_thread::InputFilterThread; +use crate::slow_keys_filter::SlowKeysFilter;  use crate::sticky_keys_filter::StickyKeysFilter;  use log::{error, info};  use std::sync::{Arc, Mutex, RwLock}; @@ -35,6 +38,7 @@ use std::sync::{Arc, Mutex, RwLock};  pub trait Filter {      fn notify_key(&mut self, event: &KeyEvent);      fn notify_devices_changed(&mut self, device_infos: &[DeviceInfo]); +    fn destroy(&mut self);  }  struct InputFilterState { @@ -50,6 +54,7 @@ pub struct InputFilter {      // Access to mutable references to mutable state (includes access to filters, enabled, etc.) is      // guarded by Mutex for thread safety      state: Mutex<InputFilterState>, +    input_filter_thread: InputFilterThread,  }  impl Interface for InputFilter {} @@ -67,7 +72,11 @@ impl InputFilter {          first_filter: Box<dyn Filter + Send + Sync>,          callbacks: Arc<RwLock<Strong<dyn IInputFilterCallbacks>>>,      ) -> InputFilter { -        Self { callbacks, state: Mutex::new(InputFilterState { first_filter, enabled: false }) } +        Self { +            callbacks: callbacks.clone(), +            state: Mutex::new(InputFilterState { first_filter, enabled: false }), +            input_filter_thread: InputFilterThread::new(InputFilterThreadCreator::new(callbacks)), +        }      }  } @@ -89,24 +98,36 @@ impl IInputFilter for InputFilter {      }      fn notifyConfigurationChanged(&self, config: &InputFilterConfiguration) -> binder::Result<()> { -        let mut state = self.state.lock().unwrap(); -        let mut first_filter: Box<dyn Filter + Send + Sync> = -            Box::new(BaseFilter::new(self.callbacks.clone())); -        if config.stickyKeysEnabled { -            first_filter = Box::new(StickyKeysFilter::new( -                first_filter, -                ModifierStateListener::new(self.callbacks.clone()), -            )); -            state.enabled = true; -            info!("Sticky keys filter is installed"); -        } -        if config.bounceKeysThresholdNs > 0 { -            first_filter = -                Box::new(BounceKeysFilter::new(first_filter, config.bounceKeysThresholdNs)); -            state.enabled = true; -            info!("Bounce keys filter is installed"); +        { +            let mut state = self.state.lock().unwrap(); +            state.first_filter.destroy(); +            let mut first_filter: Box<dyn Filter + Send + Sync> = +                Box::new(BaseFilter::new(self.callbacks.clone())); +            if config.stickyKeysEnabled { +                first_filter = Box::new(StickyKeysFilter::new( +                    first_filter, +                    ModifierStateListener::new(self.callbacks.clone()), +                )); +                state.enabled = true; +                info!("Sticky keys filter is installed"); +            } +            if config.slowKeysThresholdNs > 0 { +                first_filter = Box::new(SlowKeysFilter::new( +                    first_filter, +                    config.slowKeysThresholdNs, +                    self.input_filter_thread.clone(), +                )); +                state.enabled = true; +                info!("Slow keys filter is installed"); +            } +            if config.bounceKeysThresholdNs > 0 { +                first_filter = +                    Box::new(BounceKeysFilter::new(first_filter, config.bounceKeysThresholdNs)); +                state.enabled = true; +                info!("Bounce keys filter is installed"); +            } +            state.first_filter = first_filter;          } -        state.first_filter = first_filter;          Result::Ok(())      }  } @@ -132,27 +153,51 @@ impl Filter for BaseFilter {      fn notify_devices_changed(&mut self, _device_infos: &[DeviceInfo]) {          // do nothing      } -} -pub struct ModifierStateListener { -    callbacks: Arc<RwLock<Strong<dyn IInputFilterCallbacks>>>, +    fn destroy(&mut self) { +        // do nothing +    }  } +/// This struct wraps around IInputFilterCallbacks restricting access to only +/// {@code onModifierStateChanged()} method of the callback. +#[derive(Clone)] +pub struct ModifierStateListener(Arc<RwLock<Strong<dyn IInputFilterCallbacks>>>); +  impl ModifierStateListener { -    /// Create a new InputFilter instance.      pub fn new(callbacks: Arc<RwLock<Strong<dyn IInputFilterCallbacks>>>) -> ModifierStateListener { -        Self { callbacks } +        Self(callbacks)      }      pub fn modifier_state_changed(&self, modifier_state: u32, locked_modifier_state: u32) {          let _ = self -            .callbacks +            .0              .read()              .unwrap()              .onModifierStateChanged(modifier_state as i32, locked_modifier_state as i32);      }  } +/// This struct wraps around IInputFilterCallbacks restricting access to only +/// {@code createInputFilterThread()} method of the callback. +#[derive(Clone)] +pub struct InputFilterThreadCreator(Arc<RwLock<Strong<dyn IInputFilterCallbacks>>>); + +impl InputFilterThreadCreator { +    pub fn new( +        callbacks: Arc<RwLock<Strong<dyn IInputFilterCallbacks>>>, +    ) -> InputFilterThreadCreator { +        Self(callbacks) +    } + +    pub fn create( +        &self, +        input_thread_callback: &Strong<dyn IInputThreadCallback>, +    ) -> Strong<dyn IInputThread> { +        self.0.read().unwrap().createInputFilterThread(input_thread_callback).unwrap() +    } +} +  #[cfg(test)]  mod tests {      use crate::input_filter::{ @@ -218,7 +263,7 @@ mod tests {          let input_filter = InputFilter::new(Strong::new(Box::new(test_callbacks)));          let result = input_filter.notifyConfigurationChanged(&InputFilterConfiguration {              bounceKeysThresholdNs: 100, -            stickyKeysEnabled: false, +            ..Default::default()          });          assert!(result.is_ok());          let result = input_filter.isEnabled(); @@ -231,8 +276,8 @@ mod tests {          let test_callbacks = TestCallbacks::new();          let input_filter = InputFilter::new(Strong::new(Box::new(test_callbacks)));          let result = input_filter.notifyConfigurationChanged(&InputFilterConfiguration { -            bounceKeysThresholdNs: 0,              stickyKeysEnabled: true, +            ..Default::default()          });          assert!(result.is_ok());          let result = input_filter.isEnabled(); @@ -240,6 +285,33 @@ mod tests {          assert!(result.unwrap());      } +    #[test] +    fn test_notify_configuration_changed_enabled_slow_keys() { +        let test_callbacks = TestCallbacks::new(); +        let input_filter = InputFilter::new(Strong::new(Box::new(test_callbacks))); +        let result = input_filter.notifyConfigurationChanged(&InputFilterConfiguration { +            slowKeysThresholdNs: 100, +            ..Default::default() +        }); +        assert!(result.is_ok()); +        let result = input_filter.isEnabled(); +        assert!(result.is_ok()); +        assert!(result.unwrap()); +    } + +    #[test] +    fn test_notify_configuration_changed_destroys_existing_filters() { +        let test_filter = TestFilter::new(); +        let test_callbacks = TestCallbacks::new(); +        let input_filter = InputFilter::create_input_filter( +            Box::new(test_filter.clone()), +            Arc::new(RwLock::new(Strong::new(Box::new(test_callbacks)))), +        ); +        let _ = input_filter +            .notifyConfigurationChanged(&InputFilterConfiguration { ..Default::default() }); +        assert!(test_filter.is_destroy_called()); +    } +      fn create_key_event() -> KeyEvent {          KeyEvent {              id: 1, @@ -271,6 +343,7 @@ pub mod test_filter {      struct TestFilterInner {          is_device_changed_called: bool,          last_event: Option<KeyEvent>, +        is_destroy_called: bool,      }      #[derive(Default, Clone)] @@ -296,6 +369,10 @@ pub mod test_filter {          pub fn is_device_changed_called(&self) -> bool {              self.0.read().unwrap().is_device_changed_called          } + +        pub fn is_destroy_called(&self) -> bool { +            self.0.read().unwrap().is_destroy_called +        }      }      impl Filter for TestFilter { @@ -305,14 +382,19 @@ pub mod test_filter {          fn notify_devices_changed(&mut self, _device_infos: &[DeviceInfo]) {              self.inner().is_device_changed_called = true;          } +        fn destroy(&mut self) { +            self.inner().is_destroy_called = true; +        }      }  }  #[cfg(test)]  pub mod test_callbacks { -    use binder::Interface; +    use binder::{BinderFeatures, Interface, Strong};      use com_android_server_inputflinger::aidl::com::android::server::inputflinger::{ -        IInputFilter::IInputFilterCallbacks::IInputFilterCallbacks, KeyEvent::KeyEvent, +        IInputFilter::IInputFilterCallbacks::IInputFilterCallbacks, +        IInputThread::{BnInputThread, IInputThread, IInputThreadCallback::IInputThreadCallback}, +        KeyEvent::KeyEvent,      };      use std::sync::{Arc, RwLock, RwLockWriteGuard}; @@ -321,6 +403,7 @@ pub mod test_callbacks {          last_modifier_state: u32,          last_locked_modifier_state: u32,          last_event: Option<KeyEvent>, +        test_thread: Option<TestThread>,      }      #[derive(Default, Clone)] @@ -354,6 +437,17 @@ pub mod test_callbacks {          pub fn get_last_locked_modifier_state(&self) -> u32 {              self.0.read().unwrap().last_locked_modifier_state          } + +        pub fn is_thread_created(&self) -> bool { +            self.0.read().unwrap().test_thread.is_some() +        } + +        pub fn is_thread_finished(&self) -> bool { +            if let Some(test_thread) = &self.0.read().unwrap().test_thread { +                return test_thread.is_finish_called(); +            } +            false +        }      }      impl IInputFilterCallbacks for TestCallbacks { @@ -371,5 +465,45 @@ pub mod test_callbacks {              self.inner().last_locked_modifier_state = locked_modifier_state as u32;              Result::Ok(())          } + +        fn createInputFilterThread( +            &self, +            _callback: &Strong<dyn IInputThreadCallback>, +        ) -> std::result::Result<Strong<dyn IInputThread>, binder::Status> { +            let test_thread = TestThread::new(); +            self.inner().test_thread = Some(test_thread.clone()); +            Result::Ok(BnInputThread::new_binder(test_thread, BinderFeatures::default())) +        } +    } + +    #[derive(Default)] +    struct TestThreadInner { +        is_finish_called: bool, +    } + +    #[derive(Default, Clone)] +    struct TestThread(Arc<RwLock<TestThreadInner>>); + +    impl Interface for TestThread {} + +    impl TestThread { +        pub fn new() -> Self { +            Default::default() +        } + +        fn inner(&self) -> RwLockWriteGuard<'_, TestThreadInner> { +            self.0.write().unwrap() +        } + +        pub fn is_finish_called(&self) -> bool { +            self.0.read().unwrap().is_finish_called +        } +    } + +    impl IInputThread for TestThread { +        fn finish(&self) -> binder::Result<()> { +            self.inner().is_finish_called = true; +            Result::Ok(()) +        }      }  } diff --git a/services/inputflinger/rust/input_filter_thread.rs b/services/inputflinger/rust/input_filter_thread.rs new file mode 100644 index 0000000000..2d503aee70 --- /dev/null +++ b/services/inputflinger/rust/input_filter_thread.rs @@ -0,0 +1,452 @@ +/* + * Copyright 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *      http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Input filter thread implementation in rust. +//! Using IInputFilter.aidl interface to create ever looping thread with JNI support, rest of +//! thread handling is done from rust side. +//! +//! NOTE: Tried using rust provided threading infrastructure but that uses std::thread which doesn't +//! have JNI support and can't call into Java policy that we use currently. libutils provided +//! Thread.h also recommends against using std::thread and using the provided infrastructure that +//! already provides way of attaching JniEnv to the created thread. So, we are using an AIDL +//! interface to expose the InputThread infrastructure to rust. + +use crate::input_filter::InputFilterThreadCreator; +use binder::{BinderFeatures, Interface, Strong}; +use com_android_server_inputflinger::aidl::com::android::server::inputflinger::IInputThread::{ +    IInputThread, IInputThreadCallback::BnInputThreadCallback, +    IInputThreadCallback::IInputThreadCallback, +}; +use log::{debug, error}; +use nix::{sys::time::TimeValLike, time::clock_gettime, time::ClockId}; +use std::sync::{Arc, RwLock, RwLockWriteGuard}; +use std::time::Duration; +use std::{thread, thread::Thread}; + +/// Interface to receive callback from Input filter thread +pub trait ThreadCallback { +    /// Calls back after the requested timeout expires. +    /// {@see InputFilterThread.request_timeout_at_time(...)} +    /// +    /// NOTE: In case of multiple requests, the timeout request which is earliest in time, will be +    /// fulfilled and notified to all the listeners. It's up to the listeners to re-request another +    /// timeout in the future. +    fn notify_timeout_expired(&self, when_nanos: i64); +    /// Unique name for the listener, which will be used to uniquely identify the listener. +    fn name(&self) -> &str; +} + +#[derive(Clone)] +pub struct InputFilterThread { +    thread_creator: InputFilterThreadCreator, +    thread_callback_handler: ThreadCallbackHandler, +    inner: Arc<RwLock<InputFilterThreadInner>>, +} + +struct InputFilterThreadInner { +    cpp_thread: Option<Strong<dyn IInputThread>>, +    looper: Option<Thread>, +    next_timeout: i64, +    is_finishing: bool, +} + +impl InputFilterThread { +    /// Create a new InputFilterThread instance. +    /// NOTE: This will create a new thread. Clone the existing instance to reuse the same thread. +    pub fn new(thread_creator: InputFilterThreadCreator) -> InputFilterThread { +        Self { +            thread_creator, +            thread_callback_handler: ThreadCallbackHandler::new(), +            inner: Arc::new(RwLock::new(InputFilterThreadInner { +                cpp_thread: None, +                looper: None, +                next_timeout: i64::MAX, +                is_finishing: false, +            })), +        } +    } + +    /// Listener requesting a timeout in future will receive a callback at or before the requested +    /// time on the input filter thread. +    /// {@see ThreadCallback.notify_timeout_expired(...)} +    pub fn request_timeout_at_time(&self, when_nanos: i64) { +        let filter_thread = &mut self.filter_thread(); +        if when_nanos < filter_thread.next_timeout { +            filter_thread.next_timeout = when_nanos; +            if let Some(looper) = &filter_thread.looper { +                looper.unpark(); +            } +        } +    } + +    /// Registers a callback listener. +    /// +    /// NOTE: If a listener with the same name already exists when registering using +    /// {@see InputFilterThread.register_thread_callback(...)}, we will ignore the listener. You +    /// must clear any previously registered listeners using +    /// {@see InputFilterThread.unregister_thread_callback(...) before registering the new listener. +    /// +    /// NOTE: Also, registering a callback will start the looper if not already started. +    pub fn register_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) { +        self.thread_callback_handler.register_thread_callback(callback); +        self.start(); +    } + +    /// Unregisters a callback listener. +    /// +    /// NOTE: Unregistering a callback will stop the looper if not other callback registered. +    pub fn unregister_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) { +        self.thread_callback_handler.unregister_thread_callback(callback); +        // Stop the thread if no registered callbacks exist. We will recreate the thread when new +        // callbacks are registered. +        let has_callbacks = self.thread_callback_handler.has_callbacks(); +        if !has_callbacks { +            self.stop(); +        } +    } + +    fn start(&self) { +        debug!("InputFilterThread: start thread"); +        let filter_thread = &mut self.filter_thread(); +        if filter_thread.cpp_thread.is_none() { +            filter_thread.cpp_thread = Some(self.thread_creator.create( +                &BnInputThreadCallback::new_binder(self.clone(), BinderFeatures::default()), +            )); +            filter_thread.looper = None; +            filter_thread.is_finishing = false; +        } +    } + +    fn stop(&self) { +        debug!("InputFilterThread: stop thread"); +        let filter_thread = &mut self.filter_thread(); +        filter_thread.is_finishing = true; +        if let Some(looper) = &filter_thread.looper { +            looper.unpark(); +        } +        if let Some(cpp_thread) = &filter_thread.cpp_thread { +            let _ = cpp_thread.finish(); +        } +        // Clear all references +        filter_thread.cpp_thread = None; +        filter_thread.looper = None; +    } + +    fn loop_once(&self, now: i64) { +        let mut wake_up_time = i64::MAX; +        let mut timeout_expired = false; +        { +            // acquire thread lock +            let filter_thread = &mut self.filter_thread(); +            if filter_thread.is_finishing { +                // Thread is finishing so don't block processing on it and let it loop. +                return; +            } +            if filter_thread.next_timeout != i64::MAX { +                if filter_thread.next_timeout <= now { +                    timeout_expired = true; +                    filter_thread.next_timeout = i64::MAX; +                } else { +                    wake_up_time = filter_thread.next_timeout; +                } +            } +            if filter_thread.looper.is_none() { +                filter_thread.looper = Some(std::thread::current()); +            } +        } // release thread lock +        if timeout_expired { +            self.thread_callback_handler.notify_timeout_expired(now); +        } +        if wake_up_time == i64::MAX { +            thread::park(); +        } else { +            let duration_now = Duration::from_nanos(now as u64); +            let duration_wake_up = Duration::from_nanos(wake_up_time as u64); +            thread::park_timeout(duration_wake_up - duration_now); +        } +    } + +    fn filter_thread(&self) -> RwLockWriteGuard<'_, InputFilterThreadInner> { +        self.inner.write().unwrap() +    } +} + +impl Interface for InputFilterThread {} + +impl IInputThreadCallback for InputFilterThread { +    fn loopOnce(&self) -> binder::Result<()> { +        self.loop_once(clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds()); +        Result::Ok(()) +    } +} + +#[derive(Default, Clone)] +struct ThreadCallbackHandler(Arc<RwLock<ThreadCallbackHandlerInner>>); + +#[derive(Default)] +struct ThreadCallbackHandlerInner { +    callbacks: Vec<Box<dyn ThreadCallback + Send + Sync>>, +} + +impl ThreadCallbackHandler { +    fn new() -> Self { +        Default::default() +    } + +    fn has_callbacks(&self) -> bool { +        !&self.0.read().unwrap().callbacks.is_empty() +    } + +    fn register_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) { +        let callbacks = &mut self.0.write().unwrap().callbacks; +        if callbacks.iter().any(|x| x.name() == callback.name()) { +            error!( +                "InputFilterThread: register_thread_callback, callback {:?} already exists!", +                callback.name() +            ); +            return; +        } +        debug!( +            "InputFilterThread: register_thread_callback, callback {:?} added!", +            callback.name() +        ); +        callbacks.push(callback); +    } + +    fn unregister_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) { +        let callbacks = &mut self.0.write().unwrap().callbacks; +        if let Some(index) = callbacks.iter().position(|x| x.name() == callback.name()) { +            callbacks.remove(index); +            debug!( +                "InputFilterThread: unregister_thread_callback, callback {:?} removed!", +                callback.name() +            ); +            return; +        } +        error!( +            "InputFilterThread: unregister_thread_callback, callback {:?} doesn't exist", +            callback.name() +        ); +    } + +    fn notify_timeout_expired(&self, when_nanos: i64) { +        let callbacks = &self.0.read().unwrap().callbacks; +        for callback in callbacks.iter() { +            callback.notify_timeout_expired(when_nanos); +        } +    } +} + +#[cfg(test)] +mod tests { +    use crate::input_filter::test_callbacks::TestCallbacks; +    use crate::input_filter_thread::{ +        test_thread::TestThread, test_thread_callback::TestThreadCallback, +    }; + +    #[test] +    fn test_register_callback_creates_cpp_thread() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let test_thread_callback = TestThreadCallback::new(); +        test_thread.register_thread_callback(test_thread_callback); +        assert!(test_callbacks.is_thread_created()); +    } + +    #[test] +    fn test_unregister_callback_finishes_cpp_thread() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let test_thread_callback = TestThreadCallback::new(); +        test_thread.register_thread_callback(test_thread_callback.clone()); +        test_thread.unregister_thread_callback(test_thread_callback); +        assert!(test_callbacks.is_thread_finished()); +    } + +    #[test] +    fn test_notify_timeout_called_after_timeout_expired() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let test_thread_callback = TestThreadCallback::new(); +        test_thread.register_thread_callback(test_thread_callback.clone()); +        test_thread.start_looper(); + +        test_thread.request_timeout_at_time(500); +        test_thread.dispatch_next(); + +        test_thread.move_time_forward(500); + +        test_thread.stop_looper(); +        assert!(test_thread_callback.is_notify_timeout_called()); +    } + +    #[test] +    fn test_notify_timeout_not_called_before_timeout_expired() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let test_thread_callback = TestThreadCallback::new(); +        test_thread.register_thread_callback(test_thread_callback.clone()); +        test_thread.start_looper(); + +        test_thread.request_timeout_at_time(500); +        test_thread.dispatch_next(); + +        test_thread.move_time_forward(100); + +        test_thread.stop_looper(); +        assert!(!test_thread_callback.is_notify_timeout_called()); +    } +} + +#[cfg(test)] +pub mod test_thread { + +    use crate::input_filter::{test_callbacks::TestCallbacks, InputFilterThreadCreator}; +    use crate::input_filter_thread::{test_thread_callback::TestThreadCallback, InputFilterThread}; +    use binder::Strong; +    use std::sync::{ +        atomic::AtomicBool, atomic::AtomicI64, atomic::Ordering, Arc, RwLock, RwLockWriteGuard, +    }; +    use std::time::Duration; + +    #[derive(Clone)] +    pub struct TestThread { +        input_thread: InputFilterThread, +        inner: Arc<RwLock<TestThreadInner>>, +        exit_flag: Arc<AtomicBool>, +        now: Arc<AtomicI64>, +    } + +    struct TestThreadInner { +        join_handle: Option<std::thread::JoinHandle<()>>, +    } + +    impl TestThread { +        pub fn new(callbacks: TestCallbacks) -> TestThread { +            Self { +                input_thread: InputFilterThread::new(InputFilterThreadCreator::new(Arc::new( +                    RwLock::new(Strong::new(Box::new(callbacks))), +                ))), +                inner: Arc::new(RwLock::new(TestThreadInner { join_handle: None })), +                exit_flag: Arc::new(AtomicBool::new(false)), +                now: Arc::new(AtomicI64::new(0)), +            } +        } + +        fn inner(&self) -> RwLockWriteGuard<'_, TestThreadInner> { +            self.inner.write().unwrap() +        } + +        pub fn get_input_thread(&self) -> InputFilterThread { +            self.input_thread.clone() +        } + +        pub fn register_thread_callback(&self, thread_callback: TestThreadCallback) { +            self.input_thread.register_thread_callback(Box::new(thread_callback)); +        } + +        pub fn unregister_thread_callback(&self, thread_callback: TestThreadCallback) { +            self.input_thread.unregister_thread_callback(Box::new(thread_callback)); +        } + +        pub fn start_looper(&self) { +            self.exit_flag.store(false, Ordering::Relaxed); +            let clone = self.clone(); +            let join_handle = std::thread::Builder::new() +                .name("test_thread".to_string()) +                .spawn(move || { +                    while !clone.exit_flag.load(Ordering::Relaxed) { +                        clone.loop_once(); +                    } +                }) +                .unwrap(); +            self.inner().join_handle = Some(join_handle); +            // Sleep until the looper thread starts +            std::thread::sleep(Duration::from_millis(10)); +        } + +        pub fn stop_looper(&self) { +            self.exit_flag.store(true, Ordering::Relaxed); +            { +                let mut inner = self.inner(); +                if let Some(join_handle) = &inner.join_handle { +                    join_handle.thread().unpark(); +                } +                inner.join_handle.take().map(std::thread::JoinHandle::join); +                inner.join_handle = None; +            } +            self.exit_flag.store(false, Ordering::Relaxed); +        } + +        pub fn move_time_forward(&self, value: i64) { +            let _ = self.now.fetch_add(value, Ordering::Relaxed); +            self.dispatch_next(); +        } + +        pub fn dispatch_next(&self) { +            if let Some(join_handle) = &self.inner().join_handle { +                join_handle.thread().unpark(); +            } +            // Sleep until the looper thread runs a loop +            std::thread::sleep(Duration::from_millis(10)); +        } + +        fn loop_once(&self) { +            self.input_thread.loop_once(self.now.load(Ordering::Relaxed)); +        } + +        pub fn request_timeout_at_time(&self, when_nanos: i64) { +            self.input_thread.request_timeout_at_time(when_nanos); +        } +    } +} + +#[cfg(test)] +pub mod test_thread_callback { +    use crate::input_filter_thread::ThreadCallback; +    use std::sync::{Arc, RwLock, RwLockWriteGuard}; + +    #[derive(Default)] +    struct TestThreadCallbackInner { +        is_notify_timeout_called: bool, +    } + +    #[derive(Default, Clone)] +    pub struct TestThreadCallback(Arc<RwLock<TestThreadCallbackInner>>); + +    impl TestThreadCallback { +        pub fn new() -> Self { +            Default::default() +        } + +        fn inner(&self) -> RwLockWriteGuard<'_, TestThreadCallbackInner> { +            self.0.write().unwrap() +        } + +        pub fn is_notify_timeout_called(&self) -> bool { +            self.0.read().unwrap().is_notify_timeout_called +        } +    } + +    impl ThreadCallback for TestThreadCallback { +        fn notify_timeout_expired(&self, _when_nanos: i64) { +            self.inner().is_notify_timeout_called = true; +        } +        fn name(&self) -> &str { +            "TestThreadCallback" +        } +    } +} diff --git a/services/inputflinger/rust/lib.rs b/services/inputflinger/rust/lib.rs index fa16898835..da72134065 100644 --- a/services/inputflinger/rust/lib.rs +++ b/services/inputflinger/rust/lib.rs @@ -21,6 +21,8 @@  mod bounce_keys_filter;  mod input_filter; +mod input_filter_thread; +mod slow_keys_filter;  mod sticky_keys_filter;  use crate::input_filter::InputFilter; diff --git a/services/inputflinger/rust/slow_keys_filter.rs b/services/inputflinger/rust/slow_keys_filter.rs new file mode 100644 index 0000000000..01165b57fa --- /dev/null +++ b/services/inputflinger/rust/slow_keys_filter.rs @@ -0,0 +1,382 @@ +/* + * Copyright 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *      http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Slow keys input filter implementation. +//! Slow keys is an accessibility feature to aid users who have physical disabilities, that allows +//! the user to specify the duration for which one must press-and-hold a key before the system +//! accepts the keypress. +use crate::input_filter::Filter; +use crate::input_filter_thread::{InputFilterThread, ThreadCallback}; +use android_hardware_input_common::aidl::android::hardware::input::common::Source::Source; +use com_android_server_inputflinger::aidl::com::android::server::inputflinger::{ +    DeviceInfo::DeviceInfo, KeyEvent::KeyEvent, KeyEventAction::KeyEventAction, +}; +use log::debug; +use std::collections::HashSet; +use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[derive(Debug)] +struct OngoingKeyDown { +    scancode: i32, +    device_id: i32, +    down_time: i64, +} + +struct SlowKeysFilterInner { +    next: Box<dyn Filter + Send + Sync>, +    slow_key_threshold_ns: i64, +    external_devices: HashSet<i32>, +    // This tracks KeyEvents that are blocked by Slow keys filter and will be passed through if the +    // press duration exceeds the slow keys threshold. +    pending_down_events: Vec<KeyEvent>, +    // This tracks KeyEvent streams that have press duration greater than the slow keys threshold, +    // hence any future ACTION_DOWN (if repeats are handled on HW side) or ACTION_UP are allowed to +    // pass through without waiting. +    ongoing_down_events: Vec<OngoingKeyDown>, +    input_filter_thread: InputFilterThread, +} + +#[derive(Clone)] +pub struct SlowKeysFilter(Arc<RwLock<SlowKeysFilterInner>>); + +impl SlowKeysFilter { +    /// Create a new SlowKeysFilter instance. +    pub fn new( +        next: Box<dyn Filter + Send + Sync>, +        slow_key_threshold_ns: i64, +        input_filter_thread: InputFilterThread, +    ) -> SlowKeysFilter { +        let filter = Self(Arc::new(RwLock::new(SlowKeysFilterInner { +            next, +            slow_key_threshold_ns, +            external_devices: HashSet::new(), +            pending_down_events: Vec::new(), +            ongoing_down_events: Vec::new(), +            input_filter_thread: input_filter_thread.clone(), +        }))); +        input_filter_thread.register_thread_callback(Box::new(filter.clone())); +        filter +    } + +    fn read_inner(&self) -> RwLockReadGuard<'_, SlowKeysFilterInner> { +        self.0.read().unwrap() +    } + +    fn write_inner(&self) -> RwLockWriteGuard<'_, SlowKeysFilterInner> { +        self.0.write().unwrap() +    } + +    fn request_next_callback(&self) { +        let slow_filter = &self.read_inner(); +        if slow_filter.pending_down_events.is_empty() { +            return; +        } +        if let Some(event) = slow_filter.pending_down_events.iter().min_by_key(|x| x.downTime) { +            slow_filter.input_filter_thread.request_timeout_at_time(event.downTime); +        } +    } +} + +impl Filter for SlowKeysFilter { +    fn notify_key(&mut self, event: &KeyEvent) { +        { +            // acquire write lock +            let mut slow_filter = self.write_inner(); +            if !(slow_filter.external_devices.contains(&event.deviceId) +                && event.source == Source::KEYBOARD) +            { +                slow_filter.next.notify_key(event); +                return; +            } +            // Pass all events through if key down has already been processed +            // Do update the downtime before sending the events through +            if let Some(index) = slow_filter +                .ongoing_down_events +                .iter() +                .position(|x| x.device_id == event.deviceId && x.scancode == event.scanCode) +            { +                let mut new_event = *event; +                new_event.downTime = slow_filter.ongoing_down_events[index].down_time; +                slow_filter.next.notify_key(&new_event); +                if event.action == KeyEventAction::UP { +                    slow_filter.ongoing_down_events.remove(index); +                } +                return; +            } +            match event.action { +                KeyEventAction::DOWN => { +                    if slow_filter +                        .pending_down_events +                        .iter() +                        .any(|x| x.deviceId == event.deviceId && x.scanCode == event.scanCode) +                    { +                        debug!("Dropping key down event since another pending down event exists"); +                        return; +                    } +                    let mut pending_event = *event; +                    pending_event.downTime += slow_filter.slow_key_threshold_ns; +                    pending_event.eventTime = pending_event.downTime; +                    slow_filter.pending_down_events.push(pending_event); +                } +                KeyEventAction::UP => { +                    debug!("Dropping key up event due to insufficient press duration"); +                    if let Some(index) = slow_filter +                        .pending_down_events +                        .iter() +                        .position(|x| x.deviceId == event.deviceId && x.scanCode == event.scanCode) +                    { +                        slow_filter.pending_down_events.remove(index); +                    } +                } +                _ => (), +            } +        } // release write lock +        self.request_next_callback(); +    } + +    fn notify_devices_changed(&mut self, device_infos: &[DeviceInfo]) { +        let mut slow_filter = self.write_inner(); +        slow_filter +            .pending_down_events +            .retain(|event| device_infos.iter().any(|x| event.deviceId == x.deviceId)); +        slow_filter +            .ongoing_down_events +            .retain(|event| device_infos.iter().any(|x| event.device_id == x.deviceId)); +        slow_filter.external_devices.clear(); +        for device_info in device_infos { +            if device_info.external { +                slow_filter.external_devices.insert(device_info.deviceId); +            } +        } +        slow_filter.next.notify_devices_changed(device_infos); +    } + +    fn destroy(&mut self) { +        let mut slow_filter = self.write_inner(); +        slow_filter.input_filter_thread.unregister_thread_callback(Box::new(self.clone())); +        slow_filter.next.destroy(); +    } +} + +impl ThreadCallback for SlowKeysFilter { +    fn notify_timeout_expired(&self, when_nanos: i64) { +        { +            // acquire write lock +            let slow_filter = &mut self.write_inner(); +            for event in slow_filter.pending_down_events.clone() { +                if event.downTime <= when_nanos { +                    slow_filter.next.notify_key(&event); +                    slow_filter.ongoing_down_events.push(OngoingKeyDown { +                        scancode: event.scanCode, +                        device_id: event.deviceId, +                        down_time: event.downTime, +                    }); +                } +            } +            slow_filter.pending_down_events.retain(|event| event.downTime > when_nanos); +        } // release write lock +        self.request_next_callback(); +    } + +    fn name(&self) -> &str { +        "slow_keys_filter" +    } +} + +#[cfg(test)] +mod tests { +    use crate::input_filter::{test_callbacks::TestCallbacks, test_filter::TestFilter, Filter}; +    use crate::input_filter_thread::test_thread::TestThread; +    use crate::slow_keys_filter::SlowKeysFilter; +    use android_hardware_input_common::aidl::android::hardware::input::common::Source::Source; +    use com_android_server_inputflinger::aidl::com::android::server::inputflinger::{ +        DeviceInfo::DeviceInfo, KeyEvent::KeyEvent, KeyEventAction::KeyEventAction, +    }; + +    static BASE_KEY_EVENT: KeyEvent = KeyEvent { +        id: 1, +        deviceId: 1, +        downTime: 0, +        readTime: 0, +        eventTime: 0, +        source: Source::KEYBOARD, +        displayId: 0, +        policyFlags: 0, +        action: KeyEventAction::DOWN, +        flags: 0, +        keyCode: 1, +        scanCode: 0, +        metaState: 0, +    }; + +    #[test] +    fn test_is_notify_key_for_internal_keyboard_not_blocked() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let next = TestFilter::new(); +        let mut filter = setup_filter_with_internal_device( +            Box::new(next.clone()), +            test_thread.clone(), +            1,   /* device_id */ +            100, /* threshold */ +        ); +        test_thread.start_looper(); + +        let event = KeyEvent { action: KeyEventAction::DOWN, ..BASE_KEY_EVENT }; +        filter.notify_key(&event); +        assert_eq!(next.last_event().unwrap(), event); +    } + +    #[test] +    fn test_is_notify_key_for_external_stylus_not_blocked() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let next = TestFilter::new(); +        let mut filter = setup_filter_with_external_device( +            Box::new(next.clone()), +            test_thread.clone(), +            1,   /* device_id */ +            100, /* threshold */ +        ); +        test_thread.start_looper(); + +        let event = +            KeyEvent { action: KeyEventAction::DOWN, source: Source::STYLUS, ..BASE_KEY_EVENT }; +        filter.notify_key(&event); +        assert_eq!(next.last_event().unwrap(), event); +    } + +    #[test] +    fn test_notify_key_for_external_keyboard_when_key_pressed_for_threshold_time() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let next = TestFilter::new(); +        let mut filter = setup_filter_with_external_device( +            Box::new(next.clone()), +            test_thread.clone(), +            1,   /* device_id */ +            100, /* threshold */ +        ); +        test_thread.start_looper(); + +        filter.notify_key(&KeyEvent { action: KeyEventAction::DOWN, ..BASE_KEY_EVENT }); +        assert!(next.last_event().is_none()); +        test_thread.dispatch_next(); + +        test_thread.move_time_forward(100); + +        test_thread.stop_looper(); +        assert_eq!( +            next.last_event().unwrap(), +            KeyEvent { +                action: KeyEventAction::DOWN, +                downTime: 100, +                eventTime: 100, +                ..BASE_KEY_EVENT +            } +        ); +    } + +    #[test] +    fn test_notify_key_for_external_keyboard_when_key_not_pressed_for_threshold_time() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let next = TestFilter::new(); +        let mut filter = setup_filter_with_external_device( +            Box::new(next.clone()), +            test_thread.clone(), +            1,   /* device_id */ +            100, /* threshold */ +        ); +        test_thread.start_looper(); + +        filter.notify_key(&KeyEvent { action: KeyEventAction::DOWN, ..BASE_KEY_EVENT }); +        test_thread.dispatch_next(); + +        test_thread.move_time_forward(10); + +        filter.notify_key(&KeyEvent { action: KeyEventAction::UP, ..BASE_KEY_EVENT }); +        test_thread.dispatch_next(); + +        test_thread.stop_looper(); +        assert!(next.last_event().is_none()); +    } + +    #[test] +    fn test_notify_key_for_external_keyboard_when_device_removed_before_threshold_time() { +        let test_callbacks = TestCallbacks::new(); +        let test_thread = TestThread::new(test_callbacks.clone()); +        let next = TestFilter::new(); +        let mut filter = setup_filter_with_external_device( +            Box::new(next.clone()), +            test_thread.clone(), +            1,   /* device_id */ +            100, /* threshold */ +        ); +        test_thread.start_looper(); + +        filter.notify_key(&KeyEvent { action: KeyEventAction::DOWN, ..BASE_KEY_EVENT }); +        assert!(next.last_event().is_none()); +        test_thread.dispatch_next(); + +        filter.notify_devices_changed(&[]); +        test_thread.dispatch_next(); + +        test_thread.move_time_forward(100); + +        test_thread.stop_looper(); +        assert!(next.last_event().is_none()); +    } + +    fn setup_filter_with_external_device( +        next: Box<dyn Filter + Send + Sync>, +        test_thread: TestThread, +        device_id: i32, +        threshold: i64, +    ) -> SlowKeysFilter { +        setup_filter_with_devices( +            next, +            test_thread, +            &[DeviceInfo { deviceId: device_id, external: true }], +            threshold, +        ) +    } + +    fn setup_filter_with_internal_device( +        next: Box<dyn Filter + Send + Sync>, +        test_thread: TestThread, +        device_id: i32, +        threshold: i64, +    ) -> SlowKeysFilter { +        setup_filter_with_devices( +            next, +            test_thread, +            &[DeviceInfo { deviceId: device_id, external: false }], +            threshold, +        ) +    } + +    fn setup_filter_with_devices( +        next: Box<dyn Filter + Send + Sync>, +        test_thread: TestThread, +        devices: &[DeviceInfo], +        threshold: i64, +    ) -> SlowKeysFilter { +        let mut filter = SlowKeysFilter::new(next, threshold, test_thread.get_input_thread()); +        filter.notify_devices_changed(devices); +        filter +    } +} diff --git a/services/inputflinger/rust/sticky_keys_filter.rs b/services/inputflinger/rust/sticky_keys_filter.rs index da581b82bf..6c2277c813 100644 --- a/services/inputflinger/rust/sticky_keys_filter.rs +++ b/services/inputflinger/rust/sticky_keys_filter.rs @@ -142,6 +142,10 @@ impl Filter for StickyKeysFilter {          }          self.next.notify_devices_changed(device_infos);      } + +    fn destroy(&mut self) { +        self.next.destroy(); +    }  }  fn is_modifier_key(keycode: i32) -> bool { |