diff options
-rw-r--r-- | libs/binder/IPCThreadState.cpp | 47 | ||||
-rw-r--r-- | libs/binder/RpcServer.cpp | 67 | ||||
-rw-r--r-- | libs/binder/include/binder/IPCThreadState.h | 31 | ||||
-rw-r--r-- | libs/binder/include/binder/RpcServer.h | 35 | ||||
-rw-r--r-- | libs/binder/tests/binderLibTest.cpp | 31 | ||||
-rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 50 |
6 files changed, 245 insertions, 16 deletions
diff --git a/libs/binder/IPCThreadState.cpp b/libs/binder/IPCThreadState.cpp index ef7fd44419..d2919e7f4d 100644 --- a/libs/binder/IPCThreadState.cpp +++ b/libs/binder/IPCThreadState.cpp @@ -366,19 +366,46 @@ status_t IPCThreadState::clearLastError() pid_t IPCThreadState::getCallingPid() const { + checkContextIsBinderForUse(__func__); return mCallingPid; } const char* IPCThreadState::getCallingSid() const { + checkContextIsBinderForUse(__func__); return mCallingSid; } uid_t IPCThreadState::getCallingUid() const { + checkContextIsBinderForUse(__func__); return mCallingUid; } +const IPCThreadState::SpGuard* IPCThreadState::pushGetCallingSpGuard(const SpGuard* guard) { + const SpGuard* orig = mServingStackPointerGuard; + mServingStackPointerGuard = guard; + return orig; +} + +void IPCThreadState::restoreGetCallingSpGuard(const SpGuard* guard) { + mServingStackPointerGuard = guard; +} + +void IPCThreadState::checkContextIsBinderForUse(const char* use) const { + if (LIKELY(mServingStackPointerGuard == nullptr)) return; + + if (!mServingStackPointer || mServingStackPointerGuard->address < mServingStackPointer) { + LOG_ALWAYS_FATAL("In context %s, %s does not make sense (binder sp: %p, guard: %p).", + mServingStackPointerGuard->context, use, mServingStackPointer, + mServingStackPointerGuard->address); + } + + // in the case mServingStackPointer is deeper in the stack than the guard, + // we must be serving a binder transaction (maybe nested). This is a binder + // context, so we don't abort +} + int64_t IPCThreadState::clearCallingIdentity() { // ignore mCallingSid for legacy reasons @@ -847,15 +874,15 @@ status_t IPCThreadState::clearDeathNotification(int32_t handle, BpBinder* proxy) } IPCThreadState::IPCThreadState() - : mProcess(ProcessState::self()), - mServingStackPointer(nullptr), - mWorkSource(kUnsetWorkSource), - mPropagateWorkSource(false), - mIsLooper(false), - mStrictModePolicy(0), - mLastTransactionBinderFlags(0), - mCallRestriction(mProcess->mCallRestriction) -{ + : mProcess(ProcessState::self()), + mServingStackPointer(nullptr), + mServingStackPointerGuard(nullptr), + mWorkSource(kUnsetWorkSource), + mPropagateWorkSource(false), + mIsLooper(false), + mStrictModePolicy(0), + mLastTransactionBinderFlags(0), + mCallRestriction(mProcess->mCallRestriction) { pthread_setspecific(gTLS, this); clearCaller(); mIn.setDataCapacity(256); @@ -1230,7 +1257,7 @@ status_t IPCThreadState::executeCommand(int32_t cmd) tr.offsets_size/sizeof(binder_size_t), freeBuffer); const void* origServingStackPointer = mServingStackPointer; - mServingStackPointer = &origServingStackPointer; // anything on the stack + mServingStackPointer = __builtin_frame_address(0); const pid_t origPid = mCallingPid; const char* origSid = mCallingSid; diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp index 59659bd0a6..e31aea021f 100644 --- a/libs/binder/RpcServer.cpp +++ b/libs/binder/RpcServer.cpp @@ -16,19 +16,21 @@ #define LOG_TAG "RpcServer" +#include <poll.h> #include <sys/socket.h> #include <sys/un.h> #include <thread> #include <vector> +#include <android-base/macros.h> #include <android-base/scopeguard.h> #include <binder/Parcel.h> #include <binder/RpcServer.h> #include <log/log.h> -#include "RpcState.h" #include "RpcSocketAddress.h" +#include "RpcState.h" #include "RpcWireFormat.h" namespace android { @@ -99,7 +101,7 @@ bool RpcServer::setupInetServer(unsigned int port, unsigned int* assignedPort) { void RpcServer::setMaxThreads(size_t threads) { LOG_ALWAYS_FATAL_IF(threads <= 0, "RpcServer is useless without threads"); - LOG_ALWAYS_FATAL_IF(mStarted, "must be called before started"); + LOG_ALWAYS_FATAL_IF(mJoinThreadRunning, "Cannot set max threads while running"); mMaxThreads = threads; } @@ -126,16 +128,61 @@ sp<IBinder> RpcServer::getRootObject() { return ret; } +std::unique_ptr<RpcServer::FdTrigger> RpcServer::FdTrigger::make() { + auto ret = std::make_unique<RpcServer::FdTrigger>(); + if (!android::base::Pipe(&ret->mRead, &ret->mWrite)) return nullptr; + return ret; +} + +void RpcServer::FdTrigger::trigger() { + mWrite.reset(); +} + void RpcServer::join() { + LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!"); + + { + std::lock_guard<std::mutex> _l(mLock); + LOG_ALWAYS_FATAL_IF(!mServer.ok(), "RpcServer must be setup to join."); + LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined"); + mJoinThreadRunning = true; + mShutdownTrigger = FdTrigger::make(); + LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler"); + } + while (true) { - (void)acceptOne(); + pollfd pfd[]{{.fd = mServer.get(), .events = POLLIN, .revents = 0}, + {.fd = mShutdownTrigger->readFd().get(), .events = POLLHUP, .revents = 0}}; + int ret = TEMP_FAILURE_RETRY(poll(pfd, arraysize(pfd), -1)); + if (ret < 0) { + ALOGE("Could not poll socket: %s", strerror(errno)); + continue; + } + if (ret == 0) { + continue; + } + if (pfd[1].revents & POLLHUP) { + LOG_RPC_DETAIL("join() exiting because shutdown requested."); + break; + } + + (void)acceptOneNoCheck(); + } + + { + std::lock_guard<std::mutex> _l(mLock); + mJoinThreadRunning = false; } + mShutdownCv.notify_all(); } bool RpcServer::acceptOne() { LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!"); - LOG_ALWAYS_FATAL_IF(!hasServer(), "RpcServer must be setup to join."); + LOG_ALWAYS_FATAL_IF(!hasServer(), "RpcServer must be setup to acceptOne."); + return acceptOneNoCheck(); +} +bool RpcServer::acceptOneNoCheck() { unique_fd clientFd( TEMP_FAILURE_RETRY(accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC))); @@ -156,6 +203,18 @@ bool RpcServer::acceptOne() { return true; } +bool RpcServer::shutdown() { + LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!"); + std::unique_lock<std::mutex> _l(mLock); + if (mShutdownTrigger == nullptr) return false; + + mShutdownTrigger->trigger(); + while (mJoinThreadRunning) mShutdownCv.wait(_l); + + mShutdownTrigger = nullptr; + return true; +} + std::vector<sp<RpcSession>> RpcServer::listSessions() { std::lock_guard<std::mutex> _l(mLock); std::vector<sp<RpcSession>> sessions; diff --git a/libs/binder/include/binder/IPCThreadState.h b/libs/binder/include/binder/IPCThreadState.h index 23a0cb0148..ee661a5948 100644 --- a/libs/binder/include/binder/IPCThreadState.h +++ b/libs/binder/include/binder/IPCThreadState.h @@ -81,6 +81,36 @@ public: */ uid_t getCallingUid() const; + /** + * Make it an abort to rely on getCalling* for a section of + * execution. + * + * Usage: + * IPCThreadState::SpGuard guard { + * .address = __builtin_frame_address(0), + * .context = "...", + * }; + * const auto* orig = pushGetCallingSpGuard(&guard); + * { + * // will abort if you call getCalling*, unless you are + * // serving a nested binder transaction + * } + * restoreCallingSpGuard(orig); + */ + struct SpGuard { + const void* address; + const char* context; + }; + const SpGuard* pushGetCallingSpGuard(const SpGuard* guard); + void restoreGetCallingSpGuard(const SpGuard* guard); + /** + * Used internally by getCalling*. Can also be used to assert that + * you are in a binder context (getCalling* is valid). This is + * intentionally not exposed as a boolean API since code should be + * written to know its environment. + */ + void checkContextIsBinderForUse(const char* use) const; + void setStrictModePolicy(int32_t policy); int32_t getStrictModePolicy() const; @@ -203,6 +233,7 @@ private: Parcel mOut; status_t mLastError; const void* mServingStackPointer; + const SpGuard* mServingStackPointerGuard; pid_t mCallingPid; const char* mCallingSid; uid_t mCallingUid; diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h index 8f0c6fd5e1..49734007ac 100644 --- a/libs/binder/include/binder/RpcServer.h +++ b/libs/binder/include/binder/RpcServer.h @@ -119,11 +119,22 @@ public: /** * You must have at least one client session before calling this. * - * TODO(b/185167543): way to shut down? + * If a client needs to actively terminate join, call shutdown() in a separate thread. + * + * At any given point, there can only be one thread calling join(). */ void join(); /** + * Shut down any existing join(). Return true if successfully shut down, false otherwise + * (e.g. no join() is running). Will wait for the server to be fully + * shutdown. + * + * TODO(b/185167543): wait for sessions to shutdown as well + */ + [[nodiscard]] bool shutdown(); + + /** * Accept one connection on this server. You must have at least one client * session before calling this. */ @@ -142,14 +153,31 @@ public: void onSessionTerminating(const sp<RpcSession>& session); private: + /** This is not a pipe. */ + struct FdTrigger { + static std::unique_ptr<FdTrigger> make(); + /** + * poll() on this fd for POLLHUP to get notification when trigger is called + */ + base::borrowed_fd readFd() const { return mRead; } + /** + * Close the write end of the pipe so that the read end receives POLLHUP. + */ + void trigger(); + + private: + base::unique_fd mWrite; + base::unique_fd mRead; + }; + friend sp<RpcServer>; RpcServer(); void establishConnection(sp<RpcServer>&& session, base::unique_fd clientFd); bool setupSocketServer(const RpcSocketAddress& address); + [[nodiscard]] bool acceptOneNoCheck(); bool mAgreedExperimental = false; - bool mStarted = false; // TODO(b/185167543): support dynamically added clients size_t mMaxThreads = 1; base::unique_fd mServer; // socket we are accepting sessions on @@ -159,6 +187,9 @@ private: wp<IBinder> mRootObjectWeak; std::map<int32_t, sp<RpcSession>> mSessions; int32_t mSessionIdCounter = 0; + bool mJoinThreadRunning = false; + std::unique_ptr<FdTrigger> mShutdownTrigger; + std::condition_variable mShutdownCv; }; } // namespace android diff --git a/libs/binder/tests/binderLibTest.cpp b/libs/binder/tests/binderLibTest.cpp index 0c3fbcd2da..7679b46c4e 100644 --- a/libs/binder/tests/binderLibTest.cpp +++ b/libs/binder/tests/binderLibTest.cpp @@ -73,6 +73,7 @@ enum BinderLibTestTranscationCode { BINDER_LIB_TEST_REGISTER_SERVER, BINDER_LIB_TEST_ADD_SERVER, BINDER_LIB_TEST_ADD_POLL_SERVER, + BINDER_LIB_TEST_USE_CALLING_GUARD_TRANSACTION, BINDER_LIB_TEST_CALL_BACK, BINDER_LIB_TEST_CALL_BACK_VERIFY_BUF, BINDER_LIB_TEST_DELAYED_CALL_BACK, @@ -604,6 +605,13 @@ TEST_F(BinderLibTest, CallBack) EXPECT_THAT(callBack->getResult(), StatusEq(NO_ERROR)); } +TEST_F(BinderLibTest, BinderCallContextGuard) { + sp<IBinder> binder = addServer(); + Parcel data, reply; + EXPECT_THAT(binder->transact(BINDER_LIB_TEST_USE_CALLING_GUARD_TRANSACTION, data, &reply), + StatusEq(DEAD_OBJECT)); +} + TEST_F(BinderLibTest, AddServer) { sp<IBinder> server = addServer(); @@ -1262,6 +1270,21 @@ class BinderLibTestService : public BBinder pthread_mutex_unlock(&m_serverWaitMutex); return ret; } + case BINDER_LIB_TEST_USE_CALLING_GUARD_TRANSACTION: { + IPCThreadState::SpGuard spGuard{ + .address = __builtin_frame_address(0), + .context = "GuardInBinderTransaction", + }; + const IPCThreadState::SpGuard *origGuard = + IPCThreadState::self()->pushGetCallingSpGuard(&spGuard); + + // if the guard works, this should abort + (void)IPCThreadState::self()->getCallingPid(); + + IPCThreadState::self()->restoreGetCallingSpGuard(origGuard); + return NO_ERROR; + } + case BINDER_LIB_TEST_GETPID: reply->writeInt32(getpid()); return NO_ERROR; @@ -1489,6 +1512,14 @@ int run_server(int index, int readypipefd, bool usePoll) { binderLibTestServiceName += String16(binderserversuffix); + // Testing to make sure that calls that we are serving can use getCallin* + // even though we don't here. + IPCThreadState::SpGuard spGuard{ + .address = __builtin_frame_address(0), + .context = "main server thread", + }; + (void)IPCThreadState::self()->pushGetCallingSpGuard(&spGuard); + status_t ret; sp<IServiceManager> sm = defaultServiceManager(); BinderLibTestService* testServicePtr; diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index a96deb5d27..fb0ffdb19c 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -40,6 +40,8 @@ #include "../RpcState.h" // for debugging #include "../vm_sockets.h" // for VMADDR_* +using namespace std::chrono_literals; + namespace android { TEST(BinderRpcParcel, EntireParcelFormatted) { @@ -970,6 +972,54 @@ TEST_P(BinderRpcServerRootObject, WeakRootObject) { INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcServerRootObject, ::testing::Combine(::testing::Bool(), ::testing::Bool())); +class OneOffSignal { +public: + // If notify() was previously called, or is called within |duration|, return true; else false. + template <typename R, typename P> + bool wait(std::chrono::duration<R, P> duration) { + std::unique_lock<std::mutex> lock(mMutex); + return mCv.wait_for(lock, duration, [this] { return mValue; }); + } + void notify() { + std::unique_lock<std::mutex> lock(mMutex); + mValue = true; + lock.unlock(); + mCv.notify_all(); + } + +private: + std::mutex mMutex; + std::condition_variable mCv; + bool mValue = false; +}; + +TEST(BinderRpc, Shutdown) { + auto addr = allocateSocketAddress(); + unlink(addr.c_str()); + auto server = RpcServer::make(); + server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction(); + ASSERT_TRUE(server->setupUnixDomainServer(addr.c_str())); + auto joinEnds = std::make_shared<OneOffSignal>(); + + // If things are broken and the thread never stops, don't block other tests. Because the thread + // may run after the test finishes, it must not access the stack memory of the test. Hence, + // shared pointers are passed. + std::thread([server, joinEnds] { + server->join(); + joinEnds->notify(); + }).detach(); + + bool shutdown = false; + for (int i = 0; i < 10 && !shutdown; i++) { + usleep(300 * 1000); // 300ms; total 3s + if (server->shutdown()) shutdown = true; + } + ASSERT_TRUE(shutdown) << "server->shutdown() never returns true"; + + ASSERT_TRUE(joinEnds->wait(2s)) + << "After server->shutdown() returns true, join() did not stop after 2s"; +} + } // namespace android int main(int argc, char** argv) { |