diff options
author | 2022-08-03 21:19:11 +0000 | |
---|---|---|
committer | 2022-08-26 08:41:15 +0000 | |
commit | 49d74cbbbd75ade6b7002cad417bd168b9cd0f0e (patch) | |
tree | ee9cea6b677f66d60a89712392430895249bed89 | |
parent | ec646539825294a4f33f94857d224a1946fb540f (diff) |
libbinder : Adding new type TransportFd
Adding a new struct TransportFd which will contain unique_fd and
polling state of file descriptor. This will be useful in detecting
if all the descriptors are being polled. unique_fd and borrowed_fd
are replaced in these changes.
Test: m
Test: m libbinder binderRpcTest && atest binderRpcTest
Test: trusty/vendor/google/aosp/scripts/build.py --test
"boot-test:com.android.trusty.binder.test" qemu-generic-arm64-test-debug
Bug: 218518615
Change-Id: Id108806b98184582e5d93186b3b1884017c441ea
-rw-r--r-- | libs/binder/FdTrigger.cpp | 18 | ||||
-rw-r--r-- | libs/binder/FdTrigger.h | 4 | ||||
-rw-r--r-- | libs/binder/RpcServer.cpp | 38 | ||||
-rw-r--r-- | libs/binder/RpcSession.cpp | 27 | ||||
-rw-r--r-- | libs/binder/RpcTransportRaw.cpp | 28 | ||||
-rw-r--r-- | libs/binder/RpcTransportTipcAndroid.cpp | 24 | ||||
-rw-r--r-- | libs/binder/RpcTransportTls.cpp | 31 | ||||
-rw-r--r-- | libs/binder/RpcTransportUtils.h | 6 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcServer.h | 6 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcSession.h | 2 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcTransport.h | 45 | ||||
-rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 65 | ||||
-rw-r--r-- | libs/binder/trusty/RpcServerTrusty.cpp | 4 | ||||
-rw-r--r-- | libs/binder/trusty/RpcTransportTipcTrusty.cpp | 24 |
14 files changed, 222 insertions, 100 deletions
diff --git a/libs/binder/FdTrigger.cpp b/libs/binder/FdTrigger.cpp index d123fd1f2b..256d587866 100644 --- a/libs/binder/FdTrigger.cpp +++ b/libs/binder/FdTrigger.cpp @@ -22,6 +22,7 @@ #include <poll.h> #include <android-base/macros.h> +#include <android-base/scopeguard.h> #include "RpcState.h" namespace android { @@ -53,25 +54,34 @@ bool FdTrigger::isTriggered() { #endif } -status_t FdTrigger::triggerablePoll(base::borrowed_fd fd, int16_t event) { +status_t FdTrigger::triggerablePoll(const android::TransportFd& transportFd, int16_t event) { #ifdef BINDER_RPC_SINGLE_THREADED if (mTriggered) { return DEAD_OBJECT; } #endif - LOG_ALWAYS_FATAL_IF(event == 0, "triggerablePoll %d with event 0 is not allowed", fd.get()); + LOG_ALWAYS_FATAL_IF(event == 0, "triggerablePoll %d with event 0 is not allowed", + transportFd.fd.get()); pollfd pfd[]{ - {.fd = fd.get(), .events = static_cast<int16_t>(event), .revents = 0}, + {.fd = transportFd.fd.get(), .events = static_cast<int16_t>(event), .revents = 0}, #ifndef BINDER_RPC_SINGLE_THREADED {.fd = mRead.get(), .events = 0, .revents = 0}, #endif }; + + LOG_ALWAYS_FATAL_IF(transportFd.isInPollingState() == true, + "Only one thread should be polling on Fd!"); + + transportFd.setPollingState(true); + auto pollingStateGuard = + android::base::make_scope_guard([&]() { transportFd.setPollingState(false); }); + int ret = TEMP_FAILURE_RETRY(poll(pfd, arraysize(pfd), -1)); if (ret < 0) { return -errno; } - LOG_ALWAYS_FATAL_IF(ret == 0, "poll(%d) returns 0 with infinite timeout", fd.get()); + LOG_ALWAYS_FATAL_IF(ret == 0, "poll(%d) returns 0 with infinite timeout", transportFd.fd.get()); // At least one FD has events. Check them. diff --git a/libs/binder/FdTrigger.h b/libs/binder/FdTrigger.h index a25dc117f3..14a2749b8a 100644 --- a/libs/binder/FdTrigger.h +++ b/libs/binder/FdTrigger.h @@ -21,6 +21,8 @@ #include <android-base/unique_fd.h> #include <utils/Errors.h> +#include <binder/RpcTransport.h> + namespace android { /** This is not a pipe. */ @@ -53,7 +55,7 @@ public: * true - time to read! * false - trigger happened */ - [[nodiscard]] status_t triggerablePoll(base::borrowed_fd fd, int16_t event); + [[nodiscard]] status_t triggerablePoll(const android::TransportFd& transportFd, int16_t event); private: #ifdef BINDER_RPC_SINGLE_THREADED diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index 49be4dd9eb..3fc6612f9f 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -86,7 +86,7 @@ status_t RpcServer::setupInetServer(const char* address, unsigned int port, LOG_ALWAYS_FATAL_IF(socketAddress.addr()->sa_family != AF_INET, "expecting inet"); sockaddr_in addr{}; socklen_t len = sizeof(addr); - if (0 != getsockname(mServer.get(), reinterpret_cast<sockaddr*>(&addr), &len)) { + if (0 != getsockname(mServer.fd.get(), reinterpret_cast<sockaddr*>(&addr), &len)) { int savedErrno = errno; ALOGE("Could not getsockname at %s: %s", socketAddress.toString().c_str(), strerror(savedErrno)); @@ -181,7 +181,7 @@ void RpcServer::join() { { RpcMutexLockGuard _l(mLock); - LOG_ALWAYS_FATAL_IF(!mServer.ok(), "RpcServer must be setup to join."); + LOG_ALWAYS_FATAL_IF(!mServer.fd.ok(), "RpcServer must be setup to join."); LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined"); mJoinThreadRunning = true; mShutdownTrigger = FdTrigger::make(); @@ -194,24 +194,24 @@ void RpcServer::join() { static_assert(addr.size() >= sizeof(sockaddr_storage), "kRpcAddressSize is too small"); socklen_t addrLen = addr.size(); - unique_fd clientFd( - TEMP_FAILURE_RETRY(accept4(mServer.get(), reinterpret_cast<sockaddr*>(addr.data()), - &addrLen, SOCK_CLOEXEC | SOCK_NONBLOCK))); + TransportFd clientSocket(unique_fd(TEMP_FAILURE_RETRY( + accept4(mServer.fd.get(), reinterpret_cast<sockaddr*>(addr.data()), &addrLen, + SOCK_CLOEXEC | SOCK_NONBLOCK)))); LOG_ALWAYS_FATAL_IF(addrLen > static_cast<socklen_t>(sizeof(sockaddr_storage)), "Truncated address"); - if (clientFd < 0) { + if (clientSocket.fd < 0) { ALOGE("Could not accept4 socket: %s", strerror(errno)); continue; } - LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.get(), clientFd.get()); + LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.fd.get(), clientSocket.fd.get()); { RpcMutexLockGuard _l(mLock); RpcMaybeThread thread = RpcMaybeThread(&RpcServer::establishConnection, - sp<RpcServer>::fromExisting(this), std::move(clientFd), addr, + sp<RpcServer>::fromExisting(this), std::move(clientSocket), addr, addrLen, RpcSession::join); auto& threadRef = mConnectingThreads[thread.get_id()]; @@ -296,7 +296,7 @@ size_t RpcServer::numUninitializedSessions() { } void RpcServer::establishConnection( - sp<RpcServer>&& server, base::unique_fd clientFd, std::array<uint8_t, kRpcAddressSize> addr, + sp<RpcServer>&& server, TransportFd clientFd, std::array<uint8_t, kRpcAddressSize> addr, size_t addrLen, std::function<void(sp<RpcSession>&&, RpcSession::PreJoinSetupResult&&)>&& joinFn) { // mShutdownTrigger can only be cleared once connection threads have joined. @@ -306,7 +306,7 @@ void RpcServer::establishConnection( status_t status = OK; - int clientFdForLog = clientFd.get(); + int clientFdForLog = clientFd.fd.get(); auto client = server->mCtx->newTransport(std::move(clientFd), server->mShutdownTrigger.get()); if (client == nullptr) { ALOGE("Dropping accept4()-ed socket because sslAccept fails"); @@ -488,15 +488,15 @@ status_t RpcServer::setupSocketServer(const RpcSocketAddress& addr) { LOG_RPC_DETAIL("Setting up socket server %s", addr.toString().c_str()); LOG_ALWAYS_FATAL_IF(hasServer(), "Each RpcServer can only have one server."); - unique_fd serverFd(TEMP_FAILURE_RETRY( - socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0))); - if (serverFd == -1) { + TransportFd transportFd(unique_fd(TEMP_FAILURE_RETRY( + socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0)))); + if (!transportFd.fd.ok()) { int savedErrno = errno; ALOGE("Could not create socket: %s", strerror(savedErrno)); return -savedErrno; } - if (0 != TEMP_FAILURE_RETRY(bind(serverFd.get(), addr.addr(), addr.addrSize()))) { + if (0 != TEMP_FAILURE_RETRY(bind(transportFd.fd.get(), addr.addr(), addr.addrSize()))) { int savedErrno = errno; ALOGE("Could not bind socket at %s: %s", addr.toString().c_str(), strerror(savedErrno)); return -savedErrno; @@ -506,7 +506,7 @@ status_t RpcServer::setupSocketServer(const RpcSocketAddress& addr) { // the backlog is increased to a large number. // TODO(b/189955605): Once we create threads dynamically & lazily, the backlog can be reduced // to 1. - if (0 != TEMP_FAILURE_RETRY(listen(serverFd.get(), 50 /*backlog*/))) { + if (0 != TEMP_FAILURE_RETRY(listen(transportFd.fd.get(), 50 /*backlog*/))) { int savedErrno = errno; ALOGE("Could not listen socket at %s: %s", addr.toString().c_str(), strerror(savedErrno)); return -savedErrno; @@ -514,7 +514,7 @@ status_t RpcServer::setupSocketServer(const RpcSocketAddress& addr) { LOG_RPC_DETAIL("Successfully setup socket server %s", addr.toString().c_str()); - if (status_t status = setupExternalServer(std::move(serverFd)); status != OK) { + if (status_t status = setupExternalServer(std::move(transportFd.fd)); status != OK) { ALOGE("Another thread has set up server while calling setupSocketServer. Race?"); return status; } @@ -542,17 +542,17 @@ void RpcServer::onSessionIncomingThreadEnded() { bool RpcServer::hasServer() { RpcMutexLockGuard _l(mLock); - return mServer.ok(); + return mServer.fd.ok(); } unique_fd RpcServer::releaseServer() { RpcMutexLockGuard _l(mLock); - return std::move(mServer); + return std::move(mServer.fd); } status_t RpcServer::setupExternalServer(base::unique_fd serverFd) { RpcMutexLockGuard _l(mLock); - if (mServer.ok()) { + if (mServer.fd.ok()) { ALOGE("Each RpcServer can only have one server."); return INVALID_OPERATION; } diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp index 8ddfa93c00..c05a177be4 100644 --- a/libs/binder/RpcSession.cpp +++ b/libs/binder/RpcSession.cpp @@ -162,7 +162,8 @@ status_t RpcSession::setupInetClient(const char* addr, unsigned int port) { return NAME_NOT_FOUND; } -status_t RpcSession::setupPreconnectedClient(unique_fd fd, std::function<unique_fd()>&& request) { +status_t RpcSession::setupPreconnectedClient(base::unique_fd fd, + std::function<unique_fd()>&& request) { return setupClient([&](const std::vector<uint8_t>& sessionId, bool incoming) -> status_t { if (!fd.ok()) { fd = request(); @@ -172,7 +173,9 @@ status_t RpcSession::setupPreconnectedClient(unique_fd fd, std::function<unique_ ALOGE("setupPreconnectedClient: %s", res.error().message().c_str()); return res.error().code() == 0 ? UNKNOWN_ERROR : -res.error().code(); } - status_t status = initAndAddConnection(std::move(fd), sessionId, incoming); + + TransportFd transportFd(std::move(fd)); + status_t status = initAndAddConnection(std::move(transportFd), sessionId, incoming); fd = unique_fd(); // Explicitly reset after move to avoid analyzer warning. return status; }); @@ -190,7 +193,8 @@ status_t RpcSession::addNullDebuggingClient() { return -savedErrno; } - auto server = mCtx->newTransport(std::move(serverFd), mShutdownTrigger.get()); + TransportFd transportFd(std::move(serverFd)); + auto server = mCtx->newTransport(std::move(transportFd), mShutdownTrigger.get()); if (server == nullptr) { ALOGE("Unable to set up RpcTransport"); return UNKNOWN_ERROR; @@ -572,12 +576,14 @@ status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, return -savedErrno; } - if (0 != TEMP_FAILURE_RETRY(connect(serverFd.get(), addr.addr(), addr.addrSize()))) { + TransportFd transportFd(std::move(serverFd)); + + if (0 != TEMP_FAILURE_RETRY(connect(transportFd.fd.get(), addr.addr(), addr.addrSize()))) { int connErrno = errno; if (connErrno == EAGAIN || connErrno == EINPROGRESS) { // For non-blocking sockets, connect() may return EAGAIN (for unix domain socket) or // EINPROGRESS (for others). Call poll() and getsockopt() to get the error. - status_t pollStatus = mShutdownTrigger->triggerablePoll(serverFd, POLLOUT); + status_t pollStatus = mShutdownTrigger->triggerablePoll(transportFd, POLLOUT); if (pollStatus != OK) { ALOGE("Could not POLLOUT after connect() on non-blocking socket: %s", statusToString(pollStatus).c_str()); @@ -585,8 +591,8 @@ status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, } // Set connErrno to the errno that connect() would have set if the fd were blocking. socklen_t connErrnoLen = sizeof(connErrno); - int ret = - getsockopt(serverFd.get(), SOL_SOCKET, SO_ERROR, &connErrno, &connErrnoLen); + int ret = getsockopt(transportFd.fd.get(), SOL_SOCKET, SO_ERROR, &connErrno, + &connErrnoLen); if (ret == -1) { int savedErrno = errno; ALOGE("Could not getsockopt() after connect() on non-blocking socket: %s. " @@ -608,16 +614,17 @@ status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, return -connErrno; } } - LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get()); + LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), + transportFd.fd.get()); - return initAndAddConnection(std::move(serverFd), sessionId, incoming); + return initAndAddConnection(std::move(transportFd), sessionId, incoming); } ALOGE("Ran out of retries to connect to %s", addr.toString().c_str()); return UNKNOWN_ERROR; } -status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_t>& sessionId, +status_t RpcSession::initAndAddConnection(TransportFd fd, const std::vector<uint8_t>& sessionId, bool incoming) { LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr); auto server = mCtx->newTransport(std::move(fd), mShutdownTrigger.get()); diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp index 51326f6346..59e6869a1e 100644 --- a/libs/binder/RpcTransportRaw.cpp +++ b/libs/binder/RpcTransportRaw.cpp @@ -36,11 +36,11 @@ constexpr size_t kMaxFdsPerMsg = 253; // RpcTransport with TLS disabled. class RpcTransportRaw : public RpcTransport { public: - explicit RpcTransportRaw(android::base::unique_fd socket) : mSocket(std::move(socket)) {} + explicit RpcTransportRaw(android::TransportFd socket) : mSocket(std::move(socket)) {} status_t pollRead(void) override { uint8_t buf; ssize_t ret = TEMP_FAILURE_RETRY( - ::recv(mSocket.get(), &buf, sizeof(buf), MSG_PEEK | MSG_DONTWAIT)); + ::recv(mSocket.fd.get(), &buf, sizeof(buf), MSG_PEEK | MSG_DONTWAIT)); if (ret < 0) { int savedErrno = errno; if (savedErrno == EAGAIN || savedErrno == EWOULDBLOCK) { @@ -100,7 +100,7 @@ public: msg.msg_controllen = CMSG_SPACE(fdsByteSize); ssize_t processedSize = TEMP_FAILURE_RETRY( - sendmsg(mSocket.get(), &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC)); + sendmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC)); if (processedSize > 0) { sentFds = true; } @@ -113,10 +113,10 @@ public: // non-negative int and can be cast to either. .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs), }; - return TEMP_FAILURE_RETRY(sendmsg(mSocket.get(), &msg, MSG_NOSIGNAL)); + return TEMP_FAILURE_RETRY(sendmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL)); }; - return interruptableReadOrWrite(mSocket.get(), fdTrigger, iovs, niovs, send, "sendmsg", - POLLOUT, altPoll); + return interruptableReadOrWrite(mSocket, fdTrigger, iovs, niovs, send, "sendmsg", POLLOUT, + altPoll); } status_t interruptableReadFully( @@ -135,7 +135,7 @@ public: .msg_controllen = sizeof(msgControlBuf), }; ssize_t processSize = - TEMP_FAILURE_RETRY(recvmsg(mSocket.get(), &msg, MSG_NOSIGNAL)); + TEMP_FAILURE_RETRY(recvmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL)); if (processSize < 0) { return -1; } @@ -171,21 +171,23 @@ public: // non-negative int and can be cast to either. .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs), }; - return TEMP_FAILURE_RETRY(recvmsg(mSocket.get(), &msg, MSG_NOSIGNAL)); + return TEMP_FAILURE_RETRY(recvmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL)); }; - return interruptableReadOrWrite(mSocket.get(), fdTrigger, iovs, niovs, recv, "recvmsg", - POLLIN, altPoll); + return interruptableReadOrWrite(mSocket, fdTrigger, iovs, niovs, recv, "recvmsg", POLLIN, + altPoll); } + virtual bool isWaiting() { return mSocket.isInPollingState(); } + private: - base::unique_fd mSocket; + android::TransportFd mSocket; }; // RpcTransportCtx with TLS disabled. class RpcTransportCtxRaw : public RpcTransportCtx { public: - std::unique_ptr<RpcTransport> newTransport(android::base::unique_fd fd, FdTrigger*) const { - return std::make_unique<RpcTransportRaw>(std::move(fd)); + std::unique_ptr<RpcTransport> newTransport(android::TransportFd socket, FdTrigger*) const { + return std::make_unique<RpcTransportRaw>(std::move(socket)); } std::vector<uint8_t> getCertificate(RpcCertificateFormat) const override { return {}; } }; diff --git a/libs/binder/RpcTransportTipcAndroid.cpp b/libs/binder/RpcTransportTipcAndroid.cpp index c82201b28b..2e7e931040 100644 --- a/libs/binder/RpcTransportTipcAndroid.cpp +++ b/libs/binder/RpcTransportTipcAndroid.cpp @@ -36,8 +36,7 @@ namespace { // RpcTransport for writing Trusty IPC clients in Android. class RpcTransportTipcAndroid : public RpcTransport { public: - explicit RpcTransportTipcAndroid(android::base::unique_fd socket) - : mSocket(std::move(socket)) {} + explicit RpcTransportTipcAndroid(android::TransportFd socket) : mSocket(std::move(socket)) {} status_t pollRead() override { if (mReadBufferPos < mReadBufferSize) { @@ -46,7 +45,7 @@ public: } // Trusty IPC device is not a socket, so MSG_PEEK is not available - pollfd pfd{.fd = mSocket.get(), .events = static_cast<int16_t>(POLLIN), .revents = 0}; + pollfd pfd{.fd = mSocket.fd.get(), .events = static_cast<int16_t>(POLLIN), .revents = 0}; ssize_t ret = TEMP_FAILURE_RETRY(::poll(&pfd, 1, 0)); if (ret < 0) { int savedErrno = errno; @@ -84,9 +83,9 @@ public: // to send any. LOG_ALWAYS_FATAL_IF(ancillaryFds != nullptr && !ancillaryFds->empty(), "File descriptors are not supported on Trusty yet"); - return TEMP_FAILURE_RETRY(tipc_send(mSocket.get(), iovs, niovs, nullptr, 0)); + return TEMP_FAILURE_RETRY(tipc_send(mSocket.fd.get(), iovs, niovs, nullptr, 0)); }; - return interruptableReadOrWrite(mSocket.get(), fdTrigger, iovs, niovs, writeFn, "tipc_send", + return interruptableReadOrWrite(mSocket, fdTrigger, iovs, niovs, writeFn, "tipc_send", POLLOUT, altPoll); } @@ -120,10 +119,12 @@ public: return processSize; }; - return interruptableReadOrWrite(mSocket.get(), fdTrigger, iovs, niovs, readFn, "read", - POLLIN, altPoll); + return interruptableReadOrWrite(mSocket, fdTrigger, iovs, niovs, readFn, "read", POLLIN, + altPoll); } + bool isWaiting() override { return mSocket.isInPollingState(); } + private: status_t fillReadBuffer() { if (mReadBufferPos < mReadBufferSize) { @@ -146,8 +147,8 @@ private: mReadBufferSize = 0; while (true) { - ssize_t processSize = - TEMP_FAILURE_RETRY(read(mSocket.get(), mReadBuffer.get(), mReadBufferCapacity)); + ssize_t processSize = TEMP_FAILURE_RETRY( + read(mSocket.fd.get(), mReadBuffer.get(), mReadBufferCapacity)); if (processSize == 0) { return DEAD_OBJECT; } else if (processSize < 0) { @@ -173,7 +174,7 @@ private: } } - base::unique_fd mSocket; + TransportFd mSocket; // For now, we copy all the input data into a temporary buffer because // we might get multiple interruptableReadFully calls per message, but @@ -192,8 +193,7 @@ private: // RpcTransportCtx for Trusty. class RpcTransportCtxTipcAndroid : public RpcTransportCtx { public: - std::unique_ptr<RpcTransport> newTransport(android::base::unique_fd fd, - FdTrigger*) const override { + std::unique_ptr<RpcTransport> newTransport(android::TransportFd fd, FdTrigger*) const override { return std::make_unique<RpcTransportTipcAndroid>(std::move(fd)); } std::vector<uint8_t> getCertificate(RpcCertificateFormat) const override { return {}; } diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp index 09b5c17152..f18519c640 100644 --- a/libs/binder/RpcTransportTls.cpp +++ b/libs/binder/RpcTransportTls.cpp @@ -182,8 +182,8 @@ public: // If |sslError| is WANT_READ / WANT_WRITE, poll for POLLIN / POLLOUT respectively. Otherwise // return error. Also return error if |fdTrigger| is triggered before or during poll(). status_t pollForSslError( - android::base::borrowed_fd fd, int sslError, FdTrigger* fdTrigger, const char* fnString, - int additionalEvent, + const android::TransportFd& fd, int sslError, FdTrigger* fdTrigger, + const char* fnString, int additionalEvent, const std::optional<android::base::function_ref<status_t()>>& altPoll) { switch (sslError) { case SSL_ERROR_WANT_READ: @@ -198,7 +198,7 @@ public: private: bool mHandled = false; - status_t handlePoll(int event, android::base::borrowed_fd fd, FdTrigger* fdTrigger, + status_t handlePoll(int event, const android::TransportFd& fd, FdTrigger* fdTrigger, const char* fnString, const std::optional<android::base::function_ref<status_t()>>& altPoll) { status_t ret; @@ -277,7 +277,7 @@ private: class RpcTransportTls : public RpcTransport { public: - RpcTransportTls(android::base::unique_fd socket, Ssl ssl) + RpcTransportTls(TransportFd socket, Ssl ssl) : mSocket(std::move(socket)), mSsl(std::move(ssl)) {} status_t pollRead(void) override; status_t interruptableWriteFully( @@ -290,8 +290,10 @@ public: const std::optional<android::base::function_ref<status_t()>>& altPoll, std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override; + bool isWaiting() { return mSocket.isInPollingState(); }; + private: - android::base::unique_fd mSocket; + android::TransportFd mSocket; Ssl mSsl; }; @@ -350,7 +352,7 @@ status_t RpcTransportTls::interruptableWriteFully( int sslError = mSsl.getError(writeSize); // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be // triggerablePoll()-ed. Then additionalEvent is no longer necessary. - status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, + status_t pollStatus = errorQueue.pollForSslError(mSocket, sslError, fdTrigger, "SSL_write", POLLIN, altPoll); if (pollStatus != OK) return pollStatus; // Do not advance buffer. Try SSL_write() again. @@ -398,7 +400,7 @@ status_t RpcTransportTls::interruptableReadFully( return DEAD_OBJECT; } int sslError = mSsl.getError(readSize); - status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, + status_t pollStatus = errorQueue.pollForSslError(mSocket, sslError, fdTrigger, "SSL_read", 0, altPoll); if (pollStatus != OK) return pollStatus; // Do not advance buffer. Try SSL_read() again. @@ -409,8 +411,8 @@ status_t RpcTransportTls::interruptableReadFully( } // For |ssl|, set internal FD to |fd|, and do handshake. Handshake is triggerable by |fdTrigger|. -bool setFdAndDoHandshake(Ssl* ssl, android::base::borrowed_fd fd, FdTrigger* fdTrigger) { - bssl::UniquePtr<BIO> bio = newSocketBio(fd); +bool setFdAndDoHandshake(Ssl* ssl, const android::TransportFd& socket, FdTrigger* fdTrigger) { + bssl::UniquePtr<BIO> bio = newSocketBio(socket.fd); TEST_AND_RETURN(false, bio != nullptr); auto [_, errorQueue] = ssl->call(SSL_set_bio, bio.get(), bio.get()); (void)bio.release(); // SSL_set_bio takes ownership. @@ -430,7 +432,7 @@ bool setFdAndDoHandshake(Ssl* ssl, android::base::borrowed_fd fd, FdTrigger* fdT return false; } int sslError = ssl->getError(ret); - status_t pollStatus = errorQueue.pollForSslError(fd, sslError, fdTrigger, + status_t pollStatus = errorQueue.pollForSslError(socket, sslError, fdTrigger, "SSL_do_handshake", 0, std::nullopt); if (pollStatus != OK) return false; } @@ -442,8 +444,7 @@ public: typename = std::enable_if_t<std::is_base_of_v<RpcTransportCtxTls, Impl>>> static std::unique_ptr<RpcTransportCtxTls> create( std::shared_ptr<RpcCertificateVerifier> verifier, RpcAuth* auth); - std::unique_ptr<RpcTransport> newTransport(android::base::unique_fd fd, - FdTrigger* fdTrigger) const override; + std::unique_ptr<RpcTransport> newTransport(TransportFd fd, FdTrigger* fdTrigger) const override; std::vector<uint8_t> getCertificate(RpcCertificateFormat) const override; protected: @@ -513,15 +514,15 @@ std::unique_ptr<RpcTransportCtxTls> RpcTransportCtxTls::create( return ret; } -std::unique_ptr<RpcTransport> RpcTransportCtxTls::newTransport(android::base::unique_fd fd, +std::unique_ptr<RpcTransport> RpcTransportCtxTls::newTransport(android::TransportFd socket, FdTrigger* fdTrigger) const { bssl::UniquePtr<SSL> ssl(SSL_new(mCtx.get())); TEST_AND_RETURN(nullptr, ssl != nullptr); Ssl wrapped(std::move(ssl)); preHandshake(&wrapped); - TEST_AND_RETURN(nullptr, setFdAndDoHandshake(&wrapped, fd, fdTrigger)); - return std::make_unique<RpcTransportTls>(std::move(fd), std::move(wrapped)); + TEST_AND_RETURN(nullptr, setFdAndDoHandshake(&wrapped, socket, fdTrigger)); + return std::make_unique<RpcTransportTls>(std::move(socket), std::move(wrapped)); } class RpcTransportCtxTlsServer : public RpcTransportCtxTls { diff --git a/libs/binder/RpcTransportUtils.h b/libs/binder/RpcTransportUtils.h index 00cb2af8c3..d0843c0cc0 100644 --- a/libs/binder/RpcTransportUtils.h +++ b/libs/binder/RpcTransportUtils.h @@ -25,8 +25,8 @@ namespace android { template <typename SendOrReceive> status_t interruptableReadOrWrite( - int socketFd, FdTrigger* fdTrigger, iovec* iovs, int niovs, SendOrReceive sendOrReceiveFun, - const char* funName, int16_t event, + const android::TransportFd& socket, FdTrigger* fdTrigger, iovec* iovs, int niovs, + SendOrReceive sendOrReceiveFun, const char* funName, int16_t event, const std::optional<android::base::function_ref<status_t()>>& altPoll) { MAYBE_WAIT_IN_FLAKE_MODE; @@ -99,7 +99,7 @@ status_t interruptableReadOrWrite( return DEAD_OBJECT; } } else { - if (status_t status = fdTrigger->triggerablePoll(socketFd, event); status != OK) + if (status_t status = fdTrigger->triggerablePoll(socket, event); status != OK) return status; if (!havePolled) havePolled = true; } diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h index 52bda0e8ed..0a4e9d5a49 100644 --- a/libs/binder/include/binder/RpcServer.h +++ b/libs/binder/include/binder/RpcServer.h @@ -199,8 +199,8 @@ private: static constexpr size_t kRpcAddressSize = 128; static void establishConnection( - sp<RpcServer>&& server, base::unique_fd clientFd, - std::array<uint8_t, kRpcAddressSize> addr, size_t addrLen, + sp<RpcServer>&& server, TransportFd clientFd, std::array<uint8_t, kRpcAddressSize> addr, + size_t addrLen, std::function<void(sp<RpcSession>&&, RpcSession::PreJoinSetupResult&&)>&& joinFn); [[nodiscard]] status_t setupSocketServer(const RpcSocketAddress& address); @@ -210,7 +210,7 @@ private: // A mode is supported if the N'th bit is on, where N is the mode enum's value. std::bitset<8> mSupportedFileDescriptorTransportModes = std::bitset<8>().set( static_cast<size_t>(RpcSession::FileDescriptorTransportMode::NONE)); - base::unique_fd mServer; // socket we are accepting sessions on + TransportFd mServer; // socket we are accepting sessions on RpcMutex mLock; // for below std::unique_ptr<RpcMaybeThread> mJoinThread; diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h index 428e27209f..392b02c7d3 100644 --- a/libs/binder/include/binder/RpcSession.h +++ b/libs/binder/include/binder/RpcSession.h @@ -269,7 +269,7 @@ private: [[nodiscard]] status_t setupOneSocketConnection(const RpcSocketAddress& address, const std::vector<uint8_t>& sessionId, bool incoming); - [[nodiscard]] status_t initAndAddConnection(base::unique_fd fd, + [[nodiscard]] status_t initAndAddConnection(TransportFd fd, const std::vector<uint8_t>& sessionId, bool incoming); [[nodiscard]] status_t addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport); diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h index 5197ef9f4b..89f34f8979 100644 --- a/libs/binder/include/binder/RpcTransport.h +++ b/libs/binder/include/binder/RpcTransport.h @@ -30,12 +30,14 @@ #include <utils/Errors.h> #include <binder/RpcCertificateFormat.h> +#include <binder/RpcThreads.h> #include <sys/uio.h> namespace android { class FdTrigger; +struct TransportFd; // Represents a socket connection. // No thread-safety is guaranteed for these APIs. @@ -81,6 +83,15 @@ public: const std::optional<android::base::function_ref<status_t()>> &altPoll, std::vector<std::variant<base::unique_fd, base::borrowed_fd>> *ancillaryFds) = 0; + /** + * Check whether any threads are blocked while polling the transport + * for read operations + * Return: + * True - Specifies that there is active polling on transport. + * False - No active polling on transport + */ + [[nodiscard]] virtual bool isWaiting() = 0; + protected: RpcTransport() = default; }; @@ -96,7 +107,7 @@ public: // Implementation details: for TLS, this function may incur I/O. |fdTrigger| may be used // to interrupt I/O. This function blocks until handshake is finished. [[nodiscard]] virtual std::unique_ptr<RpcTransport> newTransport( - android::base::unique_fd fd, FdTrigger *fdTrigger) const = 0; + android::TransportFd fd, FdTrigger *fdTrigger) const = 0; // Return the preconfigured certificate of this context. // @@ -129,4 +140,36 @@ protected: RpcTransportCtxFactory() = default; }; +struct TransportFd { +private: + mutable bool isPolling{false}; + + void setPollingState(bool state) const { isPolling = state; } + +public: + base::unique_fd fd; + + TransportFd() = default; + explicit TransportFd(base::unique_fd &&descriptor) + : isPolling(false), fd(std::move(descriptor)) {} + + TransportFd(TransportFd &&transportFd) noexcept + : isPolling(transportFd.isPolling), fd(std::move(transportFd.fd)) {} + + TransportFd &operator=(TransportFd &&transportFd) noexcept { + fd = std::move(transportFd.fd); + isPolling = transportFd.isPolling; + return *this; + } + + TransportFd &operator=(base::unique_fd &&descriptor) noexcept { + fd = std::move(descriptor); + isPolling = false; + return *this; + } + + bool isInPollingState() const { return isPolling; } + friend class FdTrigger; +}; + } // namespace android diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index 4c037b77db..cfdfdf8907 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -1773,7 +1773,7 @@ public: } } mFd = rpcServer->releaseServer(); - if (!mFd.ok()) return AssertionFailure() << "releaseServer returns invalid fd"; + if (!mFd.fd.ok()) return AssertionFailure() << "releaseServer returns invalid fd"; mCtx = newFactory(rpcSecurity, mCertVerifier, std::move(auth))->newServerCtx(); if (mCtx == nullptr) return AssertionFailure() << "newServerCtx"; mSetup = true; @@ -1794,7 +1794,7 @@ public: std::vector<std::thread> threads; while (OK == mFdTrigger->triggerablePoll(mFd, POLLIN)) { base::unique_fd acceptedFd( - TEMP_FAILURE_RETRY(accept4(mFd.get(), nullptr, nullptr /*length*/, + TEMP_FAILURE_RETRY(accept4(mFd.fd.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC | SOCK_NONBLOCK))); threads.emplace_back(&Server::handleOne, this, std::move(acceptedFd)); } @@ -1803,7 +1803,8 @@ public: } void handleOne(android::base::unique_fd acceptedFd) { ASSERT_TRUE(acceptedFd.ok()); - auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get()); + TransportFd transportFd(std::move(acceptedFd)); + auto serverTransport = mCtx->newTransport(std::move(transportFd), mFdTrigger.get()); if (serverTransport == nullptr) return; // handshake failed ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get())); } @@ -1822,7 +1823,7 @@ public: std::unique_ptr<std::thread> mThread; ConnectToServer mConnectToServer; std::unique_ptr<FdTrigger> mFdTrigger = FdTrigger::make(); - base::unique_fd mFd; + TransportFd mFd; std::unique_ptr<RpcTransportCtx> mCtx; std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); @@ -1869,7 +1870,7 @@ public: // connect() and do handshake bool setUpTransport() { mFd = mConnectToServer(); - if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; + if (!mFd.fd.ok()) return AssertionFailure() << "Cannot connect to server"; mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); return mClientTransport != nullptr; } @@ -1898,9 +1899,11 @@ public: ASSERT_EQ(readOk, readMessage()); } + bool isTransportWaiting() { return mClientTransport->isWaiting(); } + private: ConnectToServer mConnectToServer; - base::unique_fd mFd; + TransportFd mFd; std::unique_ptr<FdTrigger> mFdTrigger = FdTrigger::make(); std::unique_ptr<RpcTransportCtx> mCtx; std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = @@ -2147,6 +2150,56 @@ TEST_P(RpcTransportTest, Trigger) { ASSERT_FALSE(client.readMessage(msg2)); } +TEST_P(RpcTransportTest, CheckWaitingForRead) { + std::mutex readMutex; + std::condition_variable readCv; + bool shouldContinueReading = false; + // Server will write data on transport once its started + auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { + std::string message(RpcTransportTestUtils::kMessage); + iovec messageIov{message.data(), message.size()}; + auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, + std::nullopt, nullptr); + if (status != OK) return AssertionFailure() << statusToString(status); + + { + std::unique_lock<std::mutex> lock(readMutex); + shouldContinueReading = true; + lock.unlock(); + readCv.notify_all(); + } + return AssertionSuccess(); + }; + + // Setup Server and client + auto server = std::make_unique<Server>(); + ASSERT_TRUE(server->setUp(GetParam())); + + Client client(server->getConnectToServerFn()); + ASSERT_TRUE(client.setUp(GetParam())); + + ASSERT_EQ(OK, trust(&client, server)); + ASSERT_EQ(OK, trust(server, &client)); + server->setPostConnect(serverPostConnect); + + server->start(); + ASSERT_TRUE(client.setUpTransport()); + { + // Wait till server writes data + std::unique_lock<std::mutex> lock(readMutex); + ASSERT_TRUE(readCv.wait_for(lock, 3s, [&] { return shouldContinueReading; })); + } + + // Since there is no read polling here, we will get polling count 0 + ASSERT_FALSE(client.isTransportWaiting()); + ASSERT_TRUE(client.readMessage(RpcTransportTestUtils::kMessage)); + // Thread should increment polling count, read and decrement polling count + // Again, polling count should be zero here + ASSERT_FALSE(client.isTransportWaiting()); + + server->shutdown(); +} + INSTANTIATE_TEST_CASE_P(BinderRpc, RpcTransportTest, ::testing::ValuesIn(RpcTransportTest::getRpcTranportTestParams()), RpcTransportTest::PrintParamInfo); diff --git a/libs/binder/trusty/RpcServerTrusty.cpp b/libs/binder/trusty/RpcServerTrusty.cpp index c789614c0b..9e48f811a9 100644 --- a/libs/binder/trusty/RpcServerTrusty.cpp +++ b/libs/binder/trusty/RpcServerTrusty.cpp @@ -118,10 +118,12 @@ int RpcServerTrusty::handleConnect(const tipc_port* port, handle_t chan, const u }; base::unique_fd clientFd(chan); + android::TransportFd transportFd(std::move(clientFd)); + std::array<uint8_t, RpcServer::kRpcAddressSize> addr; constexpr size_t addrLen = sizeof(*peer); memcpy(addr.data(), peer, addrLen); - RpcServer::establishConnection(sp(server->mRpcServer), std::move(clientFd), addr, addrLen, + RpcServer::establishConnection(sp(server->mRpcServer), std::move(transportFd), addr, addrLen, joinFn); return rc; diff --git a/libs/binder/trusty/RpcTransportTipcTrusty.cpp b/libs/binder/trusty/RpcTransportTipcTrusty.cpp index dc27eb929c..4777a5e0cc 100644 --- a/libs/binder/trusty/RpcTransportTipcTrusty.cpp +++ b/libs/binder/trusty/RpcTransportTipcTrusty.cpp @@ -33,7 +33,7 @@ namespace { // RpcTransport for Trusty. class RpcTransportTipcTrusty : public RpcTransport { public: - explicit RpcTransportTipcTrusty(android::base::unique_fd socket) : mSocket(std::move(socket)) {} + explicit RpcTransportTipcTrusty(android::TransportFd socket) : mSocket(std::move(socket)) {} ~RpcTransportTipcTrusty() { releaseMessage(); } status_t pollRead() override { @@ -64,7 +64,7 @@ public: .num_handles = 0, // TODO: add ancillaryFds .handles = nullptr, }; - ssize_t rc = send_msg(mSocket.get(), &msg); + ssize_t rc = send_msg(mSocket.fd.get(), &msg); if (rc == ERR_NOT_ENOUGH_BUFFER) { // Peer is blocked, wait until it unblocks. // TODO: when tipc supports a send-unblocked handler, @@ -72,7 +72,7 @@ public: // when the handler gets called by the library uevent uevt; do { - rc = ::wait(mSocket.get(), &uevt, INFINITE_TIME); + rc = ::wait(mSocket.fd.get(), &uevt, INFINITE_TIME); if (rc < 0) { return statusFromTrusty(rc); } @@ -83,7 +83,7 @@ public: // Retry the send, it should go through this time because // sending is now unblocked - rc = send_msg(mSocket.get(), &msg); + rc = send_msg(mSocket.fd.get(), &msg); } if (rc < 0) { return statusFromTrusty(rc); @@ -129,7 +129,7 @@ public: .num_handles = 0, // TODO: support ancillaryFds .handles = nullptr, }; - ssize_t rc = read_msg(mSocket.get(), mMessageInfo.id, mMessageOffset, &msg); + ssize_t rc = read_msg(mSocket.fd.get(), mMessageInfo.id, mMessageOffset, &msg); if (rc < 0) { return statusFromTrusty(rc); } @@ -169,6 +169,8 @@ public: } } + bool isWaiting() override { return mSocket.isInPollingState(); } + private: status_t ensureMessage(bool wait) { int rc; @@ -179,7 +181,7 @@ private: /* TODO: interruptible wait, maybe with a timeout??? */ uevent uevt; - rc = ::wait(mSocket.get(), &uevt, wait ? INFINITE_TIME : 0); + rc = ::wait(mSocket.fd.get(), &uevt, wait ? INFINITE_TIME : 0); if (rc < 0) { if (rc == ERR_TIMED_OUT && !wait) { // If we timed out with wait==false, then there's no message @@ -192,7 +194,7 @@ private: return OK; } - rc = get_msg(mSocket.get(), &mMessageInfo); + rc = get_msg(mSocket.fd.get(), &mMessageInfo); if (rc < 0) { return statusFromTrusty(rc); } @@ -204,12 +206,12 @@ private: void releaseMessage() { if (mHaveMessage) { - put_msg(mSocket.get(), mMessageInfo.id); + put_msg(mSocket.fd.get(), mMessageInfo.id); mHaveMessage = false; } } - base::unique_fd mSocket; + android::TransportFd mSocket; bool mHaveMessage = false; ipc_msg_info mMessageInfo; @@ -219,9 +221,9 @@ private: // RpcTransportCtx for Trusty. class RpcTransportCtxTipcTrusty : public RpcTransportCtx { public: - std::unique_ptr<RpcTransport> newTransport(android::base::unique_fd fd, + std::unique_ptr<RpcTransport> newTransport(android::TransportFd socket, FdTrigger*) const override { - return std::make_unique<RpcTransportTipcTrusty>(std::move(fd)); + return std::make_unique<RpcTransportTipcTrusty>(std::move(socket)); } std::vector<uint8_t> getCertificate(RpcCertificateFormat) const override { return {}; } }; |