diff options
| author | 2021-05-01 23:14:04 +0000 | |
|---|---|---|
| committer | 2021-05-07 00:20:20 -0700 | |
| commit | ae02a1f49c1b0ae49af4331be1117b37616c0adb (patch) | |
| tree | 09f6dc17c663dc36065f18edfff1364cd268d02d | |
| parent | 96e84d982fc1e50998983bed031a005140fb0355 (diff) | |
Store connections by token instead of by fd
The connections are currently stored by fd. If a connection is removed
via 'removeInputChannel', it is possible to re-create the same
connection and have it keyed by the same fd. When this happens, a race
condition may occur where a socket hangup on this fd would cause the
removal of a newly registered connection.
In this refactor, the connections are no longer stored by fd. The looper
interface for adding fds has two versions:
1) the old one that we are currently using, which is marked as 'do not
use'
2) the new one where a callback object is provided instead.
In this CL, we switch to the new version of the callback.
There is now also no need to store the inputchannels in a separate
structure, because we can use the connections collection that's now
keyed by token to find them.
In a future refactor, we should switch to using 'unique_ptr' for the
inputchannels. Most of the time when we are looking for an input
channel, we are actually interested in finding the corresponding
connection.
If we switch Connection to shared_ptr, we can also look into switching
LooperEventCallback to store a weak pointer to a connection instead of
storing the connection token. This should speed up the handling of
events, by avoiding a map lookup.
Test: ./reinitinput.sh. Observe that it doesnt finish after this patch
Test: atest inputflinger_tests
Bug: 182478748
Change-Id: I601f765eebfadcaeff3661a10a10c4a4f0477389
| -rw-r--r-- | include/input/InputTransport.h | 2 | ||||
| -rw-r--r-- | services/inputflinger/dispatcher/InputDispatcher.cpp | 220 | ||||
| -rw-r--r-- | services/inputflinger/dispatcher/InputDispatcher.h | 11 |
3 files changed, 110 insertions, 123 deletions
diff --git a/include/input/InputTransport.h b/include/input/InputTransport.h index ff3367839d..360dfbfd73 100644 --- a/include/input/InputTransport.h +++ b/include/input/InputTransport.h @@ -229,7 +229,7 @@ public: InputChannel(const InputChannel& other) : mName(other.mName), mFd(::dup(other.mFd)), mToken(other.mToken){}; InputChannel(const std::string name, android::base::unique_fd fd, sp<IBinder> token); - virtual ~InputChannel(); + ~InputChannel() override; /** * Create a pair of input channels. * The two returned input channels are equivalent, and are labeled as "server" and "client" diff --git a/services/inputflinger/dispatcher/InputDispatcher.cpp b/services/inputflinger/dispatcher/InputDispatcher.cpp index 16cb7d7a51..9a43ed9b2e 100644 --- a/services/inputflinger/dispatcher/InputDispatcher.cpp +++ b/services/inputflinger/dispatcher/InputDispatcher.cpp @@ -283,27 +283,6 @@ static V getValueByKey(const std::unordered_map<K, V>& map, K key) { return it != map.end() ? it->second : V{}; } -/** - * Find the entry in std::unordered_map by value, and remove it. - * If more than one entry has the same value, then all matching - * key-value pairs will be removed. - * - * Return true if at least one value has been removed. - */ -template <typename K, typename V> -static bool removeByValue(std::unordered_map<K, V>& map, const V& value) { - bool removed = false; - for (auto it = map.begin(); it != map.end();) { - if (it->second == value) { - it = map.erase(it); - removed = true; - } else { - it++; - } - } - return removed; -} - static bool haveSameToken(const sp<InputWindowHandle>& first, const sp<InputWindowHandle>& second) { if (first == second) { return true; @@ -507,8 +486,8 @@ InputDispatcher::~InputDispatcher() { drainInboundQueueLocked(); } - while (!mConnectionsByFd.empty()) { - sp<Connection> connection = mConnectionsByFd.begin()->second; + while (!mConnectionsByToken.empty()) { + sp<Connection> connection = mConnectionsByToken.begin()->second; removeInputChannel(connection->inputChannel->getConnectionToken()); } } @@ -3297,86 +3276,78 @@ void InputDispatcher::releaseDispatchEntry(DispatchEntry* dispatchEntry) { delete dispatchEntry; } -int InputDispatcher::handleReceiveCallback(int fd, int events, void* data) { - InputDispatcher* d = static_cast<InputDispatcher*>(data); - - { // acquire lock - std::scoped_lock _l(d->mLock); - - if (d->mConnectionsByFd.find(fd) == d->mConnectionsByFd.end()) { - ALOGE("Received spurious receive callback for unknown input channel. " - "fd=%d, events=0x%x", - fd, events); - return 0; // remove the callback - } - - bool notify; - sp<Connection> connection = d->mConnectionsByFd[fd]; - if (!(events & (ALOOPER_EVENT_ERROR | ALOOPER_EVENT_HANGUP))) { - if (!(events & ALOOPER_EVENT_INPUT)) { - ALOGW("channel '%s' ~ Received spurious callback for unhandled poll event. " - "events=0x%x", - connection->getInputChannelName().c_str(), events); - return 1; - } +int InputDispatcher::handleReceiveCallback(int events, sp<IBinder> connectionToken) { + std::scoped_lock _l(mLock); + sp<Connection> connection = getConnectionLocked(connectionToken); + if (connection == nullptr) { + ALOGW("Received looper callback for unknown input channel token %p. events=0x%x", + connectionToken.get(), events); + return 0; // remove the callback + } - nsecs_t currentTime = now(); - bool gotOne = false; - status_t status = OK; - for (;;) { - Result<InputPublisher::ConsumerResponse> result = - connection->inputPublisher.receiveConsumerResponse(); - if (!result.ok()) { - status = result.error().code(); - break; - } + bool notify; + if (!(events & (ALOOPER_EVENT_ERROR | ALOOPER_EVENT_HANGUP))) { + if (!(events & ALOOPER_EVENT_INPUT)) { + ALOGW("channel '%s' ~ Received spurious callback for unhandled poll event. " + "events=0x%x", + connection->getInputChannelName().c_str(), events); + return 1; + } - if (std::holds_alternative<InputPublisher::Finished>(*result)) { - const InputPublisher::Finished& finish = - std::get<InputPublisher::Finished>(*result); - d->finishDispatchCycleLocked(currentTime, connection, finish.seq, - finish.handled, finish.consumeTime); - } else if (std::holds_alternative<InputPublisher::Timeline>(*result)) { - // TODO(b/167947340): Report this data to LatencyTracker - } - gotOne = true; - } - if (gotOne) { - d->runCommandsLockedInterruptible(); - if (status == WOULD_BLOCK) { - return 1; - } + nsecs_t currentTime = now(); + bool gotOne = false; + status_t status = OK; + for (;;) { + Result<InputPublisher::ConsumerResponse> result = + connection->inputPublisher.receiveConsumerResponse(); + if (!result.ok()) { + status = result.error().code(); + break; } - notify = status != DEAD_OBJECT || !connection->monitor; - if (notify) { - ALOGE("channel '%s' ~ Failed to receive finished signal. status=%s(%d)", - connection->getInputChannelName().c_str(), statusToString(status).c_str(), - status); + if (std::holds_alternative<InputPublisher::Finished>(*result)) { + const InputPublisher::Finished& finish = + std::get<InputPublisher::Finished>(*result); + finishDispatchCycleLocked(currentTime, connection, finish.seq, finish.handled, + finish.consumeTime); + } else if (std::holds_alternative<InputPublisher::Timeline>(*result)) { + // TODO(b/167947340): Report this data to LatencyTracker } - } else { - // Monitor channels are never explicitly unregistered. - // We do it automatically when the remote endpoint is closed so don't warn about them. - const bool stillHaveWindowHandle = - d->getWindowHandleLocked(connection->inputChannel->getConnectionToken()) != - nullptr; - notify = !connection->monitor && stillHaveWindowHandle; - if (notify) { - ALOGW("channel '%s' ~ Consumer closed input channel or an error occurred. " - "events=0x%x", - connection->getInputChannelName().c_str(), events); + gotOne = true; + } + if (gotOne) { + runCommandsLockedInterruptible(); + if (status == WOULD_BLOCK) { + return 1; } } - // Remove the channel. - d->removeInputChannelLocked(connection->inputChannel->getConnectionToken(), notify); - return 0; // remove the callback - } // release lock + notify = status != DEAD_OBJECT || !connection->monitor; + if (notify) { + ALOGE("channel '%s' ~ Failed to receive finished signal. status=%s(%d)", + connection->getInputChannelName().c_str(), statusToString(status).c_str(), + status); + } + } else { + // Monitor channels are never explicitly unregistered. + // We do it automatically when the remote endpoint is closed so don't warn about them. + const bool stillHaveWindowHandle = + getWindowHandleLocked(connection->inputChannel->getConnectionToken()) != nullptr; + notify = !connection->monitor && stillHaveWindowHandle; + if (notify) { + ALOGW("channel '%s' ~ Consumer closed input channel or an error occurred. events=0x%x", + connection->getInputChannelName().c_str(), events); + } + } + + // Remove the channel. + removeInputChannelLocked(connection->inputChannel->getConnectionToken(), notify); + return 0; // remove the callback } void InputDispatcher::synthesizeCancelationEventsForAllConnectionsLocked( const CancelationOptions& options) { - for (const auto& [fd, connection] : mConnectionsByFd) { + for (const auto& [token, connection] : mConnectionsByToken) { synthesizeCancelationEventsForConnectionLocked(connection, options); } } @@ -4342,11 +4313,11 @@ bool InputDispatcher::hasResponsiveConnectionLocked(InputWindowHandle& windowHan std::shared_ptr<InputChannel> InputDispatcher::getInputChannelLocked( const sp<IBinder>& token) const { - size_t count = mInputChannelsByToken.count(token); - if (count == 0) { + auto connectionIt = mConnectionsByToken.find(token); + if (connectionIt == mConnectionsByToken.end()) { return nullptr; } - return mInputChannelsByToken.at(token); + return connectionIt->second->inputChannel; } void InputDispatcher::updateWindowHandlesForDisplayLocked( @@ -4996,13 +4967,13 @@ void InputDispatcher::dumpDispatchStateLocked(std::string& dump) { dump += INDENT "ReplacedKeys: <empty>\n"; } - if (!mConnectionsByFd.empty()) { + if (!mConnectionsByToken.empty()) { dump += INDENT "Connections:\n"; - for (const auto& pair : mConnectionsByFd) { - const sp<Connection>& connection = pair.second; + for (const auto& [token, connection] : mConnectionsByToken) { dump += StringPrintf(INDENT2 "%i: channelName='%s', windowName='%s', " "status=%s, monitor=%s, responsive=%s\n", - pair.first, connection->getInputChannelName().c_str(), + connection->inputChannel->getFd().get(), + connection->getInputChannelName().c_str(), connection->getWindowName().c_str(), connection->getStatusLabel(), toString(connection->monitor), toString(connection->responsive)); @@ -5050,14 +5021,23 @@ void InputDispatcher::dumpMonitors(std::string& dump, const std::vector<Monitor> } } +class LooperEventCallback : public LooperCallback { +public: + LooperEventCallback(std::function<int(int events)> callback) : mCallback(callback) {} + int handleEvent(int /*fd*/, int events, void* /*data*/) override { return mCallback(events); } + +private: + std::function<int(int events)> mCallback; +}; + Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputChannel(const std::string& name) { #if DEBUG_CHANNEL_CREATION ALOGD("channel '%s' ~ createInputChannel", name.c_str()); #endif - std::shared_ptr<InputChannel> serverChannel; + std::unique_ptr<InputChannel> serverChannel; std::unique_ptr<InputChannel> clientChannel; - status_t result = openInputChannelPair(name, serverChannel, clientChannel); + status_t result = InputChannel::openInputChannelPair(name, serverChannel, clientChannel); if (result) { return base::Error(result) << "Failed to open input channel pair with name " << name; @@ -5065,13 +5045,20 @@ Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputChannel(const { // acquire lock std::scoped_lock _l(mLock); - sp<Connection> connection = new Connection(serverChannel, false /*monitor*/, mIdGenerator); - + const sp<IBinder>& token = serverChannel->getConnectionToken(); int fd = serverChannel->getFd(); - mConnectionsByFd[fd] = connection; - mInputChannelsByToken[serverChannel->getConnectionToken()] = serverChannel; + sp<Connection> connection = + new Connection(std::move(serverChannel), false /*monitor*/, mIdGenerator); - mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, handleReceiveCallback, this); + if (mConnectionsByToken.find(token) != mConnectionsByToken.end()) { + ALOGE("Created a new connection, but the token %p is already known", token.get()); + } + mConnectionsByToken.emplace(token, connection); + + std::function<int(int events)> callback = std::bind(&InputDispatcher::handleReceiveCallback, + this, std::placeholders::_1, token); + + mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, new LooperEventCallback(callback), nullptr); } // release lock // Wake the looper because some connections have changed. @@ -5099,18 +5086,21 @@ Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputMonitor(int32_ } sp<Connection> connection = new Connection(serverChannel, true /*monitor*/, mIdGenerator); - + const sp<IBinder>& token = serverChannel->getConnectionToken(); const int fd = serverChannel->getFd(); - mConnectionsByFd[fd] = connection; - mInputChannelsByToken[serverChannel->getConnectionToken()] = serverChannel; + + if (mConnectionsByToken.find(token) != mConnectionsByToken.end()) { + ALOGE("Created a new connection, but the token %p is already known", token.get()); + } + mConnectionsByToken.emplace(token, connection); + std::function<int(int events)> callback = std::bind(&InputDispatcher::handleReceiveCallback, + this, std::placeholders::_1, token); auto& monitorsByDisplay = isGestureMonitor ? mGestureMonitorsByDisplay : mGlobalMonitorsByDisplay; monitorsByDisplay[displayId].emplace_back(serverChannel, pid); - mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, handleReceiveCallback, this); - ALOGI("Created monitor %s for display %" PRId32 ", gesture=%s, pid=%" PRId32, name.c_str(), - displayId, toString(isGestureMonitor), pid); + mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, new LooperEventCallback(callback), nullptr); } // Wake the looper because some connections have changed. @@ -5143,7 +5133,6 @@ status_t InputDispatcher::removeInputChannelLocked(const sp<IBinder>& connection } removeConnectionLocked(connection); - mInputChannelsByToken.erase(connectionToken); if (connection->monitor) { removeMonitorChannelLocked(connectionToken); @@ -5301,9 +5290,8 @@ sp<Connection> InputDispatcher::getConnectionLocked(const sp<IBinder>& inputConn return nullptr; } - for (const auto& pair : mConnectionsByFd) { - const sp<Connection>& connection = pair.second; - if (connection->inputChannel->getConnectionToken() == inputConnectionToken) { + for (const auto& [token, connection] : mConnectionsByToken) { + if (token == inputConnectionToken) { return connection; } } @@ -5321,7 +5309,7 @@ std::string InputDispatcher::getConnectionNameLocked(const sp<IBinder>& connecti void InputDispatcher::removeConnectionLocked(const sp<Connection>& connection) { mAnrTracker.eraseToken(connection->inputChannel->getConnectionToken()); - removeByValue(mConnectionsByFd, connection); + mConnectionsByToken.erase(connection->inputChannel->getConnectionToken()); } void InputDispatcher::onDispatchCycleFinishedLocked(nsecs_t currentTime, diff --git a/services/inputflinger/dispatcher/InputDispatcher.h b/services/inputflinger/dispatcher/InputDispatcher.h index 7ab4fd76dc..7ba03e8063 100644 --- a/services/inputflinger/dispatcher/InputDispatcher.h +++ b/services/inputflinger/dispatcher/InputDispatcher.h @@ -211,9 +211,6 @@ private: bool addPortalWindows = false, bool ignoreDragWindow = false) REQUIRES(mLock); - // All registered connections mapped by channel file descriptor. - std::unordered_map<int, sp<Connection>> mConnectionsByFd GUARDED_BY(mLock); - sp<Connection> getConnectionLocked(const sp<IBinder>& inputConnectionToken) const REQUIRES(mLock); @@ -225,8 +222,10 @@ private: struct StrongPointerHash { std::size_t operator()(const sp<T>& b) const { return std::hash<T*>{}(b.get()); } }; - std::unordered_map<sp<IBinder>, std::shared_ptr<InputChannel>, StrongPointerHash<IBinder>> - mInputChannelsByToken GUARDED_BY(mLock); + + // All registered connections mapped by input channel token. + std::unordered_map<sp<IBinder>, sp<Connection>, StrongPointerHash<IBinder>> mConnectionsByToken + GUARDED_BY(mLock); // Finds the display ID of the gesture monitor identified by the provided token. std::optional<int32_t> findGestureMonitorDisplayByTokenLocked(const sp<IBinder>& token) @@ -544,7 +543,7 @@ private: bool notify) REQUIRES(mLock); void drainDispatchQueue(std::deque<DispatchEntry*>& queue); void releaseDispatchEntry(DispatchEntry* dispatchEntry); - static int handleReceiveCallback(int fd, int events, void* data); + int handleReceiveCallback(int events, sp<IBinder> connectionToken); // The action sent should only be of type AMOTION_EVENT_* void dispatchPointerDownOutsideFocus(uint32_t source, int32_t action, const sp<IBinder>& newToken) REQUIRES(mLock); |