diff options
-rw-r--r-- | libs/binder/FdTrigger.cpp | 64 | ||||
-rw-r--r-- | libs/binder/Parcel.cpp | 3 | ||||
-rw-r--r-- | libs/binder/RpcServer.cpp | 9 | ||||
-rw-r--r-- | libs/binder/RpcSession.cpp | 12 | ||||
-rw-r--r-- | libs/binder/RpcState.cpp | 160 | ||||
-rw-r--r-- | libs/binder/RpcState.h | 35 | ||||
-rw-r--r-- | libs/binder/RpcTransportRaw.cpp | 90 | ||||
-rw-r--r-- | libs/binder/RpcTransportTls.cpp | 43 | ||||
-rw-r--r-- | libs/binder/RpcWireFormat.h | 9 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcSession.h | 3 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcTransport.h | 15 | ||||
-rw-r--r-- | libs/binder/rust/src/native.rs | 4 | ||||
-rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 10 |
13 files changed, 325 insertions, 132 deletions
diff --git a/libs/binder/FdTrigger.cpp b/libs/binder/FdTrigger.cpp index 49f83ff346..5e22593f69 100644 --- a/libs/binder/FdTrigger.cpp +++ b/libs/binder/FdTrigger.cpp @@ -17,11 +17,13 @@ #define LOG_TAG "FdTrigger" #include <log/log.h> +#include "FdTrigger.h" + #include <poll.h> #include <android-base/macros.h> -#include "FdTrigger.h" +#include "RpcState.h" namespace android { std::unique_ptr<FdTrigger> FdTrigger::make() { @@ -42,21 +44,53 @@ bool FdTrigger::isTriggered() { } status_t FdTrigger::triggerablePoll(base::borrowed_fd fd, int16_t event) { - while (true) { - pollfd pfd[]{{.fd = fd.get(), .events = static_cast<int16_t>(event), .revents = 0}, - {.fd = mRead.get(), .events = 0, .revents = 0}}; - int ret = TEMP_FAILURE_RETRY(poll(pfd, arraysize(pfd), -1)); - if (ret < 0) { - return -errno; - } - if (ret == 0) { - continue; - } - if (pfd[1].revents & POLLHUP) { - return DEAD_OBJECT; - } - return pfd[0].revents & event ? OK : DEAD_OBJECT; + LOG_ALWAYS_FATAL_IF(event == 0, "triggerablePoll %d with event 0 is not allowed", fd.get()); + pollfd pfd[]{{.fd = fd.get(), .events = static_cast<int16_t>(event), .revents = 0}, + {.fd = mRead.get(), .events = 0, .revents = 0}}; + int ret = TEMP_FAILURE_RETRY(poll(pfd, arraysize(pfd), -1)); + if (ret < 0) { + return -errno; + } + LOG_ALWAYS_FATAL_IF(ret == 0, "poll(%d) returns 0 with infinite timeout", fd.get()); + + // At least one FD has events. Check them. + + // Detect explicit trigger(): DEAD_OBJECT + if (pfd[1].revents & POLLHUP) { + return DEAD_OBJECT; } + // See unknown flags in trigger FD's revents (POLLERR / POLLNVAL). + // Treat this error condition as UNKNOWN_ERROR. + if (pfd[1].revents != 0) { + ALOGE("Unknown revents on trigger FD %d: revents = %d", pfd[1].fd, pfd[1].revents); + return UNKNOWN_ERROR; + } + + // pfd[1].revents is 0, hence pfd[0].revents must be set, and only possible values are + // a subset of event | POLLHUP | POLLERR | POLLNVAL. + + // POLLNVAL: invalid FD number, e.g. not opened. + if (pfd[0].revents & POLLNVAL) { + return BAD_VALUE; + } + + // Error condition. It wouldn't be possible to do I/O on |fd| afterwards. + // Note: If this is the write end of a pipe then POLLHUP may also be set simultaneously. We + // still want DEAD_OBJECT in this case. + if (pfd[0].revents & POLLERR) { + LOG_RPC_DETAIL("poll() incoming FD %d results in revents = %d", pfd[0].fd, pfd[0].revents); + return DEAD_OBJECT; + } + + // Success condition; event flag(s) set. Even though POLLHUP may also be set, + // treat it as a success condition to ensure data is drained. + if (pfd[0].revents & event) { + return OK; + } + + // POLLHUP: Peer closed connection. Treat as DEAD_OBJECT. + // This is a very common case, so don't log. + return DEAD_OBJECT; } } // namespace android diff --git a/libs/binder/Parcel.cpp b/libs/binder/Parcel.cpp index 6644187507..6ce09226d9 100644 --- a/libs/binder/Parcel.cpp +++ b/libs/binder/Parcel.cpp @@ -291,6 +291,9 @@ status_t Parcel::unflattenBinder(sp<IBinder>* out) const if (status_t status = mSession->state()->onBinderEntering(mSession, addr, &binder); status != OK) return status; + if (status_t status = mSession->state()->flushExcessBinderRefs(mSession, addr, binder); + status != OK) + return status; } return finishUnflattenBinder(binder, out); diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index 5733993b3b..ba2920e3ac 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -205,8 +205,11 @@ bool RpcServer::shutdown() { } mShutdownTrigger->trigger(); + for (auto& [id, session] : mSessions) { (void)id; + // server lock is a more general lock + std::lock_guard<std::mutex> _lSession(session->mMutex); session->mShutdownTrigger->trigger(); } @@ -275,7 +278,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie RpcConnectionHeader header; if (status == OK) { status = client->interruptableReadFully(server->mShutdownTrigger.get(), &header, - sizeof(header)); + sizeof(header), {}); if (status != OK) { ALOGE("Failed to read ID for client connecting to RPC server: %s", statusToString(status).c_str()); @@ -288,7 +291,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie if (header.sessionIdSize > 0) { sessionId.resize(header.sessionIdSize); status = client->interruptableReadFully(server->mShutdownTrigger.get(), - sessionId.data(), sessionId.size()); + sessionId.data(), sessionId.size(), {}); if (status != OK) { ALOGE("Failed to read session ID for client connecting to RPC server: %s", statusToString(status).c_str()); @@ -313,7 +316,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie }; status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &response, - sizeof(response)); + sizeof(response), {}); if (status != OK) { ALOGE("Failed to send new session response: %s", statusToString(status).c_str()); // still need to cleanup before we can return diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp index 38958c93cb..65f6bc68c9 100644 --- a/libs/binder/RpcSession.cpp +++ b/libs/binder/RpcSession.cpp @@ -217,15 +217,17 @@ status_t RpcSession::transact(const sp<IBinder>& binder, uint32_t code, const Pa } status_t RpcSession::sendDecStrong(const BpBinder* binder) { - return sendDecStrong(binder->getPrivateAccessor().rpcAddress()); + // target is 0 because this is used to free BpBinder objects + return sendDecStrongToTarget(binder->getPrivateAccessor().rpcAddress(), 0 /*target*/); } -status_t RpcSession::sendDecStrong(uint64_t address) { +status_t RpcSession::sendDecStrongToTarget(uint64_t address, size_t target) { ExclusiveConnection connection; status_t status = ExclusiveConnection::find(sp<RpcSession>::fromExisting(this), ConnectionUse::CLIENT_REFCOUNT, &connection); if (status != OK) return status; - return state()->sendDecStrong(connection.get(), sp<RpcSession>::fromExisting(this), address); + return state()->sendDecStrongToTarget(connection.get(), sp<RpcSession>::fromExisting(this), + address, target); } status_t RpcSession::readId() { @@ -558,7 +560,7 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_ } auto sendHeaderStatus = - server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header)); + server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header), {}); if (sendHeaderStatus != OK) { ALOGE("Could not write connection header to socket: %s", statusToString(sendHeaderStatus).c_str()); @@ -568,7 +570,7 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_ if (sessionId.size() > 0) { auto sendSessionIdStatus = server->interruptableWriteFully(mShutdownTrigger.get(), sessionId.data(), - sessionId.size()); + sessionId.size(), {}); if (sendSessionIdStatus != OK) { ALOGE("Could not write session ID ('%s') to socket: %s", base::HexString(sessionId.data(), sessionId.size()).c_str(), diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp index df935fe9e3..86cc91c03e 100644 --- a/libs/binder/RpcState.cpp +++ b/libs/binder/RpcState.cpp @@ -152,7 +152,7 @@ status_t RpcState::onBinderEntering(const sp<RpcSession>& session, uint64_t addr return BAD_VALUE; } - std::unique_lock<std::mutex> _l(mNodeMutex); + std::lock_guard<std::mutex> _l(mNodeMutex); if (mTerminated) return DEAD_OBJECT; if (auto it = mNodeForAddress.find(address); it != mNodeForAddress.end()) { @@ -160,13 +160,7 @@ status_t RpcState::onBinderEntering(const sp<RpcSession>& session, uint64_t addr // implicitly have strong RPC refcount, since we received this binder it->second.timesRecd++; - - _l.unlock(); - - // We have timesRecd RPC refcounts, but we only need to hold on to one - // when we keep the object. All additional dec strongs are sent - // immediately, we wait to send the last one in BpBinder::onLastDecStrong. - return session->sendDecStrong(address); + return OK; } // we don't know about this binder, so the other side of the connection @@ -187,6 +181,36 @@ status_t RpcState::onBinderEntering(const sp<RpcSession>& session, uint64_t addr return OK; } +status_t RpcState::flushExcessBinderRefs(const sp<RpcSession>& session, uint64_t address, + const sp<IBinder>& binder) { + std::unique_lock<std::mutex> _l(mNodeMutex); + if (mTerminated) return DEAD_OBJECT; + + auto it = mNodeForAddress.find(address); + + LOG_ALWAYS_FATAL_IF(it == mNodeForAddress.end(), "Can't be deleted while we hold sp<>"); + LOG_ALWAYS_FATAL_IF(it->second.binder != binder, + "Caller of flushExcessBinderRefs using inconsistent arguments"); + + // if this is a local binder, then we want to get rid of all refcounts + // (tell the other process it can drop the binder when it wants to - we + // have a local sp<>, so we will drop it when we want to as well). if + // this is a remote binder, then we need to hold onto one refcount until + // it is dropped in BpBinder::onLastStrongRef + size_t targetRecd = binder->localBinder() ? 0 : 1; + + // We have timesRecd RPC refcounts, but we only need to hold on to one + // when we keep the object. All additional dec strongs are sent + // immediately, we wait to send the last one in BpBinder::onLastDecStrong. + if (it->second.timesRecd != targetRecd) { + _l.unlock(); + + return session->sendDecStrongToTarget(address, targetRecd); + } + + return OK; +} + size_t RpcState::countBinders() { std::lock_guard<std::mutex> _l(mNodeMutex); return mNodeForAddress.size(); @@ -283,7 +307,7 @@ RpcState::CommandData::CommandData(size_t size) : mSize(size) { status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session, const char* what, const void* data, - size_t size) { + size_t size, const std::function<status_t()>& altPoll) { LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), android::base::HexString(data, size).c_str()); @@ -295,7 +319,7 @@ status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection, if (status_t status = connection->rpcTransport->interruptableWriteFully(session->mShutdownTrigger.get(), - data, size); + data, size, altPoll); status != OK) { LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on RpcTransport %p, error: %s", what, size, connection->rpcTransport.get(), statusToString(status).c_str()); @@ -317,7 +341,7 @@ status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection, if (status_t status = connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(), - data, size); + data, size, {}); status != OK) { LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on RpcTransport %p, error: %s", what, size, connection->rpcTransport.get(), statusToString(status).c_str()); @@ -427,12 +451,16 @@ status_t RpcState::transact(const sp<RpcSession::RpcConnection>& connection, const sp<IBinder>& binder, uint32_t code, const Parcel& data, const sp<RpcSession>& session, Parcel* reply, uint32_t flags) { if (!data.isForRpc()) { - ALOGE("Refusing to send RPC with parcel not crafted for RPC"); + ALOGE("Refusing to send RPC with parcel not crafted for RPC call on binder %p code " + "%" PRIu32, + binder.get(), code); return BAD_TYPE; } if (data.objectsCount() != 0) { - ALOGE("Parcel at %p has attached objects but is being used in an RPC call", &data); + ALOGE("Parcel at %p has attached objects but is being used in an RPC call on binder %p " + "code %" PRIu32, + &data, binder.get(), code); return BAD_TYPE; } @@ -495,21 +523,44 @@ status_t RpcState::transactAddress(const sp<RpcSession::RpcConnection>& connecti memcpy(transactionData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireTransaction), data.data(), data.dataSize()); + constexpr size_t kWaitMaxUs = 1000000; + constexpr size_t kWaitLogUs = 10000; + size_t waitUs = 0; + + // Oneway calls have no sync point, so if many are sent before, whether this + // is a twoway or oneway transaction, they may have filled up the socket. + // So, make sure we drain them before polling. + std::function<status_t()> drainRefs = [&] { + if (waitUs > kWaitLogUs) { + ALOGE("Cannot send command, trying to process pending refcounts. Waiting %zuus. Too " + "many oneway calls?", + waitUs); + } + + if (waitUs > 0) { + usleep(waitUs); + waitUs = std::min(kWaitMaxUs, waitUs * 2); + } else { + waitUs = 1; + } + + return drainCommands(connection, session, CommandType::CONTROL_ONLY); + }; + if (status_t status = rpcSend(connection, session, "transaction", transactionData.data(), - transactionData.size()); - status != OK) + transactionData.size(), drainRefs); + status != OK) { // TODO(b/167966510): need to undo onBinderLeaving - we know the // refcount isn't successfully transferred. return status; + } if (flags & IBinder::FLAG_ONEWAY) { LOG_RPC_DETAIL("Oneway command, so no longer waiting on RpcTransport %p", connection->rpcTransport.get()); // Do not wait on result. - // However, too many oneway calls may cause refcounts to build up and fill up the socket, - // so process those. - return drainCommands(connection, session, CommandType::CONTROL_ONLY); + return OK; } LOG_ALWAYS_FATAL_IF(reply == nullptr, "Reply parcel must be used for synchronous transaction."); @@ -567,32 +618,43 @@ status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection, return OK; } -status_t RpcState::sendDecStrong(const sp<RpcSession::RpcConnection>& connection, - const sp<RpcSession>& session, uint64_t addr) { +status_t RpcState::sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& connection, + const sp<RpcSession>& session, uint64_t addr, + size_t target) { + RpcDecStrong body = { + .address = RpcWireAddress::fromRaw(addr), + }; + { std::lock_guard<std::mutex> _l(mNodeMutex); if (mTerminated) return DEAD_OBJECT; // avoid fatal only, otherwise races auto it = mNodeForAddress.find(addr); LOG_ALWAYS_FATAL_IF(it == mNodeForAddress.end(), "Sending dec strong on unknown address %" PRIu64, addr); - LOG_ALWAYS_FATAL_IF(it->second.timesRecd <= 0, "Bad dec strong %" PRIu64, addr); - it->second.timesRecd--; + LOG_ALWAYS_FATAL_IF(it->second.timesRecd < target, "Can't dec count of %zu to %zu.", + it->second.timesRecd, target); + + // typically this happens when multiple threads send dec refs at the + // same time - the transactions will get combined automatically + if (it->second.timesRecd == target) return OK; + + body.amount = it->second.timesRecd - target; + it->second.timesRecd = target; + LOG_ALWAYS_FATAL_IF(nullptr != tryEraseNode(it), "Bad state. RpcState shouldn't own received binder"); } RpcWireHeader cmd = { .command = RPC_COMMAND_DEC_STRONG, - .bodySize = sizeof(RpcWireAddress), + .bodySize = sizeof(RpcDecStrong), }; if (status_t status = rpcSend(connection, session, "dec ref header", &cmd, sizeof(cmd)); status != OK) return status; - if (status_t status = rpcSend(connection, session, "dec ref body", &addr, sizeof(addr)); - status != OK) - return status; - return OK; + + return rpcSend(connection, session, "dec ref body", &body, sizeof(body)); } status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& connection, @@ -684,7 +746,7 @@ status_t RpcState::processTransactInternal(const sp<RpcSession::RpcConnection>& // for 'recursive' calls to this, we have already read and processed the // binder from the transaction data and taken reference counts into account, // so it is cached here. - sp<IBinder> targetRef; + sp<IBinder> target; processTransactInternalTailCall: if (transactionData.size() < sizeof(RpcWireTransaction)) { @@ -699,12 +761,9 @@ processTransactInternalTailCall: bool oneway = transaction->flags & IBinder::FLAG_ONEWAY; status_t replyStatus = OK; - sp<IBinder> target; if (addr != 0) { - if (!targetRef) { + if (!target) { replyStatus = onBinderEntering(session, addr, &target); - } else { - target = targetRef; } if (replyStatus != OK) { @@ -823,6 +882,12 @@ processTransactInternalTailCall: } } + // Binder refs are flushed for oneway calls only after all calls which are + // built up are executed. Otherwise, they fill up the binder buffer. + if (addr != 0 && replyStatus == OK && !oneway) { + replyStatus = flushExcessBinderRefs(session, addr, target); + } + if (oneway) { if (replyStatus != OK) { ALOGW("Oneway call failed with error: %d", replyStatus); @@ -865,12 +930,20 @@ processTransactInternalTailCall: // reset up arguments transactionData = std::move(todo.data); - targetRef = std::move(todo.ref); + LOG_ALWAYS_FATAL_IF(target != todo.ref, + "async list should be associated with a binder"); it->second.asyncTodo.pop(); goto processTransactInternalTailCall; } } + + // done processing all the async commands on this binder that we can, so + // write decstrongs on the binder + if (addr != 0 && replyStatus == OK) { + return flushExcessBinderRefs(session, addr, target); + } + return OK; } @@ -912,16 +985,15 @@ status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connect status != OK) return status; - if (command.bodySize != sizeof(RpcWireAddress)) { - ALOGE("Expecting %zu but got %" PRId32 " bytes for RpcWireAddress. Terminating!", - sizeof(RpcWireAddress), command.bodySize); + if (command.bodySize != sizeof(RpcDecStrong)) { + ALOGE("Expecting %zu but got %" PRId32 " bytes for RpcDecStrong. Terminating!", + sizeof(RpcDecStrong), command.bodySize); (void)session->shutdownAndWait(false); return BAD_VALUE; } - RpcWireAddress* address = reinterpret_cast<RpcWireAddress*>(commandData.data()); - - uint64_t addr = RpcWireAddress::toRaw(*address); + RpcDecStrong* body = reinterpret_cast<RpcDecStrong*>(commandData.data()); + uint64_t addr = RpcWireAddress::toRaw(body->address); std::unique_lock<std::mutex> _l(mNodeMutex); auto it = mNodeForAddress.find(addr); if (it == mNodeForAddress.end()) { @@ -939,15 +1011,19 @@ status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connect return BAD_VALUE; } - if (it->second.timesSent == 0) { - ALOGE("No record of sending binder, but requested decStrong: %" PRIu64, addr); + if (it->second.timesSent < body->amount) { + ALOGE("Record of sending binder %zu times, but requested decStrong for %" PRIu64 " of %u", + it->second.timesSent, addr, body->amount); return OK; } LOG_ALWAYS_FATAL_IF(it->second.sentRef == nullptr, "Inconsistent state, lost ref for %" PRIu64, addr); - it->second.timesSent--; + LOG_RPC_DETAIL("Processing dec strong of %" PRIu64 " by %u from %zu", addr, body->amount, + it->second.timesSent); + + it->second.timesSent -= body->amount; sp<IBinder> tempHold = tryEraseNode(it); _l.unlock(); tempHold = nullptr; // destructor may make binder calls on this session diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h index dcfb5699e8..50de22bc29 100644 --- a/libs/binder/RpcState.h +++ b/libs/binder/RpcState.h @@ -82,8 +82,29 @@ public: uint64_t address, uint32_t code, const Parcel& data, const sp<RpcSession>& session, Parcel* reply, uint32_t flags); - [[nodiscard]] status_t sendDecStrong(const sp<RpcSession::RpcConnection>& connection, - const sp<RpcSession>& session, uint64_t address); + + /** + * The ownership model here carries an implicit strong refcount whenever a + * binder is sent across processes. Since we have a local strong count in + * sp<> over these objects, we only ever need to keep one of these. So, + * typically we tell the remote process that we drop all the implicit dec + * strongs, and we hold onto the last one. 'target' here is the target + * timesRecd (the number of remaining reference counts) we wish to keep. + * Typically this should be '0' or '1'. The target is used instead of an + * explicit decrement count in order to allow multiple threads to lower the + * number of counts simultaneously. Since we only lower the count to 0 when + * a binder is deleted, targets of '1' should only be sent when the caller + * owns a local strong reference to the binder. Larger targets may be used + * for testing, and to make the function generic, but generally this should + * be avoided because it would be hard to guarantee another thread doesn't + * lower the number of held refcounts to '1'. Note also, these refcounts + * must be sent actively. If they are sent when binders are deleted, this + * can cause leaks, since even remote binders carry an implicit strong ref + * when they are sent to another process. + */ + [[nodiscard]] status_t sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& connection, + const sp<RpcSession>& session, uint64_t address, + size_t target); enum class CommandType { ANY, @@ -108,6 +129,13 @@ public: */ [[nodiscard]] status_t onBinderEntering(const sp<RpcSession>& session, uint64_t address, sp<IBinder>* out); + /** + * Called on incoming binders to update refcounting information. This should + * only be called when it is done as part of making progress on a + * transaction. + */ + [[nodiscard]] status_t flushExcessBinderRefs(const sp<RpcSession>& session, uint64_t address, + const sp<IBinder>& binder); size_t countBinders(); void dump(); @@ -149,7 +177,8 @@ private: [[nodiscard]] status_t rpcSend(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session, const char* what, - const void* data, size_t size); + const void* data, size_t size, + const std::function<status_t()>& altPoll = nullptr); [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session, const char* what, void* data, size_t size); diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp index 41f4a9f2bf..7669518954 100644 --- a/libs/binder/RpcTransportRaw.cpp +++ b/libs/binder/RpcTransportRaw.cpp @@ -43,56 +43,72 @@ public: return ret; } - status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size) override { - const uint8_t* buffer = reinterpret_cast<const uint8_t*>(data); - const uint8_t* end = buffer + size; + template <typename Buffer, typename SendOrReceive> + status_t interruptableReadOrWrite(FdTrigger* fdTrigger, Buffer buffer, size_t size, + SendOrReceive sendOrReceiveFun, const char* funName, + int16_t event, const std::function<status_t()>& altPoll) { + const Buffer end = buffer + size; MAYBE_WAIT_IN_FLAKE_MODE; - status_t status; - while ((status = fdTrigger->triggerablePoll(mSocket.get(), POLLOUT)) == OK) { - ssize_t writeSize = - TEMP_FAILURE_RETRY(::send(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL)); - if (writeSize < 0) { - int savedErrno = errno; - LOG_RPC_DETAIL("RpcTransport send(): %s", strerror(savedErrno)); - return -savedErrno; - } - - if (writeSize == 0) return DEAD_OBJECT; - - buffer += writeSize; - if (buffer == end) return OK; + // Since we didn't poll, we need to manually check to see if it was triggered. Otherwise, we + // may never know we should be shutting down. + if (fdTrigger->isTriggered()) { + return DEAD_OBJECT; } - return status; - } - - status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) override { - uint8_t* buffer = reinterpret_cast<uint8_t*>(data); - uint8_t* end = buffer + size; - MAYBE_WAIT_IN_FLAKE_MODE; + bool havePolled = false; + while (true) { + ssize_t processSize = TEMP_FAILURE_RETRY( + sendOrReceiveFun(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL)); - status_t status; - while ((status = fdTrigger->triggerablePoll(mSocket.get(), POLLIN)) == OK) { - ssize_t readSize = - TEMP_FAILURE_RETRY(::recv(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL)); - if (readSize < 0) { + if (processSize < 0) { int savedErrno = errno; - LOG_RPC_DETAIL("RpcTransport recv(): %s", strerror(savedErrno)); - return -savedErrno; - } - if (readSize == 0) return DEAD_OBJECT; // EOF + // Still return the error on later passes, since it would expose + // a problem with polling + if (havePolled || + (!havePolled && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) { + LOG_RPC_DETAIL("RpcTransport %s(): %s", funName, strerror(savedErrno)); + return -savedErrno; + } + } else if (processSize == 0) { + return DEAD_OBJECT; + } else { + buffer += processSize; + if (buffer == end) { + return OK; + } + } - buffer += readSize; - if (buffer == end) return OK; + if (altPoll) { + if (status_t status = altPoll(); status != OK) return status; + if (fdTrigger->isTriggered()) { + return DEAD_OBJECT; + } + } else { + if (status_t status = fdTrigger->triggerablePoll(mSocket.get(), event); + status != OK) + return status; + if (!havePolled) havePolled = true; + } } - return status; + } + + status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size, + const std::function<status_t()>& altPoll) override { + return interruptableReadOrWrite(fdTrigger, reinterpret_cast<const uint8_t*>(data), size, + send, "send", POLLOUT, altPoll); + } + + status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, + const std::function<status_t()>& altPoll) override { + return interruptableReadOrWrite(fdTrigger, reinterpret_cast<uint8_t*>(data), size, recv, + "recv", POLLIN, altPoll); } private: - android::base::unique_fd mSocket; + base::unique_fd mSocket; }; // RpcTransportCtx with TLS disabled. diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp index f8cd71d434..7f810b17ba 100644 --- a/libs/binder/RpcTransportTls.cpp +++ b/libs/binder/RpcTransportTls.cpp @@ -169,12 +169,13 @@ public: // If |sslError| is WANT_READ / WANT_WRITE, poll for POLLIN / POLLOUT respectively. Otherwise // return error. Also return error if |fdTrigger| is triggered before or during poll(). status_t pollForSslError(android::base::borrowed_fd fd, int sslError, FdTrigger* fdTrigger, - const char* fnString, int additionalEvent = 0) { + const char* fnString, int additionalEvent, + const std::function<status_t()>& altPoll) { switch (sslError) { case SSL_ERROR_WANT_READ: - return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString); + return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString, altPoll); case SSL_ERROR_WANT_WRITE: - return handlePoll(POLLOUT | additionalEvent, fd, fdTrigger, fnString); + return handlePoll(POLLOUT | additionalEvent, fd, fdTrigger, fnString, altPoll); case SSL_ERROR_SYSCALL: { auto queue = toString(); LOG_TLS_DETAIL("%s(): %s. Treating as DEAD_OBJECT. Error queue: %s", fnString, @@ -194,11 +195,17 @@ private: bool mHandled = false; status_t handlePoll(int event, android::base::borrowed_fd fd, FdTrigger* fdTrigger, - const char* fnString) { - status_t ret = fdTrigger->triggerablePoll(fd, event); + const char* fnString, const std::function<status_t()>& altPoll) { + status_t ret; + if (altPoll) { + ret = altPoll(); + if (fdTrigger->isTriggered()) ret = DEAD_OBJECT; + } else { + ret = fdTrigger->triggerablePoll(fd, event); + } + if (ret != OK && ret != DEAD_OBJECT) { - ALOGE("triggerablePoll error while poll()-ing after %s(): %s", fnString, - statusToString(ret).c_str()); + ALOGE("poll error while after %s(): %s", fnString, statusToString(ret).c_str()); } clear(); return ret; @@ -268,8 +275,10 @@ public: RpcTransportTls(android::base::unique_fd socket, Ssl ssl) : mSocket(std::move(socket)), mSsl(std::move(ssl)) {} Result<size_t> peek(void* buf, size_t size) override; - status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size) override; - status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) override; + status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size, + const std::function<status_t()>& altPoll) override; + status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, + const std::function<status_t()>& altPoll) override; private: android::base::unique_fd mSocket; @@ -295,7 +304,8 @@ Result<size_t> RpcTransportTls::peek(void* buf, size_t size) { } status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data, - size_t size) { + size_t size, + const std::function<status_t()>& altPoll) { auto buffer = reinterpret_cast<const uint8_t*>(data); const uint8_t* end = buffer + size; @@ -317,8 +327,8 @@ status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const vo int sslError = mSsl.getError(writeSize); // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be // triggerablePoll()-ed. Then additionalEvent is no longer necessary. - status_t pollStatus = - errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, "SSL_write", POLLIN); + status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, + "SSL_write", POLLIN, altPoll); if (pollStatus != OK) return pollStatus; // Do not advance buffer. Try SSL_write() again. } @@ -326,7 +336,8 @@ status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const vo return OK; } -status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) { +status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, + const std::function<status_t()>& altPoll) { auto buffer = reinterpret_cast<uint8_t*>(data); uint8_t* end = buffer + size; @@ -350,8 +361,8 @@ status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* dat return DEAD_OBJECT; } int sslError = mSsl.getError(readSize); - status_t pollStatus = - errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, "SSL_read"); + status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, + "SSL_read", 0, altPoll); if (pollStatus != OK) return pollStatus; // Do not advance buffer. Try SSL_read() again. } @@ -382,7 +393,7 @@ bool setFdAndDoHandshake(Ssl* ssl, android::base::borrowed_fd fd, FdTrigger* fdT } int sslError = ssl->getError(ret); status_t pollStatus = - errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake"); + errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake", 0, {}); if (pollStatus != OK) return false; } } diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h index a87aa074a9..171550e620 100644 --- a/libs/binder/RpcWireFormat.h +++ b/libs/binder/RpcWireFormat.h @@ -85,7 +85,7 @@ enum : uint32_t { */ RPC_COMMAND_REPLY, /** - * follows is RpcWireAddress + * follows is RpcDecStrong * * note - this in the protocol directly instead of as a 'special * transaction' in order to keep it as lightweight as possible (we don't @@ -117,6 +117,13 @@ struct RpcWireHeader { }; static_assert(sizeof(RpcWireHeader) == 16); +struct RpcDecStrong { + RpcWireAddress address; + uint32_t amount; + uint32_t reserved; +}; +static_assert(sizeof(RpcDecStrong) == 16); + struct RpcWireTransaction { RpcWireAddress address; uint32_t code; diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h index 6a29c05e36..12d448d1e4 100644 --- a/libs/binder/include/binder/RpcSession.h +++ b/libs/binder/include/binder/RpcSession.h @@ -176,7 +176,8 @@ private: friend RpcState; explicit RpcSession(std::unique_ptr<RpcTransportCtx> ctx); - [[nodiscard]] status_t sendDecStrong(uint64_t address); + // for 'target', see RpcState::sendDecStrongToTarget + [[nodiscard]] status_t sendDecStrongToTarget(uint64_t address, size_t target); class EventListener : public virtual RefBase { public: diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h index 4fe2324d07..db8b5e920e 100644 --- a/libs/binder/include/binder/RpcTransport.h +++ b/libs/binder/include/binder/RpcTransport.h @@ -18,6 +18,7 @@ #pragma once +#include <functional> #include <memory> #include <string> @@ -43,14 +44,20 @@ public: /** * Read (or write), but allow to be interrupted by a trigger. * + * altPoll - function to be called instead of polling, when needing to wait + * to read/write data. If this returns an error, that error is returned from + * this function. + * * Return: * OK - succeeded in completely processing 'size' * error - interrupted (failure or trigger) */ - [[nodiscard]] virtual status_t interruptableWriteFully(FdTrigger *fdTrigger, const void *buf, - size_t size) = 0; - [[nodiscard]] virtual status_t interruptableReadFully(FdTrigger *fdTrigger, void *buf, - size_t size) = 0; + [[nodiscard]] virtual status_t interruptableWriteFully( + FdTrigger *fdTrigger, const void *buf, size_t size, + const std::function<status_t()> &altPoll) = 0; + [[nodiscard]] virtual status_t interruptableReadFully( + FdTrigger *fdTrigger, void *buf, size_t size, + const std::function<status_t()> &altPoll) = 0; protected: RpcTransport() = default; diff --git a/libs/binder/rust/src/native.rs b/libs/binder/rust/src/native.rs index e7c33960e2..a91092e2d4 100644 --- a/libs/binder/rust/src/native.rs +++ b/libs/binder/rust/src/native.rs @@ -441,6 +441,8 @@ unsafe impl<B: Remotable> AsNative<sys::AIBinder> for Binder<B> { /// /// Registers the given binder object with the given identifier. If successful, /// this service can then be retrieved using that identifier. +/// +/// This function will panic if the identifier contains a 0 byte (NUL). pub fn add_service(identifier: &str, mut binder: SpIBinder) -> Result<()> { let instance = CString::new(identifier).unwrap(); let status = unsafe { @@ -462,6 +464,8 @@ pub fn add_service(identifier: &str, mut binder: SpIBinder) -> Result<()> { /// /// If any service in the process is registered as lazy, all should be, otherwise /// the process may be shut down while a service is in use. +/// +/// This function will panic if the identifier contains a 0 byte (NUL). pub fn register_lazy_service(identifier: &str, mut binder: SpIBinder) -> Result<()> { let instance = CString::new(identifier).unwrap(); let status = unsafe { diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index a2558f5aaf..a1058bced6 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -1573,7 +1573,7 @@ public: FdTrigger* fdTrigger) { std::string message(kMessage); auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), - message.size()); + message.size(), {}); if (status != OK) return AssertionFailure() << statusToString(status); return AssertionSuccess(); } @@ -1606,7 +1606,7 @@ public: std::string readMessage(expectedMessage.size(), '\0'); status_t readStatus = mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), - readMessage.size()); + readMessage.size(), {}); if (readStatus != OK) { return AssertionFailure() << statusToString(readStatus); } @@ -1800,8 +1800,8 @@ TEST_P(RpcTransportTest, Trigger) { bool shouldContinueWriting = false; auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(RpcTransportTestUtils::kMessage); - auto status = - serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); + auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), + message.size(), {}); if (status != OK) return AssertionFailure() << statusToString(status); { @@ -1811,7 +1811,7 @@ TEST_P(RpcTransportTest, Trigger) { } } - status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size()); + status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size(), {}); if (status != DEAD_OBJECT) return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully " "should return DEAD_OBJECT, but it is " |