diff options
author | 2021-12-10 08:41:54 +0000 | |
---|---|---|
committer | 2021-12-23 04:42:05 +0000 | |
commit | a39e4edeaa1635a6d6246bb9de470f046e4856e9 (patch) | |
tree | 4337a39f0b2fd9a61fcf1f8df2bef7745c5ba4fe | |
parent | 3183b5c2ccb259ec218420c8bfcd8f7e4f45c05e (diff) |
binder: Eliminate a data copy in RPC transport operations
Switch RpcTransportRaw to use sendmsg() and recvmsg() over
iovecs to send data from multiple buffers to avoid having
to copy all data into a single large buffer.
Bug: 202878542
Test: atest binderRpcTest
Change-Id: I8ba7fa815040555503160ae41888a0b0efe9e5d2
-rw-r--r-- | libs/binder/RpcServer.cpp | 13 | ||||
-rw-r--r-- | libs/binder/RpcSession.cpp | 8 | ||||
-rw-r--r-- | libs/binder/RpcState.cpp | 115 | ||||
-rw-r--r-- | libs/binder/RpcState.h | 10 | ||||
-rw-r--r-- | libs/binder/RpcTransportRaw.cpp | 65 | ||||
-rw-r--r-- | libs/binder/RpcTransportTls.cpp | 95 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcTransport.h | 9 | ||||
-rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 17 |
8 files changed, 185 insertions, 147 deletions
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index 93ed50e986..ace5cd5052 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -287,8 +287,8 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie RpcConnectionHeader header; if (status == OK) { - status = client->interruptableReadFully(server->mShutdownTrigger.get(), &header, - sizeof(header), {}); + iovec iov{&header, sizeof(header)}; + status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {}); if (status != OK) { ALOGE("Failed to read ID for client connecting to RPC server: %s", statusToString(status).c_str()); @@ -301,8 +301,9 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie if (header.sessionIdSize > 0) { if (header.sessionIdSize == kSessionIdBytes) { sessionId.resize(header.sessionIdSize); - status = client->interruptableReadFully(server->mShutdownTrigger.get(), - sessionId.data(), sessionId.size(), {}); + iovec iov{sessionId.data(), sessionId.size()}; + status = + client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {}); if (status != OK) { ALOGE("Failed to read session ID for client connecting to RPC server: %s", statusToString(status).c_str()); @@ -331,8 +332,8 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie .version = protocolVersion, }; - status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &response, - sizeof(response), {}); + iovec iov{&response, sizeof(response)}; + status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1, {}); if (status != OK) { ALOGE("Failed to send new session response: %s", statusToString(status).c_str()); // still need to cleanup before we can return diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp index a5a2bb1017..b84395e7cb 100644 --- a/libs/binder/RpcSession.cpp +++ b/libs/binder/RpcSession.cpp @@ -615,8 +615,9 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_ header.options |= RPC_CONNECTION_OPTION_INCOMING; } + iovec headerIov{&header, sizeof(header)}; auto sendHeaderStatus = - server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header), {}); + server->interruptableWriteFully(mShutdownTrigger.get(), &headerIov, 1, {}); if (sendHeaderStatus != OK) { ALOGE("Could not write connection header to socket: %s", statusToString(sendHeaderStatus).c_str()); @@ -624,9 +625,10 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_ } if (sessionId.size() > 0) { + iovec sessionIov{const_cast<void*>(static_cast<const void*>(sessionId.data())), + sessionId.size()}; auto sendSessionIdStatus = - server->interruptableWriteFully(mShutdownTrigger.get(), sessionId.data(), - sessionId.size(), {}); + server->interruptableWriteFully(mShutdownTrigger.get(), &sessionIov, 1, {}); if (sendSessionIdStatus != OK) { ALOGE("Could not write session ID ('%s') to socket: %s", base::HexString(sessionId.data(), sessionId.size()).c_str(), diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp index 09b3d68626..6286c9c7bc 100644 --- a/libs/binder/RpcState.cpp +++ b/libs/binder/RpcState.cpp @@ -19,6 +19,7 @@ #include "RpcState.h" #include <android-base/hex.h> +#include <android-base/macros.h> #include <android-base/scopeguard.h> #include <binder/BpBinder.h> #include <binder/IPCThreadState.h> @@ -309,22 +310,18 @@ RpcState::CommandData::CommandData(size_t size) : mSize(size) { } status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection, - const sp<RpcSession>& session, const char* what, const void* data, - size_t size, const std::function<status_t()>& altPoll) { - LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), - android::base::HexString(data, size).c_str()); - - if (size > std::numeric_limits<ssize_t>::max()) { - ALOGE("Cannot send %s at size %zu (too big)", what, size); - (void)session->shutdownAndWait(false); - return BAD_VALUE; + const sp<RpcSession>& session, const char* what, iovec* iovs, + size_t niovs, const std::function<status_t()>& altPoll) { + for (size_t i = 0; i < niovs; i++) { + LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), + android::base::HexString(iovs[i].iov_base, iovs[i].iov_len).c_str()); } if (status_t status = connection->rpcTransport->interruptableWriteFully(session->mShutdownTrigger.get(), - data, size, altPoll); + iovs, niovs, altPoll); status != OK) { - LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on RpcTransport %p, error: %s", what, size, + LOG_RPC_DETAIL("Failed to write %s (%zu iovs) on RpcTransport %p, error: %s", what, niovs, connection->rpcTransport.get(), statusToString(status).c_str()); (void)session->shutdownAndWait(false); return status; @@ -334,34 +331,30 @@ status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection, } status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection, - const sp<RpcSession>& session, const char* what, void* data, - size_t size) { - if (size > std::numeric_limits<ssize_t>::max()) { - ALOGE("Cannot rec %s at size %zu (too big)", what, size); - (void)session->shutdownAndWait(false); - return BAD_VALUE; - } - + const sp<RpcSession>& session, const char* what, iovec* iovs, + size_t niovs) { if (status_t status = connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(), - data, size, {}); + iovs, niovs, {}); status != OK) { - LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on RpcTransport %p, error: %s", what, size, + LOG_RPC_DETAIL("Failed to read %s (%zu iovs) on RpcTransport %p, error: %s", what, niovs, connection->rpcTransport.get(), statusToString(status).c_str()); (void)session->shutdownAndWait(false); return status; } - LOG_RPC_DETAIL("Received %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), - android::base::HexString(data, size).c_str()); + for (size_t i = 0; i < niovs; i++) { + LOG_RPC_DETAIL("Received %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), + android::base::HexString(iovs[i].iov_base, iovs[i].iov_len).c_str()); + } return OK; } status_t RpcState::readNewSessionResponse(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session, uint32_t* version) { RpcNewSessionResponse response; - if (status_t status = - rpcRec(connection, session, "new session response", &response, sizeof(response)); + iovec iov{&response, sizeof(response)}; + if (status_t status = rpcRec(connection, session, "new session response", &iov, 1); status != OK) { return status; } @@ -374,14 +367,15 @@ status_t RpcState::sendConnectionInit(const sp<RpcSession::RpcConnection>& conne RpcOutgoingConnectionInit init{ .msg = RPC_CONNECTION_INIT_OKAY, }; - return rpcSend(connection, session, "connection init", &init, sizeof(init)); + iovec iov{&init, sizeof(init)}; + return rpcSend(connection, session, "connection init", &iov, 1); } status_t RpcState::readConnectionInit(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session) { RpcOutgoingConnectionInit init; - if (status_t status = rpcRec(connection, session, "connection init", &init, sizeof(init)); - status != OK) + iovec iov{&init, sizeof(init)}; + if (status_t status = rpcRec(connection, session, "connection init", &iov, 1); status != OK) return status; static_assert(sizeof(init.msg) == sizeof(RPC_CONNECTION_INIT_OKAY)); @@ -514,17 +508,6 @@ status_t RpcState::transactAddress(const sp<RpcSession::RpcConnection>& connecti .flags = flags, .asyncNumber = asyncNumber, }; - CommandData transactionData(sizeof(RpcWireHeader) + sizeof(RpcWireTransaction) + - data.dataSize()); - if (!transactionData.valid()) { - return NO_MEMORY; - } - - memcpy(transactionData.data() + 0, &command, sizeof(RpcWireHeader)); - memcpy(transactionData.data() + sizeof(RpcWireHeader), &transaction, - sizeof(RpcWireTransaction)); - memcpy(transactionData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireTransaction), data.data(), - data.dataSize()); constexpr size_t kWaitMaxUs = 1000000; constexpr size_t kWaitLogUs = 10000; @@ -550,8 +533,13 @@ status_t RpcState::transactAddress(const sp<RpcSession::RpcConnection>& connecti return drainCommands(connection, session, CommandType::CONTROL_ONLY); }; - if (status_t status = rpcSend(connection, session, "transaction", transactionData.data(), - transactionData.size(), drainRefs); + iovec iovs[]{ + {&command, sizeof(RpcWireHeader)}, + {&transaction, sizeof(RpcWireTransaction)}, + {const_cast<uint8_t*>(data.data()), data.dataSize()}, + }; + if (status_t status = + rpcSend(connection, session, "transaction", iovs, arraysize(iovs), drainRefs); status != OK) { // TODO(b/167966510): need to undo onBinderLeaving - we know the // refcount isn't successfully transferred. @@ -584,8 +572,8 @@ status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session, Parcel* reply) { RpcWireHeader command; while (true) { - if (status_t status = rpcRec(connection, session, "command header (for reply)", &command, - sizeof(command)); + iovec iov{&command, sizeof(command)}; + if (status_t status = rpcRec(connection, session, "command header (for reply)", &iov, 1); status != OK) return status; @@ -599,8 +587,8 @@ status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection, CommandData data(command.bodySize); if (!data.valid()) return NO_MEMORY; - if (status_t status = rpcRec(connection, session, "reply body", data.data(), command.bodySize); - status != OK) + iovec iov{data.data(), command.bodySize}; + if (status_t status = rpcRec(connection, session, "reply body", &iov, 1); status != OK) return status; if (command.bodySize < sizeof(RpcWireReply)) { @@ -653,11 +641,8 @@ status_t RpcState::sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& co .command = RPC_COMMAND_DEC_STRONG, .bodySize = sizeof(RpcDecStrong), }; - if (status_t status = rpcSend(connection, session, "dec ref header", &cmd, sizeof(cmd)); - status != OK) - return status; - - return rpcSend(connection, session, "dec ref body", &body, sizeof(body)); + iovec iovs[]{{&cmd, sizeof(cmd)}, {&body, sizeof(body)}}; + return rpcSend(connection, session, "dec ref", iovs, arraysize(iovs)); } status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& connection, @@ -665,8 +650,8 @@ status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& con LOG_RPC_DETAIL("getAndExecuteCommand on RpcTransport %p", connection->rpcTransport.get()); RpcWireHeader command; - if (status_t status = rpcRec(connection, session, "command header (for server)", &command, - sizeof(command)); + iovec iov{&command, sizeof(command)}; + if (status_t status = rpcRec(connection, session, "command header (for server)", &iov, 1); status != OK) return status; @@ -726,9 +711,8 @@ status_t RpcState::processTransact(const sp<RpcSession::RpcConnection>& connecti if (!transactionData.valid()) { return NO_MEMORY; } - if (status_t status = rpcRec(connection, session, "transaction body", transactionData.data(), - transactionData.size()); - status != OK) + iovec iov{transactionData.data(), transactionData.size()}; + if (status_t status = rpcRec(connection, session, "transaction body", &iov, 1); status != OK) return status; return processTransactInternal(connection, session, std::move(transactionData)); @@ -965,16 +949,12 @@ processTransactInternalTailCall: .status = replyStatus, }; - CommandData replyData(sizeof(RpcWireHeader) + sizeof(RpcWireReply) + reply.dataSize()); - if (!replyData.valid()) { - return NO_MEMORY; - } - memcpy(replyData.data() + 0, &cmdReply, sizeof(RpcWireHeader)); - memcpy(replyData.data() + sizeof(RpcWireHeader), &rpcReply, sizeof(RpcWireReply)); - memcpy(replyData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireReply), reply.data(), - reply.dataSize()); - - return rpcSend(connection, session, "reply", replyData.data(), replyData.size()); + iovec iovs[]{ + {&cmdReply, sizeof(RpcWireHeader)}, + {&rpcReply, sizeof(RpcWireReply)}, + {const_cast<uint8_t*>(reply.data()), reply.dataSize()}, + }; + return rpcSend(connection, session, "reply", iovs, arraysize(iovs)); } status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connection, @@ -985,9 +965,8 @@ status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connect if (!commandData.valid()) { return NO_MEMORY; } - if (status_t status = - rpcRec(connection, session, "dec ref body", commandData.data(), commandData.size()); - status != OK) + iovec iov{commandData.data(), commandData.size()}; + if (status_t status = rpcRec(connection, session, "dec ref body", &iov, 1); status != OK) return status; if (command.bodySize != sizeof(RpcDecStrong)) { diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h index dba0a43060..5cad394a2f 100644 --- a/libs/binder/RpcState.h +++ b/libs/binder/RpcState.h @@ -24,6 +24,8 @@ #include <optional> #include <queue> +#include <sys/uio.h> + namespace android { struct RpcWireHeader; @@ -177,12 +179,12 @@ private: }; [[nodiscard]] status_t rpcSend(const sp<RpcSession::RpcConnection>& connection, - const sp<RpcSession>& session, const char* what, - const void* data, size_t size, + const sp<RpcSession>& session, const char* what, iovec* iovs, + size_t niovs, const std::function<status_t()>& altPoll = nullptr); [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection, - const sp<RpcSession>& session, const char* what, void* data, - size_t size); + const sp<RpcSession>& session, const char* what, iovec* iovs, + size_t niovs); [[nodiscard]] status_t waitForReply(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session, Parcel* reply); diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp index 7669518954..2182e1868e 100644 --- a/libs/binder/RpcTransportRaw.cpp +++ b/libs/binder/RpcTransportRaw.cpp @@ -43,12 +43,10 @@ public: return ret; } - template <typename Buffer, typename SendOrReceive> - status_t interruptableReadOrWrite(FdTrigger* fdTrigger, Buffer buffer, size_t size, + template <typename SendOrReceive> + status_t interruptableReadOrWrite(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, SendOrReceive sendOrReceiveFun, const char* funName, int16_t event, const std::function<status_t()>& altPoll) { - const Buffer end = buffer + size; - MAYBE_WAIT_IN_FLAKE_MODE; // Since we didn't poll, we need to manually check to see if it was triggered. Otherwise, we @@ -57,26 +55,61 @@ public: return DEAD_OBJECT; } + // If iovs has one or more empty vectors at the end and + // we somehow advance past all the preceding vectors and + // pass some or all of the empty ones to sendmsg/recvmsg, + // the call will return processSize == 0. In that case + // we should be returning OK but instead return DEAD_OBJECT. + // To avoid this problem, we make sure here that the last + // vector at iovs[niovs - 1] has a non-zero length. + while (niovs > 0 && iovs[niovs - 1].iov_len == 0) { + niovs--; + } + if (niovs == 0) { + // The vectors are all empty, so we have nothing to send. + return OK; + } + bool havePolled = false; while (true) { - ssize_t processSize = TEMP_FAILURE_RETRY( - sendOrReceiveFun(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL)); + msghdr msg{ + .msg_iov = iovs, + .msg_iovlen = niovs, + }; + ssize_t processSize = + TEMP_FAILURE_RETRY(sendOrReceiveFun(mSocket.get(), &msg, MSG_NOSIGNAL)); if (processSize < 0) { int savedErrno = errno; // Still return the error on later passes, since it would expose // a problem with polling - if (havePolled || - (!havePolled && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) { + if (havePolled || (savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) { LOG_RPC_DETAIL("RpcTransport %s(): %s", funName, strerror(savedErrno)); return -savedErrno; } } else if (processSize == 0) { return DEAD_OBJECT; } else { - buffer += processSize; - if (buffer == end) { + while (processSize > 0 && niovs > 0) { + auto& iov = iovs[0]; + if (static_cast<size_t>(processSize) < iov.iov_len) { + // Advance the base of the current iovec + iov.iov_base = reinterpret_cast<char*>(iov.iov_base) + processSize; + iov.iov_len -= processSize; + break; + } + + // The current iovec was fully written + processSize -= iov.iov_len; + iovs++; + niovs--; + } + if (niovs == 0) { + LOG_ALWAYS_FATAL_IF(processSize > 0, + "Reached the end of iovecs " + "with %zd bytes remaining", + processSize); return OK; } } @@ -95,16 +128,16 @@ public: } } - status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size, + status_t interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, const std::function<status_t()>& altPoll) override { - return interruptableReadOrWrite(fdTrigger, reinterpret_cast<const uint8_t*>(data), size, - send, "send", POLLOUT, altPoll); + return interruptableReadOrWrite(fdTrigger, iovs, niovs, sendmsg, "sendmsg", POLLOUT, + altPoll); } - status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, + status_t interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, const std::function<status_t()>& altPoll) override { - return interruptableReadOrWrite(fdTrigger, reinterpret_cast<uint8_t*>(data), size, recv, - "recv", POLLIN, altPoll); + return interruptableReadOrWrite(fdTrigger, iovs, niovs, recvmsg, "recvmsg", POLLIN, + altPoll); } private: diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp index 7f810b17ba..c05ea1512f 100644 --- a/libs/binder/RpcTransportTls.cpp +++ b/libs/binder/RpcTransportTls.cpp @@ -275,9 +275,9 @@ public: RpcTransportTls(android::base::unique_fd socket, Ssl ssl) : mSocket(std::move(socket)), mSsl(std::move(ssl)) {} Result<size_t> peek(void* buf, size_t size) override; - status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size, + status_t interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, const std::function<status_t()>& altPoll) override; - status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, + status_t interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, const std::function<status_t()>& altPoll) override; private: @@ -303,68 +303,83 @@ Result<size_t> RpcTransportTls::peek(void* buf, size_t size) { return ret; } -status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data, - size_t size, +status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, const std::function<status_t()>& altPoll) { - auto buffer = reinterpret_cast<const uint8_t*>(data); - const uint8_t* end = buffer + size; - MAYBE_WAIT_IN_FLAKE_MODE; // Before doing any I/O, check trigger once. This ensures the trigger is checked at least // once. The trigger is also checked via triggerablePoll() after every SSL_write(). if (fdTrigger->isTriggered()) return DEAD_OBJECT; - while (buffer < end) { - size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); - auto [writeSize, errorQueue] = mSsl.call(SSL_write, buffer, todo); - if (writeSize > 0) { - buffer += writeSize; - errorQueue.clear(); + size_t size = 0; + for (size_t i = 0; i < niovs; i++) { + const iovec& iov = iovs[i]; + if (iov.iov_len == 0) { continue; } - // SSL_write() should never return 0 unless BIO_write were to return 0. - int sslError = mSsl.getError(writeSize); - // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be - // triggerablePoll()-ed. Then additionalEvent is no longer necessary. - status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, - "SSL_write", POLLIN, altPoll); - if (pollStatus != OK) return pollStatus; - // Do not advance buffer. Try SSL_write() again. + size += iov.iov_len; + + auto buffer = reinterpret_cast<const uint8_t*>(iov.iov_base); + const uint8_t* end = buffer + iov.iov_len; + while (buffer < end) { + size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); + auto [writeSize, errorQueue] = mSsl.call(SSL_write, buffer, todo); + if (writeSize > 0) { + buffer += writeSize; + errorQueue.clear(); + continue; + } + // SSL_write() should never return 0 unless BIO_write were to return 0. + int sslError = mSsl.getError(writeSize); + // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be + // triggerablePoll()-ed. Then additionalEvent is no longer necessary. + status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, + "SSL_write", POLLIN, altPoll); + if (pollStatus != OK) return pollStatus; + // Do not advance buffer. Try SSL_write() again. + } } LOG_TLS_DETAIL("TLS: Sent %zu bytes!", size); return OK; } -status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, +status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs, const std::function<status_t()>& altPoll) { - auto buffer = reinterpret_cast<uint8_t*>(data); - uint8_t* end = buffer + size; - MAYBE_WAIT_IN_FLAKE_MODE; // Before doing any I/O, check trigger once. This ensures the trigger is checked at least // once. The trigger is also checked via triggerablePoll() after every SSL_write(). if (fdTrigger->isTriggered()) return DEAD_OBJECT; - while (buffer < end) { - size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); - auto [readSize, errorQueue] = mSsl.call(SSL_read, buffer, todo); - if (readSize > 0) { - buffer += readSize; - errorQueue.clear(); + size_t size = 0; + for (size_t i = 0; i < niovs; i++) { + const iovec& iov = iovs[i]; + if (iov.iov_len == 0) { continue; } - if (readSize == 0) { - // SSL_read() only returns 0 on EOF. - errorQueue.clear(); - return DEAD_OBJECT; + size += iov.iov_len; + + auto buffer = reinterpret_cast<uint8_t*>(iov.iov_base); + const uint8_t* end = buffer + iov.iov_len; + while (buffer < end) { + size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); + auto [readSize, errorQueue] = mSsl.call(SSL_read, buffer, todo); + if (readSize > 0) { + buffer += readSize; + errorQueue.clear(); + continue; + } + if (readSize == 0) { + // SSL_read() only returns 0 on EOF. + errorQueue.clear(); + return DEAD_OBJECT; + } + int sslError = mSsl.getError(readSize); + status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, + "SSL_read", 0, altPoll); + if (pollStatus != OK) return pollStatus; + // Do not advance buffer. Try SSL_read() again. } - int sslError = mSsl.getError(readSize); - status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, - "SSL_read", 0, altPoll); - if (pollStatus != OK) return pollStatus; - // Do not advance buffer. Try SSL_read() again. } LOG_TLS_DETAIL("TLS: Received %zu bytes!", size); return OK; diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h index db8b5e920e..348bfebf15 100644 --- a/libs/binder/include/binder/RpcTransport.h +++ b/libs/binder/include/binder/RpcTransport.h @@ -28,6 +28,8 @@ #include <binder/RpcCertificateFormat.h> +#include <sys/uio.h> + namespace android { class FdTrigger; @@ -44,6 +46,9 @@ public: /** * Read (or write), but allow to be interrupted by a trigger. * + * iovs - array of iovecs to perform the operation on. The elements + * of the array may be modified by this method. + * * altPoll - function to be called instead of polling, when needing to wait * to read/write data. If this returns an error, that error is returned from * this function. @@ -53,10 +58,10 @@ public: * error - interrupted (failure or trigger) */ [[nodiscard]] virtual status_t interruptableWriteFully( - FdTrigger *fdTrigger, const void *buf, size_t size, + FdTrigger *fdTrigger, iovec *iovs, size_t niovs, const std::function<status_t()> &altPoll) = 0; [[nodiscard]] virtual status_t interruptableReadFully( - FdTrigger *fdTrigger, void *buf, size_t size, + FdTrigger *fdTrigger, iovec *iovs, size_t niovs, const std::function<status_t()> &altPoll) = 0; protected: diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index 5a96b7835c..ca68b99e35 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -1674,8 +1674,8 @@ public: static AssertionResult defaultPostConnect(RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(kMessage); - auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), - message.size(), {}); + iovec messageIov{message.data(), message.size()}; + auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, {}); if (status != OK) return AssertionFailure() << statusToString(status); return AssertionSuccess(); } @@ -1706,9 +1706,9 @@ public: AssertionResult readMessage(const std::string& expectedMessage = kMessage) { LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed"); std::string readMessage(expectedMessage.size(), '\0'); - status_t readStatus = - mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), - readMessage.size(), {}); + iovec readMessageIov{readMessage.data(), readMessage.size()}; + status_t readStatus = mClientTransport->interruptableReadFully(mFdTrigger.get(), + &readMessageIov, 1, {}); if (readStatus != OK) { return AssertionFailure() << statusToString(readStatus); } @@ -1902,8 +1902,8 @@ TEST_P(RpcTransportTest, Trigger) { bool shouldContinueWriting = false; auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(RpcTransportTestUtils::kMessage); - auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), - message.size(), {}); + iovec messageIov{message.data(), message.size()}; + auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, {}); if (status != OK) return AssertionFailure() << statusToString(status); { @@ -1913,7 +1913,8 @@ TEST_P(RpcTransportTest, Trigger) { } } - status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size(), {}); + iovec msg2Iov{msg2.data(), msg2.size()}; + status = serverTransport->interruptableWriteFully(fdTrigger, &msg2Iov, 1, {}); if (status != DEAD_OBJECT) return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully " "should return DEAD_OBJECT, but it is " |