diff options
Diffstat (limited to 'libs')
| -rw-r--r-- | libs/binder/RpcTransportTls.cpp | 2 | ||||
| -rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 147 |
2 files changed, 121 insertions, 28 deletions
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp index d40cfc84fc..63f93391e5 100644 --- a/libs/binder/RpcTransportTls.cpp +++ b/libs/binder/RpcTransportTls.cpp @@ -347,7 +347,7 @@ status_t RpcTransportTls::isTriggered(FdTrigger* fdTrigger) { ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str()); return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code(); } - return OK; + return *ret ? -ECANCELED : OK; } status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data, diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index a4e37adde9..2fd63a36f7 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -53,6 +53,7 @@ #include "RpcCertificateVerifierSimple.h" using namespace std::chrono_literals; +using namespace std::placeholders; using testing::AssertionFailure; using testing::AssertionResult; using testing::AssertionSuccess; @@ -1444,7 +1445,7 @@ public: PrintToString(certificateFormat); } void TearDown() override { - for (auto& server : mServers) server->shutdown(); + for (auto& server : mServers) server->shutdownAndWait(); } // A server that handles client socket connections. @@ -1452,7 +1453,7 @@ public: public: explicit Server() {} Server(Server&&) = default; - ~Server() { shutdown(); } + ~Server() { shutdownAndWait(); } [[nodiscard]] AssertionResult setUp() { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); auto rpcServer = RpcServer::make(newFactory(rpcSecurity)); @@ -1536,17 +1537,17 @@ public: ASSERT_TRUE(acceptedFd.ok()); auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get()); if (serverTransport == nullptr) return; // handshake failed - std::string message(kMessage); - ASSERT_EQ(OK, - serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(), - message.size())); + ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get())); } - void shutdown() { - mFdTrigger->trigger(); - if (mThread != nullptr) { - mThread->join(); - mThread = nullptr; - } + void shutdownAndWait() { + shutdown(); + join(); + } + void shutdown() { mFdTrigger->trigger(); } + + void setPostConnect( + std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) { + mPostConnect = std::move(fn); } private: @@ -1558,6 +1559,26 @@ public: std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); bool mSetup = false; + // The function invoked after connection and handshake. By default, it is + // |defaultPostConnect| that sends |kMessage| to the client. + std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect = + Server::defaultPostConnect; + + void join() { + if (mThread != nullptr) { + mThread->join(); + mThread = nullptr; + } + } + + static AssertionResult defaultPostConnect(RpcTransport* serverTransport, + FdTrigger* fdTrigger) { + std::string message(kMessage); + auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), + message.size()); + if (status != OK) return AssertionFailure() << statusToString(status); + return AssertionSuccess(); + } }; class Client { @@ -1566,8 +1587,6 @@ public: Client(Client&&) = default; [[nodiscard]] AssertionResult setUp() { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); - mFd = mConnectToServer(); - if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; mFdTrigger = FdTrigger::make(); mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx(); if (mCtx == nullptr) return AssertionFailure() << "newClientCtx"; @@ -1577,24 +1596,35 @@ public: std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const { return mCertVerifier; } + // connect() and do handshake + bool setUpTransport() { + mFd = mConnectToServer(); + if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; + mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); + return mClientTransport != nullptr; + } + AssertionResult readMessage(const std::string& expectedMessage = kMessage) { + LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed"); + std::string readMessage(expectedMessage.size(), '\0'); + status_t readStatus = + mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), + readMessage.size()); + if (readStatus != OK) { + return AssertionFailure() << statusToString(readStatus); + } + if (readMessage != expectedMessage) { + return AssertionFailure() + << "Expected " << expectedMessage << ", actual " << readMessage; + } + return AssertionSuccess(); + } void run(bool handshakeOk = true, bool readOk = true) { - auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); - if (clientTransport == nullptr) { + if (!setUpTransport()) { ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't"; return; } ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should"; - std::string expectedMessage(kMessage); - std::string readMessage(expectedMessage.size(), '\0'); - status_t readStatus = - clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), - readMessage.size()); - if (readOk) { - ASSERT_EQ(OK, readStatus); - ASSERT_EQ(readMessage, expectedMessage); - } else { - ASSERT_NE(OK, readStatus); - } + ASSERT_EQ(readOk, readMessage()); } private: @@ -1604,6 +1634,7 @@ public: std::unique_ptr<RpcTransportCtx> mCtx; std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); + std::unique_ptr<RpcTransport> mClientTransport; }; // Make A trust B. @@ -1729,6 +1760,68 @@ TEST_P(RpcTransportTest, MaliciousClient) { maliciousClient.run(true, readOk); } +TEST_P(RpcTransportTest, Trigger) { + std::string msg2 = ", world!"; + std::mutex writeMutex; + std::condition_variable writeCv; + bool shouldContinueWriting = false; + auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { + std::string message(kMessage); + auto status = + serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); + if (status != OK) return AssertionFailure() << statusToString(status); + + { + std::unique_lock<std::mutex> lock(writeMutex); + if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) { + return AssertionFailure() << "write barrier not cleared in time!"; + } + } + + status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size()); + if (status != -ECANCELED) + return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully " + "should return -ECANCELLED, but it is " + << statusToString(status); + return AssertionSuccess(); + }; + + auto server = mServers.emplace_back(std::make_unique<Server>()).get(); + ASSERT_TRUE(server->setUp()); + + // Set up client + Client client(server->getConnectToServerFn()); + ASSERT_TRUE(client.setUp()); + + // Exchange keys + ASSERT_EQ(OK, trust(&client, server)); + ASSERT_EQ(OK, trust(server, &client)); + + server->setPostConnect(serverPostConnect); + + // Start server + server->start(); + // connect() to server and do handshake + ASSERT_TRUE(client.setUpTransport()); + // read the first message. This confirms that server has finished handshake and start handling + // client fd. Server thread should pause at waitForWriteBarrier. + ASSERT_TRUE(client.readMessage(kMessage)); + // Trigger server shutdown after server starts handling client FD. This ensures that the second + // write is on an FdTrigger that has been shut down. + server->shutdown(); + // Continues server thread to write the second message. + { + std::unique_lock<std::mutex> lock(writeMutex); + shouldContinueWriting = true; + lock.unlock(); + writeCv.notify_all(); + } + // After this line, server thread unblocks and attempts to write the second message, but + // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|. + // On the client side, second read fails with DEAD_OBJECT + ASSERT_FALSE(client.readMessage(msg2)); +} + std::vector<RpcCertificateFormat> testRpcCertificateFormats() { return { RpcCertificateFormat::PEM, |