diff options
-rw-r--r-- | libs/binder/RpcServer.cpp | 14 | ||||
-rw-r--r-- | libs/binder/RpcSession.cpp | 20 | ||||
-rw-r--r-- | libs/binder/RpcState.cpp | 39 | ||||
-rw-r--r-- | libs/binder/RpcState.h | 4 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcServer.h | 3 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcSession.h | 20 |
6 files changed, 56 insertions, 44 deletions
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index e3bf2a5e36..bff5543c9b 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -192,10 +192,10 @@ bool RpcServer::shutdown() { } mShutdownTrigger->trigger(); - while (mJoinThreadRunning || !mConnectingThreads.empty()) { + while (mJoinThreadRunning || !mConnectingThreads.empty() || !mSessions.empty()) { ALOGI("Waiting for RpcServer to shut down. Join thread running: %d, Connecting threads: " - "%zu", - mJoinThreadRunning, mConnectingThreads.size()); + "%zu, Sessions: %zu", + mJoinThreadRunning, mConnectingThreads.size(), mSessions.size()); mShutdownCv.wait(_l); } @@ -278,7 +278,8 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie server->mSessionIdCounter++; session = RpcSession::make(); - session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter); + session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter, + server->mShutdownTrigger); server->mSessions[server->mSessionIdCounter] = session; } else { @@ -344,6 +345,11 @@ void RpcServer::onSessionTerminating(const sp<RpcSession>& session) { (void)mSessions.erase(it); } +void RpcServer::onSessionThreadEnding(const sp<RpcSession>& session) { + (void)session; + mShutdownCv.notify_all(); +} + bool RpcServer::hasServer() { LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!"); std::lock_guard<std::mutex> _l(mLock); diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp index 9f26a33335..7c458c123a 100644 --- a/libs/binder/RpcSession.cpp +++ b/libs/binder/RpcSession.cpp @@ -207,12 +207,19 @@ void RpcSession::join(unique_fd client) { LOG_ALWAYS_FATAL_IF(!removeServerConnection(connection), "bad state: connection object guaranteed to be in list"); + sp<RpcServer> server; { std::lock_guard<std::mutex> _l(mMutex); auto it = mThreads.find(std::this_thread::get_id()); LOG_ALWAYS_FATAL_IF(it == mThreads.end()); it->second.detach(); mThreads.erase(it); + + server = mForServer.promote(); + } + + if (server != nullptr) { + server->onSessionThreadEnding(sp<RpcSession>::fromExisting(this)); } } @@ -314,14 +321,25 @@ bool RpcSession::setupOneSocketClient(const RpcSocketAddress& addr, int32_t id) void RpcSession::addClientConnection(unique_fd fd) { std::lock_guard<std::mutex> _l(mMutex); + + if (mShutdownTrigger == nullptr) { + mShutdownTrigger = FdTrigger::make(); + } + sp<RpcConnection> session = sp<RpcConnection>::make(); session->fd = std::move(fd); mClientConnections.push_back(session); } -void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId) { +void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId, + const std::shared_ptr<FdTrigger>& shutdownTrigger) { + LOG_ALWAYS_FATAL_IF(mForServer.unsafe_get() != nullptr); + LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr); + LOG_ALWAYS_FATAL_IF(shutdownTrigger == nullptr); + mId = sessionId; mForServer = server; + mShutdownTrigger = shutdownTrigger; } sp<RpcSession::RpcConnection> RpcSession::assignServerToThisThread(unique_fd fd) { diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp index 230de6f0ef..6483486340 100644 --- a/libs/binder/RpcState.cpp +++ b/libs/binder/RpcState.cpp @@ -229,30 +229,22 @@ bool RpcState::rpcSend(const base::unique_fd& fd, const char* what, const void* return true; } -bool RpcState::rpcRec(const base::unique_fd& fd, const char* what, void* data, size_t size) { +bool RpcState::rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session, const char* what, + void* data, size_t size) { if (size > std::numeric_limits<ssize_t>::max()) { ALOGE("Cannot rec %s at size %zu (too big)", what, size); terminate(); return false; } - ssize_t recd = TEMP_FAILURE_RETRY(recv(fd.get(), data, size, MSG_WAITALL | MSG_NOSIGNAL)); - - if (recd < 0 || recd != static_cast<ssize_t>(size)) { - terminate(); - - if (recd == 0 && errno == 0) { - LOG_RPC_DETAIL("No more data when trying to read %s on fd %d", what, fd.get()); - return false; - } - - ALOGE("Failed to read %s (received %zd of %zu bytes) on fd %d, error: %s", what, recd, size, - fd.get(), strerror(errno)); + if (status_t status = session->mShutdownTrigger->interruptableReadFully(fd.get(), data, size); + status != OK) { + ALOGE("Failed to read %s (%zu bytes) on fd %d, error: %s", what, size, fd.get(), + statusToString(status).c_str()); return false; - } else { - LOG_RPC_DETAIL("Received %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str()); } + LOG_RPC_DETAIL("Received %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str()); return true; } @@ -398,7 +390,7 @@ status_t RpcState::waitForReply(const base::unique_fd& fd, const sp<RpcSession>& Parcel* reply) { RpcWireHeader command; while (true) { - if (!rpcRec(fd, "command header", &command, sizeof(command))) { + if (!rpcRec(fd, session, "command header", &command, sizeof(command))) { return DEAD_OBJECT; } @@ -413,7 +405,7 @@ status_t RpcState::waitForReply(const base::unique_fd& fd, const sp<RpcSession>& return NO_MEMORY; } - if (!rpcRec(fd, "reply body", data.data(), command.bodySize)) { + if (!rpcRec(fd, session, "reply body", data.data(), command.bodySize)) { return DEAD_OBJECT; } @@ -465,7 +457,7 @@ status_t RpcState::getAndExecuteCommand(const base::unique_fd& fd, const sp<RpcS LOG_RPC_DETAIL("getAndExecuteCommand on fd %d", fd.get()); RpcWireHeader command; - if (!rpcRec(fd, "command header", &command, sizeof(command))) { + if (!rpcRec(fd, session, "command header", &command, sizeof(command))) { return DEAD_OBJECT; } @@ -493,7 +485,7 @@ status_t RpcState::processServerCommand(const base::unique_fd& fd, const sp<RpcS case RPC_COMMAND_TRANSACT: return processTransact(fd, session, command); case RPC_COMMAND_DEC_STRONG: - return processDecStrong(fd, command); + return processDecStrong(fd, session, command); } // We should always know the version of the opposing side, and since the @@ -513,7 +505,7 @@ status_t RpcState::processTransact(const base::unique_fd& fd, const sp<RpcSessio if (!transactionData.valid()) { return NO_MEMORY; } - if (!rpcRec(fd, "transaction body", transactionData.data(), transactionData.size())) { + if (!rpcRec(fd, session, "transaction body", transactionData.data(), transactionData.size())) { return DEAD_OBJECT; } @@ -626,7 +618,7 @@ status_t RpcState::processTransactInternal(const base::unique_fd& fd, const sp<R // // sessions associated with servers must have an ID // (hence abort) - int32_t id = session->getPrivateAccessorForId().get().value(); + int32_t id = session->mId.value(); replyStatus = reply.writeInt32(id); break; } @@ -721,14 +713,15 @@ status_t RpcState::processTransactInternal(const base::unique_fd& fd, const sp<R return OK; } -status_t RpcState::processDecStrong(const base::unique_fd& fd, const RpcWireHeader& command) { +status_t RpcState::processDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session, + const RpcWireHeader& command) { LOG_ALWAYS_FATAL_IF(command.command != RPC_COMMAND_DEC_STRONG, "command: %d", command.command); CommandData commandData(command.bodySize); if (!commandData.valid()) { return NO_MEMORY; } - if (!rpcRec(fd, "dec ref body", commandData.data(), commandData.size())) { + if (!rpcRec(fd, session, "dec ref body", commandData.data(), commandData.size())) { return DEAD_OBJECT; } diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h index 31f8a22065..f913925adb 100644 --- a/libs/binder/RpcState.h +++ b/libs/binder/RpcState.h @@ -117,7 +117,8 @@ private: [[nodiscard]] bool rpcSend(const base::unique_fd& fd, const char* what, const void* data, size_t size); - [[nodiscard]] bool rpcRec(const base::unique_fd& fd, const char* what, void* data, size_t size); + [[nodiscard]] bool rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session, + const char* what, void* data, size_t size); [[nodiscard]] status_t waitForReply(const base::unique_fd& fd, const sp<RpcSession>& session, Parcel* reply); @@ -130,6 +131,7 @@ private: const sp<RpcSession>& session, CommandData transactionData); [[nodiscard]] status_t processDecStrong(const base::unique_fd& fd, + const sp<RpcSession>& session, const RpcWireHeader& command); struct BinderNode { diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h index 50770f12d2..178459d2fe 100644 --- a/libs/binder/include/binder/RpcServer.h +++ b/libs/binder/include/binder/RpcServer.h @@ -150,6 +150,7 @@ public: // internal use only void onSessionTerminating(const sp<RpcSession>& session); + void onSessionThreadEnding(const sp<RpcSession>& session); private: friend sp<RpcServer>; @@ -171,7 +172,7 @@ private: wp<IBinder> mRootObjectWeak; std::map<int32_t, sp<RpcSession>> mSessions; int32_t mSessionIdCounter = 0; - std::unique_ptr<RpcSession::FdTrigger> mShutdownTrigger; + std::shared_ptr<RpcSession::FdTrigger> mShutdownTrigger; std::condition_variable mShutdownCv; }; diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h index d6b796f281..d46f27567c 100644 --- a/libs/binder/include/binder/RpcSession.h +++ b/libs/binder/include/binder/RpcSession.h @@ -94,27 +94,16 @@ public: // internal only const std::unique_ptr<RpcState>& state() { return mState; } - class PrivateAccessorForId { - private: - friend class RpcSession; - friend class RpcState; - explicit PrivateAccessorForId(const RpcSession* session) : mSession(session) {} - - const std::optional<int32_t> get() { return mSession->mId; } - - const RpcSession* mSession; - }; - PrivateAccessorForId getPrivateAccessorForId() const { return PrivateAccessorForId(this); } - private: - friend PrivateAccessorForId; friend sp<RpcSession>; friend RpcServer; + friend RpcState; RpcSession(); /** This is not a pipe. */ struct FdTrigger { static std::unique_ptr<FdTrigger> make(); + /** * poll() on this fd for POLLHUP to get notification when trigger is called */ @@ -167,7 +156,8 @@ private: bool setupSocketClient(const RpcSocketAddress& address); bool setupOneSocketClient(const RpcSocketAddress& address, int32_t sessionId); void addClientConnection(base::unique_fd fd); - void setForServer(const wp<RpcServer>& server, int32_t sessionId); + void setForServer(const wp<RpcServer>& server, int32_t sessionId, + const std::shared_ptr<FdTrigger>& shutdownTrigger); sp<RpcConnection> assignServerToThisThread(base::unique_fd fd); bool removeServerConnection(const sp<RpcConnection>& connection); @@ -218,6 +208,8 @@ private: // TODO(b/183988761): this shouldn't be guessable std::optional<int32_t> mId; + std::shared_ptr<FdTrigger> mShutdownTrigger; + std::unique_ptr<RpcState> mState; std::mutex mMutex; // for all below |