diff options
| author | 2022-01-05 01:32:58 +0000 | |
|---|---|---|
| committer | 2022-01-05 01:32:58 +0000 | |
| commit | 2c25d94025fde6cf7ad80f7a799c7ff7f81b46b6 (patch) | |
| tree | eacdd4bbef7d0f7df0fdd3dcf2c7df85c2e3c57b | |
| parent | 2692efa2371af720c7d92fddcdb39e8031eea960 (diff) | |
| parent | e627a40c7e4bed19850ee556cb5c4c7300e5b099 (diff) | |
Merge "binder: Eliminate a data copy in RPC transport operations" am: 7dc506f27f am: 8e5a9c89be am: e627a40c7e
Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1917958
Change-Id: Idd582e849674945da732d57b004989fade91cddb
| -rw-r--r-- | libs/binder/RpcServer.cpp | 13 | ||||
| -rw-r--r-- | libs/binder/RpcSession.cpp | 8 | ||||
| -rw-r--r-- | libs/binder/RpcState.cpp | 115 | ||||
| -rw-r--r-- | libs/binder/RpcState.h | 10 | ||||
| -rw-r--r-- | libs/binder/RpcTransportRaw.cpp | 65 | ||||
| -rw-r--r-- | libs/binder/RpcTransportTls.cpp | 95 | ||||
| -rw-r--r-- | libs/binder/include/binder/RpcTransport.h | 9 | ||||
| -rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 17 | 
8 files changed, 185 insertions, 147 deletions
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index 93ed50e986..ace5cd5052 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -287,8 +287,8 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie      RpcConnectionHeader header;      if (status == OK) { -        status = client->interruptableReadFully(server->mShutdownTrigger.get(), &header, -                                                sizeof(header), {}); +        iovec iov{&header, sizeof(header)}; +        status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {});          if (status != OK) {              ALOGE("Failed to read ID for client connecting to RPC server: %s",                    statusToString(status).c_str()); @@ -301,8 +301,9 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie          if (header.sessionIdSize > 0) {              if (header.sessionIdSize == kSessionIdBytes) {                  sessionId.resize(header.sessionIdSize); -                status = client->interruptableReadFully(server->mShutdownTrigger.get(), -                                                        sessionId.data(), sessionId.size(), {}); +                iovec iov{sessionId.data(), sessionId.size()}; +                status = +                        client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {});                  if (status != OK) {                      ALOGE("Failed to read session ID for client connecting to RPC server: %s",                            statusToString(status).c_str()); @@ -331,8 +332,8 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie                      .version = protocolVersion,              }; -            status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &response, -                                                     sizeof(response), {}); +            iovec iov{&response, sizeof(response)}; +            status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1, {});              if (status != OK) {                  ALOGE("Failed to send new session response: %s", statusToString(status).c_str());                  // still need to cleanup before we can return diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp index a5a2bb1017..b84395e7cb 100644 --- a/libs/binder/RpcSession.cpp +++ b/libs/binder/RpcSession.cpp @@ -615,8 +615,9 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_          header.options |= RPC_CONNECTION_OPTION_INCOMING;      } +    iovec headerIov{&header, sizeof(header)};      auto sendHeaderStatus = -            server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header), {}); +            server->interruptableWriteFully(mShutdownTrigger.get(), &headerIov, 1, {});      if (sendHeaderStatus != OK) {          ALOGE("Could not write connection header to socket: %s",                statusToString(sendHeaderStatus).c_str()); @@ -624,9 +625,10 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_      }      if (sessionId.size() > 0) { +        iovec sessionIov{const_cast<void*>(static_cast<const void*>(sessionId.data())), +                         sessionId.size()};          auto sendSessionIdStatus = -                server->interruptableWriteFully(mShutdownTrigger.get(), sessionId.data(), -                                                sessionId.size(), {}); +                server->interruptableWriteFully(mShutdownTrigger.get(), &sessionIov, 1, {});          if (sendSessionIdStatus != OK) {              ALOGE("Could not write session ID ('%s') to socket: %s",                    base::HexString(sessionId.data(), sessionId.size()).c_str(), diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp index 09b3d68626..6286c9c7bc 100644 --- a/libs/binder/RpcState.cpp +++ b/libs/binder/RpcState.cpp @@ -19,6 +19,7 @@  #include "RpcState.h"  #include <android-base/hex.h> +#include <android-base/macros.h>  #include <android-base/scopeguard.h>  #include <binder/BpBinder.h>  #include <binder/IPCThreadState.h> @@ -309,22 +310,18 @@ RpcState::CommandData::CommandData(size_t size) : mSize(size) {  }  status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection, -                           const sp<RpcSession>& session, const char* what, const void* data, -                           size_t size, const std::function<status_t()>& altPoll) { -    LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), -                   android::base::HexString(data, size).c_str()); - -    if (size > std::numeric_limits<ssize_t>::max()) { -        ALOGE("Cannot send %s at size %zu (too big)", what, size); -        (void)session->shutdownAndWait(false); -        return BAD_VALUE; +                           const sp<RpcSession>& session, const char* what, iovec* iovs, +                           size_t niovs, const std::function<status_t()>& altPoll) { +    for (size_t i = 0; i < niovs; i++) { +        LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), +                       android::base::HexString(iovs[i].iov_base, iovs[i].iov_len).c_str());      }      if (status_t status =                  connection->rpcTransport->interruptableWriteFully(session->mShutdownTrigger.get(), -                                                                  data, size, altPoll); +                                                                  iovs, niovs, altPoll);          status != OK) { -        LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on RpcTransport %p, error: %s", what, size, +        LOG_RPC_DETAIL("Failed to write %s (%zu iovs) on RpcTransport %p, error: %s", what, niovs,                         connection->rpcTransport.get(), statusToString(status).c_str());          (void)session->shutdownAndWait(false);          return status; @@ -334,34 +331,30 @@ status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection,  }  status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection, -                          const sp<RpcSession>& session, const char* what, void* data, -                          size_t size) { -    if (size > std::numeric_limits<ssize_t>::max()) { -        ALOGE("Cannot rec %s at size %zu (too big)", what, size); -        (void)session->shutdownAndWait(false); -        return BAD_VALUE; -    } - +                          const sp<RpcSession>& session, const char* what, iovec* iovs, +                          size_t niovs) {      if (status_t status =                  connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(), -                                                                 data, size, {}); +                                                                 iovs, niovs, {});          status != OK) { -        LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on RpcTransport %p, error: %s", what, size, +        LOG_RPC_DETAIL("Failed to read %s (%zu iovs) on RpcTransport %p, error: %s", what, niovs,                         connection->rpcTransport.get(), statusToString(status).c_str());          (void)session->shutdownAndWait(false);          return status;      } -    LOG_RPC_DETAIL("Received %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), -                   android::base::HexString(data, size).c_str()); +    for (size_t i = 0; i < niovs; i++) { +        LOG_RPC_DETAIL("Received %s on RpcTransport %p: %s", what, connection->rpcTransport.get(), +                       android::base::HexString(iovs[i].iov_base, iovs[i].iov_len).c_str()); +    }      return OK;  }  status_t RpcState::readNewSessionResponse(const sp<RpcSession::RpcConnection>& connection,                                            const sp<RpcSession>& session, uint32_t* version) {      RpcNewSessionResponse response; -    if (status_t status = -                rpcRec(connection, session, "new session response", &response, sizeof(response)); +    iovec iov{&response, sizeof(response)}; +    if (status_t status = rpcRec(connection, session, "new session response", &iov, 1);          status != OK) {          return status;      } @@ -374,14 +367,15 @@ status_t RpcState::sendConnectionInit(const sp<RpcSession::RpcConnection>& conne      RpcOutgoingConnectionInit init{              .msg = RPC_CONNECTION_INIT_OKAY,      }; -    return rpcSend(connection, session, "connection init", &init, sizeof(init)); +    iovec iov{&init, sizeof(init)}; +    return rpcSend(connection, session, "connection init", &iov, 1);  }  status_t RpcState::readConnectionInit(const sp<RpcSession::RpcConnection>& connection,                                        const sp<RpcSession>& session) {      RpcOutgoingConnectionInit init; -    if (status_t status = rpcRec(connection, session, "connection init", &init, sizeof(init)); -        status != OK) +    iovec iov{&init, sizeof(init)}; +    if (status_t status = rpcRec(connection, session, "connection init", &iov, 1); status != OK)          return status;      static_assert(sizeof(init.msg) == sizeof(RPC_CONNECTION_INIT_OKAY)); @@ -514,17 +508,6 @@ status_t RpcState::transactAddress(const sp<RpcSession::RpcConnection>& connecti              .flags = flags,              .asyncNumber = asyncNumber,      }; -    CommandData transactionData(sizeof(RpcWireHeader) + sizeof(RpcWireTransaction) + -                                data.dataSize()); -    if (!transactionData.valid()) { -        return NO_MEMORY; -    } - -    memcpy(transactionData.data() + 0, &command, sizeof(RpcWireHeader)); -    memcpy(transactionData.data() + sizeof(RpcWireHeader), &transaction, -           sizeof(RpcWireTransaction)); -    memcpy(transactionData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireTransaction), data.data(), -           data.dataSize());      constexpr size_t kWaitMaxUs = 1000000;      constexpr size_t kWaitLogUs = 10000; @@ -550,8 +533,13 @@ status_t RpcState::transactAddress(const sp<RpcSession::RpcConnection>& connecti          return drainCommands(connection, session, CommandType::CONTROL_ONLY);      }; -    if (status_t status = rpcSend(connection, session, "transaction", transactionData.data(), -                                  transactionData.size(), drainRefs); +    iovec iovs[]{ +            {&command, sizeof(RpcWireHeader)}, +            {&transaction, sizeof(RpcWireTransaction)}, +            {const_cast<uint8_t*>(data.data()), data.dataSize()}, +    }; +    if (status_t status = +                rpcSend(connection, session, "transaction", iovs, arraysize(iovs), drainRefs);          status != OK) {          // TODO(b/167966510): need to undo onBinderLeaving - we know the          // refcount isn't successfully transferred. @@ -584,8 +572,8 @@ status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection,                                  const sp<RpcSession>& session, Parcel* reply) {      RpcWireHeader command;      while (true) { -        if (status_t status = rpcRec(connection, session, "command header (for reply)", &command, -                                     sizeof(command)); +        iovec iov{&command, sizeof(command)}; +        if (status_t status = rpcRec(connection, session, "command header (for reply)", &iov, 1);              status != OK)              return status; @@ -599,8 +587,8 @@ status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection,      CommandData data(command.bodySize);      if (!data.valid()) return NO_MEMORY; -    if (status_t status = rpcRec(connection, session, "reply body", data.data(), command.bodySize); -        status != OK) +    iovec iov{data.data(), command.bodySize}; +    if (status_t status = rpcRec(connection, session, "reply body", &iov, 1); status != OK)          return status;      if (command.bodySize < sizeof(RpcWireReply)) { @@ -653,11 +641,8 @@ status_t RpcState::sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& co              .command = RPC_COMMAND_DEC_STRONG,              .bodySize = sizeof(RpcDecStrong),      }; -    if (status_t status = rpcSend(connection, session, "dec ref header", &cmd, sizeof(cmd)); -        status != OK) -        return status; - -    return rpcSend(connection, session, "dec ref body", &body, sizeof(body)); +    iovec iovs[]{{&cmd, sizeof(cmd)}, {&body, sizeof(body)}}; +    return rpcSend(connection, session, "dec ref", iovs, arraysize(iovs));  }  status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& connection, @@ -665,8 +650,8 @@ status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& con      LOG_RPC_DETAIL("getAndExecuteCommand on RpcTransport %p", connection->rpcTransport.get());      RpcWireHeader command; -    if (status_t status = rpcRec(connection, session, "command header (for server)", &command, -                                 sizeof(command)); +    iovec iov{&command, sizeof(command)}; +    if (status_t status = rpcRec(connection, session, "command header (for server)", &iov, 1);          status != OK)          return status; @@ -726,9 +711,8 @@ status_t RpcState::processTransact(const sp<RpcSession::RpcConnection>& connecti      if (!transactionData.valid()) {          return NO_MEMORY;      } -    if (status_t status = rpcRec(connection, session, "transaction body", transactionData.data(), -                                 transactionData.size()); -        status != OK) +    iovec iov{transactionData.data(), transactionData.size()}; +    if (status_t status = rpcRec(connection, session, "transaction body", &iov, 1); status != OK)          return status;      return processTransactInternal(connection, session, std::move(transactionData)); @@ -965,16 +949,12 @@ processTransactInternalTailCall:              .status = replyStatus,      }; -    CommandData replyData(sizeof(RpcWireHeader) + sizeof(RpcWireReply) + reply.dataSize()); -    if (!replyData.valid()) { -        return NO_MEMORY; -    } -    memcpy(replyData.data() + 0, &cmdReply, sizeof(RpcWireHeader)); -    memcpy(replyData.data() + sizeof(RpcWireHeader), &rpcReply, sizeof(RpcWireReply)); -    memcpy(replyData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireReply), reply.data(), -           reply.dataSize()); - -    return rpcSend(connection, session, "reply", replyData.data(), replyData.size()); +    iovec iovs[]{ +            {&cmdReply, sizeof(RpcWireHeader)}, +            {&rpcReply, sizeof(RpcWireReply)}, +            {const_cast<uint8_t*>(reply.data()), reply.dataSize()}, +    }; +    return rpcSend(connection, session, "reply", iovs, arraysize(iovs));  }  status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connection, @@ -985,9 +965,8 @@ status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connect      if (!commandData.valid()) {          return NO_MEMORY;      } -    if (status_t status = -                rpcRec(connection, session, "dec ref body", commandData.data(), commandData.size()); -        status != OK) +    iovec iov{commandData.data(), commandData.size()}; +    if (status_t status = rpcRec(connection, session, "dec ref body", &iov, 1); status != OK)          return status;      if (command.bodySize != sizeof(RpcDecStrong)) { diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h index dba0a43060..5cad394a2f 100644 --- a/libs/binder/RpcState.h +++ b/libs/binder/RpcState.h @@ -24,6 +24,8 @@  #include <optional>  #include <queue> +#include <sys/uio.h> +  namespace android {  struct RpcWireHeader; @@ -177,12 +179,12 @@ private:      };      [[nodiscard]] status_t rpcSend(const sp<RpcSession::RpcConnection>& connection, -                                   const sp<RpcSession>& session, const char* what, -                                   const void* data, size_t size, +                                   const sp<RpcSession>& session, const char* what, iovec* iovs, +                                   size_t niovs,                                     const std::function<status_t()>& altPoll = nullptr);      [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection, -                                  const sp<RpcSession>& session, const char* what, void* data, -                                  size_t size); +                                  const sp<RpcSession>& session, const char* what, iovec* iovs, +                                  size_t niovs);      [[nodiscard]] status_t waitForReply(const sp<RpcSession::RpcConnection>& connection,                                          const sp<RpcSession>& session, Parcel* reply); diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp index 7669518954..2182e1868e 100644 --- a/libs/binder/RpcTransportRaw.cpp +++ b/libs/binder/RpcTransportRaw.cpp @@ -43,12 +43,10 @@ public:          return ret;      } -    template <typename Buffer, typename SendOrReceive> -    status_t interruptableReadOrWrite(FdTrigger* fdTrigger, Buffer buffer, size_t size, +    template <typename SendOrReceive> +    status_t interruptableReadOrWrite(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                        SendOrReceive sendOrReceiveFun, const char* funName,                                        int16_t event, const std::function<status_t()>& altPoll) { -        const Buffer end = buffer + size; -          MAYBE_WAIT_IN_FLAKE_MODE;          // Since we didn't poll, we need to manually check to see if it was triggered. Otherwise, we @@ -57,26 +55,61 @@ public:              return DEAD_OBJECT;          } +        // If iovs has one or more empty vectors at the end and +        // we somehow advance past all the preceding vectors and +        // pass some or all of the empty ones to sendmsg/recvmsg, +        // the call will return processSize == 0. In that case +        // we should be returning OK but instead return DEAD_OBJECT. +        // To avoid this problem, we make sure here that the last +        // vector at iovs[niovs - 1] has a non-zero length. +        while (niovs > 0 && iovs[niovs - 1].iov_len == 0) { +            niovs--; +        } +        if (niovs == 0) { +            // The vectors are all empty, so we have nothing to send. +            return OK; +        } +          bool havePolled = false;          while (true) { -            ssize_t processSize = TEMP_FAILURE_RETRY( -                    sendOrReceiveFun(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL)); +            msghdr msg{ +                    .msg_iov = iovs, +                    .msg_iovlen = niovs, +            }; +            ssize_t processSize = +                    TEMP_FAILURE_RETRY(sendOrReceiveFun(mSocket.get(), &msg, MSG_NOSIGNAL));              if (processSize < 0) {                  int savedErrno = errno;                  // Still return the error on later passes, since it would expose                  // a problem with polling -                if (havePolled || -                    (!havePolled && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) { +                if (havePolled || (savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) {                      LOG_RPC_DETAIL("RpcTransport %s(): %s", funName, strerror(savedErrno));                      return -savedErrno;                  }              } else if (processSize == 0) {                  return DEAD_OBJECT;              } else { -                buffer += processSize; -                if (buffer == end) { +                while (processSize > 0 && niovs > 0) { +                    auto& iov = iovs[0]; +                    if (static_cast<size_t>(processSize) < iov.iov_len) { +                        // Advance the base of the current iovec +                        iov.iov_base = reinterpret_cast<char*>(iov.iov_base) + processSize; +                        iov.iov_len -= processSize; +                        break; +                    } + +                    // The current iovec was fully written +                    processSize -= iov.iov_len; +                    iovs++; +                    niovs--; +                } +                if (niovs == 0) { +                    LOG_ALWAYS_FATAL_IF(processSize > 0, +                                        "Reached the end of iovecs " +                                        "with %zd bytes remaining", +                                        processSize);                      return OK;                  }              } @@ -95,16 +128,16 @@ public:          }      } -    status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size, +    status_t interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                       const std::function<status_t()>& altPoll) override { -        return interruptableReadOrWrite(fdTrigger, reinterpret_cast<const uint8_t*>(data), size, -                                        send, "send", POLLOUT, altPoll); +        return interruptableReadOrWrite(fdTrigger, iovs, niovs, sendmsg, "sendmsg", POLLOUT, +                                        altPoll);      } -    status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, +    status_t interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                      const std::function<status_t()>& altPoll) override { -        return interruptableReadOrWrite(fdTrigger, reinterpret_cast<uint8_t*>(data), size, recv, -                                        "recv", POLLIN, altPoll); +        return interruptableReadOrWrite(fdTrigger, iovs, niovs, recvmsg, "recvmsg", POLLIN, +                                        altPoll);      }  private: diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp index 7f810b17ba..c05ea1512f 100644 --- a/libs/binder/RpcTransportTls.cpp +++ b/libs/binder/RpcTransportTls.cpp @@ -275,9 +275,9 @@ public:      RpcTransportTls(android::base::unique_fd socket, Ssl ssl)            : mSocket(std::move(socket)), mSsl(std::move(ssl)) {}      Result<size_t> peek(void* buf, size_t size) override; -    status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size, +    status_t interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                       const std::function<status_t()>& altPoll) override; -    status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, +    status_t interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                      const std::function<status_t()>& altPoll) override;  private: @@ -303,68 +303,83 @@ Result<size_t> RpcTransportTls::peek(void* buf, size_t size) {      return ret;  } -status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data, -                                                  size_t size, +status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                                    const std::function<status_t()>& altPoll) { -    auto buffer = reinterpret_cast<const uint8_t*>(data); -    const uint8_t* end = buffer + size; -      MAYBE_WAIT_IN_FLAKE_MODE;      // Before doing any I/O, check trigger once. This ensures the trigger is checked at least      // once. The trigger is also checked via triggerablePoll() after every SSL_write().      if (fdTrigger->isTriggered()) return DEAD_OBJECT; -    while (buffer < end) { -        size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); -        auto [writeSize, errorQueue] = mSsl.call(SSL_write, buffer, todo); -        if (writeSize > 0) { -            buffer += writeSize; -            errorQueue.clear(); +    size_t size = 0; +    for (size_t i = 0; i < niovs; i++) { +        const iovec& iov = iovs[i]; +        if (iov.iov_len == 0) {              continue;          } -        // SSL_write() should never return 0 unless BIO_write were to return 0. -        int sslError = mSsl.getError(writeSize); -        // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be -        //   triggerablePoll()-ed. Then additionalEvent is no longer necessary. -        status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, -                                                         "SSL_write", POLLIN, altPoll); -        if (pollStatus != OK) return pollStatus; -        // Do not advance buffer. Try SSL_write() again. +        size += iov.iov_len; + +        auto buffer = reinterpret_cast<const uint8_t*>(iov.iov_base); +        const uint8_t* end = buffer + iov.iov_len; +        while (buffer < end) { +            size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); +            auto [writeSize, errorQueue] = mSsl.call(SSL_write, buffer, todo); +            if (writeSize > 0) { +                buffer += writeSize; +                errorQueue.clear(); +                continue; +            } +            // SSL_write() should never return 0 unless BIO_write were to return 0. +            int sslError = mSsl.getError(writeSize); +            // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be +            //   triggerablePoll()-ed. Then additionalEvent is no longer necessary. +            status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, +                                                             "SSL_write", POLLIN, altPoll); +            if (pollStatus != OK) return pollStatus; +            // Do not advance buffer. Try SSL_write() again. +        }      }      LOG_TLS_DETAIL("TLS: Sent %zu bytes!", size);      return OK;  } -status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size, +status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, size_t niovs,                                                   const std::function<status_t()>& altPoll) { -    auto buffer = reinterpret_cast<uint8_t*>(data); -    uint8_t* end = buffer + size; -      MAYBE_WAIT_IN_FLAKE_MODE;      // Before doing any I/O, check trigger once. This ensures the trigger is checked at least      // once. The trigger is also checked via triggerablePoll() after every SSL_write().      if (fdTrigger->isTriggered()) return DEAD_OBJECT; -    while (buffer < end) { -        size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); -        auto [readSize, errorQueue] = mSsl.call(SSL_read, buffer, todo); -        if (readSize > 0) { -            buffer += readSize; -            errorQueue.clear(); +    size_t size = 0; +    for (size_t i = 0; i < niovs; i++) { +        const iovec& iov = iovs[i]; +        if (iov.iov_len == 0) {              continue;          } -        if (readSize == 0) { -            // SSL_read() only returns 0 on EOF. -            errorQueue.clear(); -            return DEAD_OBJECT; +        size += iov.iov_len; + +        auto buffer = reinterpret_cast<uint8_t*>(iov.iov_base); +        const uint8_t* end = buffer + iov.iov_len; +        while (buffer < end) { +            size_t todo = std::min<size_t>(end - buffer, std::numeric_limits<int>::max()); +            auto [readSize, errorQueue] = mSsl.call(SSL_read, buffer, todo); +            if (readSize > 0) { +                buffer += readSize; +                errorQueue.clear(); +                continue; +            } +            if (readSize == 0) { +                // SSL_read() only returns 0 on EOF. +                errorQueue.clear(); +                return DEAD_OBJECT; +            } +            int sslError = mSsl.getError(readSize); +            status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, +                                                             "SSL_read", 0, altPoll); +            if (pollStatus != OK) return pollStatus; +            // Do not advance buffer. Try SSL_read() again.          } -        int sslError = mSsl.getError(readSize); -        status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, -                                                         "SSL_read", 0, altPoll); -        if (pollStatus != OK) return pollStatus; -        // Do not advance buffer. Try SSL_read() again.      }      LOG_TLS_DETAIL("TLS: Received %zu bytes!", size);      return OK; diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h index db8b5e920e..348bfebf15 100644 --- a/libs/binder/include/binder/RpcTransport.h +++ b/libs/binder/include/binder/RpcTransport.h @@ -28,6 +28,8 @@  #include <binder/RpcCertificateFormat.h> +#include <sys/uio.h> +  namespace android {  class FdTrigger; @@ -44,6 +46,9 @@ public:      /**       * Read (or write), but allow to be interrupted by a trigger.       * +     * iovs - array of iovecs to perform the operation on. The elements +     * of the array may be modified by this method. +     *       * altPoll - function to be called instead of polling, when needing to wait       * to read/write data. If this returns an error, that error is returned from       * this function. @@ -53,10 +58,10 @@ public:       *   error - interrupted (failure or trigger)       */      [[nodiscard]] virtual status_t interruptableWriteFully( -            FdTrigger *fdTrigger, const void *buf, size_t size, +            FdTrigger *fdTrigger, iovec *iovs, size_t niovs,              const std::function<status_t()> &altPoll) = 0;      [[nodiscard]] virtual status_t interruptableReadFully( -            FdTrigger *fdTrigger, void *buf, size_t size, +            FdTrigger *fdTrigger, iovec *iovs, size_t niovs,              const std::function<status_t()> &altPoll) = 0;  protected: diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index 5a96b7835c..ca68b99e35 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -1674,8 +1674,8 @@ public:          static AssertionResult defaultPostConnect(RpcTransport* serverTransport,                                                    FdTrigger* fdTrigger) {              std::string message(kMessage); -            auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), -                                                                   message.size(), {}); +            iovec messageIov{message.data(), message.size()}; +            auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, {});              if (status != OK) return AssertionFailure() << statusToString(status);              return AssertionSuccess();          } @@ -1706,9 +1706,9 @@ public:          AssertionResult readMessage(const std::string& expectedMessage = kMessage) {              LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");              std::string readMessage(expectedMessage.size(), '\0'); -            status_t readStatus = -                    mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), -                                                             readMessage.size(), {}); +            iovec readMessageIov{readMessage.data(), readMessage.size()}; +            status_t readStatus = mClientTransport->interruptableReadFully(mFdTrigger.get(), +                                                                           &readMessageIov, 1, {});              if (readStatus != OK) {                  return AssertionFailure() << statusToString(readStatus);              } @@ -1902,8 +1902,8 @@ TEST_P(RpcTransportTest, Trigger) {      bool shouldContinueWriting = false;      auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {          std::string message(RpcTransportTestUtils::kMessage); -        auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), -                                                               message.size(), {}); +        iovec messageIov{message.data(), message.size()}; +        auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, {});          if (status != OK) return AssertionFailure() << statusToString(status);          { @@ -1913,7 +1913,8 @@ TEST_P(RpcTransportTest, Trigger) {              }          } -        status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size(), {}); +        iovec msg2Iov{msg2.data(), msg2.size()}; +        status = serverTransport->interruptableWriteFully(fdTrigger, &msg2Iov, 1, {});          if (status != DEAD_OBJECT)              return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "                                           "should return DEAD_OBJECT, but it is "  |