diff options
-rw-r--r-- | libs/binder/RpcServer.cpp | 42 | ||||
-rw-r--r-- | libs/binder/RpcSession.cpp | 36 | ||||
-rw-r--r-- | libs/binder/RpcState.cpp | 12 | ||||
-rw-r--r-- | libs/binder/RpcState.h | 2 | ||||
-rw-r--r-- | libs/binder/RpcWireFormat.h | 13 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcServer.h | 8 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcSession.h | 12 | ||||
-rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 16 |
8 files changed, 131 insertions, 10 deletions
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index 200d923b6d..62ea187719 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -110,6 +110,10 @@ size_t RpcServer::getMaxThreads() { return mMaxThreads; } +void RpcServer::setProtocolVersion(uint32_t version) { + mProtocolVersion = version; +} + void RpcServer::setRootObject(const sp<IBinder>& binder) { std::lock_guard<std::mutex> _l(mLock); mRootObjectWeak = mRootObject = binder; @@ -245,13 +249,37 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie RpcConnectionHeader header; status_t status = server->mShutdownTrigger->interruptableReadFully(clientFd.get(), &header, sizeof(header)); - bool idValid = status == OK; - if (!idValid) { + if (status != OK) { ALOGE("Failed to read ID for client connecting to RPC server: %s", statusToString(status).c_str()); // still need to cleanup before we can return } - bool incoming = header.options & RPC_CONNECTION_OPTION_INCOMING; + + bool incoming = false; + uint32_t protocolVersion = 0; + RpcAddress sessionId = RpcAddress::zero(); + bool requestingNewSession = false; + + if (status == OK) { + incoming = header.options & RPC_CONNECTION_OPTION_INCOMING; + protocolVersion = std::min(header.version, + server->mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION)); + sessionId = RpcAddress::fromRawEmbedded(&header.sessionId); + requestingNewSession = sessionId.isZero(); + + if (requestingNewSession) { + RpcNewSessionResponse response{ + .version = protocolVersion, + }; + + status = server->mShutdownTrigger->interruptableWriteFully(clientFd.get(), &response, + sizeof(response)); + if (status != OK) { + ALOGE("Failed to send new session response: %s", statusToString(status).c_str()); + // still need to cleanup before we can return + } + } + } std::thread thisThread; sp<RpcSession> session; @@ -269,19 +297,16 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie }; server->mConnectingThreads.erase(threadId); - if (!idValid || server->mShutdownTrigger->isTriggered()) { + if (status != OK || server->mShutdownTrigger->isTriggered()) { return; } - RpcAddress sessionId = RpcAddress::fromRawEmbedded(&header.sessionId); - - if (sessionId.isZero()) { + if (requestingNewSession) { if (incoming) { ALOGE("Cannot create a new session with an incoming connection, would leak"); return; } - sessionId = RpcAddress::zero(); size_t tries = 0; do { // don't block if there is some entropy issue @@ -295,6 +320,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie session = RpcSession::make(); session->setMaxThreads(server->mMaxThreads); + if (!session->setProtocolVersion(protocolVersion)) return; if (!session->setForServer(server, sp<RpcServer::EventListener>::fromExisting( static_cast<RpcServer::EventListener*>( diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp index 1c376518f0..90ce4d6d3f 100644 --- a/libs/binder/RpcSession.cpp +++ b/libs/binder/RpcSession.cpp @@ -77,6 +77,25 @@ size_t RpcSession::getMaxThreads() { return mMaxThreads; } +bool RpcSession::setProtocolVersion(uint32_t version) { + if (version >= RPC_WIRE_PROTOCOL_VERSION_NEXT && + version != RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL) { + ALOGE("Cannot start RPC session with version %u which is unknown (current protocol version " + "is %u).", + version, RPC_WIRE_PROTOCOL_VERSION); + return false; + } + + std::lock_guard<std::mutex> _l(mMutex); + mProtocolVersion = version; + return true; +} + +std::optional<uint32_t> RpcSession::getProtocolVersion() { + std::lock_guard<std::mutex> _l(mMutex); + return mProtocolVersion; +} + bool RpcSession::setupUnixDomainClient(const char* path) { return setupSocketClient(UnixSocketAddress(path)); } @@ -424,6 +443,18 @@ bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) { if (!setupOneSocketConnection(addr, RpcAddress::zero(), false /*incoming*/)) return false; + { + ExclusiveConnection connection; + status_t status = ExclusiveConnection::find(sp<RpcSession>::fromExisting(this), + ConnectionUse::CLIENT, &connection); + if (status != OK) return false; + + uint32_t version; + status = state()->readNewSessionResponse(connection.get(), + sp<RpcSession>::fromExisting(this), &version); + if (!setProtocolVersion(version)) return false; + } + // TODO(b/189955605): we should add additional sessions dynamically // instead of all at once. // TODO(b/186470974): first risk of blocking @@ -484,7 +515,10 @@ bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, const Rp return false; } - RpcConnectionHeader header{.options = 0}; + RpcConnectionHeader header{ + .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION), + .options = 0, + }; memcpy(&header.sessionId, &id.viewRawEmbedded(), sizeof(RpcWireAddress)); if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING; diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp index 332c75f9e7..f3406bb10b 100644 --- a/libs/binder/RpcState.cpp +++ b/libs/binder/RpcState.cpp @@ -315,6 +315,18 @@ status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection, return OK; } +status_t RpcState::readNewSessionResponse(const sp<RpcSession::RpcConnection>& connection, + const sp<RpcSession>& session, uint32_t* version) { + RpcNewSessionResponse response; + if (status_t status = + rpcRec(connection, session, "new session response", &response, sizeof(response)); + status != OK) { + return status; + } + *version = response.version; + return OK; +} + status_t RpcState::sendConnectionInit(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session) { RpcOutgoingConnectionInit init{ diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h index 5ac0b973f1..1446eecc64 100644 --- a/libs/binder/RpcState.h +++ b/libs/binder/RpcState.h @@ -60,6 +60,8 @@ public: RpcState(); ~RpcState(); + status_t readNewSessionResponse(const sp<RpcSession::RpcConnection>& connection, + const sp<RpcSession>& session, uint32_t* version); status_t sendConnectionInit(const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session); status_t readConnectionInit(const sp<RpcSession::RpcConnection>& connection, diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h index 2a44c7af04..0f8efd2391 100644 --- a/libs/binder/RpcWireFormat.h +++ b/libs/binder/RpcWireFormat.h @@ -37,9 +37,20 @@ struct RpcWireAddress { * either as part of a new session or an existing session */ struct RpcConnectionHeader { + uint32_t version; // maximum supported by caller + uint8_t reserver0[4]; RpcWireAddress sessionId; uint8_t options; - uint8_t reserved[7]; + uint8_t reserved1[7]; +}; + +/** + * In response to an RpcConnectionHeader which corresponds to a new session, + * this returns information to the server. + */ +struct RpcNewSessionResponse { + uint32_t version; // maximum supported by callee <= maximum supported by caller + uint8_t reserved[4]; }; #define RPC_CONNECTION_INIT_OKAY "cci" diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h index a8094dd081..40ff78cd37 100644 --- a/libs/binder/include/binder/RpcServer.h +++ b/libs/binder/include/binder/RpcServer.h @@ -105,6 +105,13 @@ public: size_t getMaxThreads(); /** + * By default, the latest protocol version which is supported by a client is + * used. However, this can be used in order to prevent newer protocol + * versions from ever being used. This is expected to be useful for testing. + */ + void setProtocolVersion(uint32_t version); + + /** * The root object can be retrieved by any client, without any * authentication. TODO(b/183988761) * @@ -164,6 +171,7 @@ private: bool mAgreedExperimental = false; size_t mMaxThreads = 1; + std::optional<uint32_t> mProtocolVersion; base::unique_fd mServer; // socket we are accepting sessions on std::mutex mLock; // for below diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h index 2101df85c4..1f7c0291a9 100644 --- a/libs/binder/include/binder/RpcSession.h +++ b/libs/binder/include/binder/RpcSession.h @@ -37,6 +37,10 @@ class RpcServer; class RpcSocketAddress; class RpcState; +constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION_NEXT = 0; +constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL = 0xF0000000; +constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION = RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL; + /** * This represents a session (group of connections) between a client * and a server. Multiple connections are needed for multiple parallel "binder" @@ -60,6 +64,13 @@ public: size_t getMaxThreads(); /** + * By default, the minimum of the supported versions of the client and the + * server will be used. Usually, this API should only be used for debugging. + */ + [[nodiscard]] bool setProtocolVersion(uint32_t version); + std::optional<uint32_t> getProtocolVersion(); + + /** * This should be called once per thread, matching 'join' in the remote * process. */ @@ -291,6 +302,7 @@ private: std::mutex mMutex; // for all below size_t mMaxThreads = 0; + std::optional<uint32_t> mProtocolVersion; std::condition_variable mAvailableConnectionCv; // for mWaitingThreads size_t mWaitingThreads = 0; diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index 40ebd9c8bc..d5786bcbe1 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -47,6 +47,9 @@ using namespace std::chrono_literals; namespace android { +static_assert(RPC_WIRE_PROTOCOL_VERSION + 1 == RPC_WIRE_PROTOCOL_VERSION_NEXT || + RPC_WIRE_PROTOCOL_VERSION == RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL); + TEST(BinderRpcParcel, EntireParcelFormatted) { Parcel p; p.writeInt32(3); @@ -67,6 +70,19 @@ TEST(BinderRpc, SetExternalServer) { ASSERT_EQ(sinkFd, retrieved.get()); } +TEST(BinderRpc, CannotUseNextWireVersion) { + auto session = RpcSession::make(); + EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT)); + EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT + 1)); + EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT + 2)); + EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT + 15)); +} + +TEST(BinderRpc, CanUseExperimentalWireVersion) { + auto session = RpcSession::make(); + EXPECT_TRUE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL)); +} + using android::binder::Status; #define EXPECT_OK(status) \ |