From c59db84c455570af8a314f426357355499c804d5 Mon Sep 17 00:00:00 2001 From: Shai Barack Date: Wed, 5 Feb 2025 07:35:05 -0800 Subject: Remove synchronization (and potential priority inversions) from RateLimitingCache Noticed this issue while looking at a trace of binder spam. Fetching the rate-limited value inside a synchronized block exposes all call sites to priority inversions. A priority inversion would happen for instance if an unimportant caller manages to fetch a value, fetching the value requires for instance a binder call, and when the calling unimportant thread becomes runnable again because the result is available that thread isn't scheduled for a very long time. In the meantime, another calling thread that is more important is blocked on entering the synchronized lock. This change removes the lock, and the opportunity for contention and inversion. This slightly changes the behavior of RateLimitingCache. Before this change, the rate limiter would strictly not permit more instances of getting the underlying value than is set. After this change, if there are at most N threads potentially calling into this cache then in the extreme we would admit at most limit + N - 1 calls to get the underlying value in a given window. Flag: EXEMPT bugfix Bug: 393503787 Change-Id: I5f392d54c10348fa1dac1c82f15900118325a679 --- .../android/internal/util/RateLimitingCache.java | 76 +++++++---- .../internal/util/RateLimitingCacheTest.java | 145 ++++++++++++++++++++- 2 files changed, 192 insertions(+), 29 deletions(-) diff --git a/core/java/com/android/internal/util/RateLimitingCache.java b/core/java/com/android/internal/util/RateLimitingCache.java index 9916076fd0ef..956d5d680fe7 100644 --- a/core/java/com/android/internal/util/RateLimitingCache.java +++ b/core/java/com/android/internal/util/RateLimitingCache.java @@ -17,6 +17,8 @@ package com.android.internal.util; import android.os.SystemClock; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; /** * A speed/rate limiting cache that's used to cache a value to be returned as long as period hasn't @@ -30,6 +32,12 @@ import android.os.SystemClock; * and then the cached value is returned for the remainder of the period. It uses a simple fixed * window method to track rate. Use a window and count appropriate for bursts of calls and for * high latency/cost of the AIDL call. + *

+ * This class is thread-safe. When multiple threads call get(), they will all fetch a new value + * if the cached value is stale. This is to prevent a slow getting thread from blocking other + * threads from getting a fresh value. In such circumsntaces it's possible to exceed + * count calls in a given period by up to the number of threads that are concurrently + * attempting to get a fresh value minus one. * * @param The type of the return value * @hide @@ -37,12 +45,11 @@ import android.os.SystemClock; @android.ravenwood.annotation.RavenwoodKeepWholeClass public class RateLimitingCache { - private volatile Value mCurrentValue; - private volatile long mLastTimestamp; // Can be last fetch time or window start of fetch time private final long mPeriodMillis; // window size private final int mLimit; // max per window - private int mCount = 0; // current count within window - private long mRandomOffset; // random offset to avoid batching of AIDL calls at window boundary + // random offset to avoid batching of AIDL calls at window boundary + private final long mRandomOffset; + private final AtomicReference mCachedValue = new AtomicReference(); /** * The interface to fetch the actual value, if the cache is null or expired. @@ -56,6 +63,12 @@ public class RateLimitingCache { V fetchValue(); } + class CachedValue { + Value value; + long timestamp; + AtomicInteger count; // current count within window + } + /** * Create a speed limiting cache that returns the same value until periodMillis has passed * and then fetches a new value via the {@link ValueFetcher}. @@ -83,6 +96,8 @@ public class RateLimitingCache { mLimit = count; if (mLimit > 1 && periodMillis > 1) { mRandomOffset = (long) (Math.random() * (periodMillis / 2)); + } else { + mRandomOffset = 0; } } @@ -102,34 +117,39 @@ public class RateLimitingCache { * @return the cached or latest value */ public Value get(ValueFetcher query) { - // If the value never changes - if (mPeriodMillis < 0 && mLastTimestamp != 0) { - return mCurrentValue; - } + CachedValue cached = mCachedValue.get(); - synchronized (this) { - // Get the current time and add a random offset to avoid colliding with other - // caches with similar harmonic window boundaries - final long now = getTime() + mRandomOffset; - final boolean newWindow = now - mLastTimestamp >= mPeriodMillis; - if (newWindow || mCount < mLimit) { - // Fetch a new value - mCurrentValue = query.fetchValue(); + // If the value never changes and there is a previous cached value, return it + if (mPeriodMillis < 0 && cached != null && cached.timestamp != 0) { + return cached.value; + } - // If rate limiting, set timestamp to start of this window - if (mLimit > 1) { - mLastTimestamp = now - (now % mPeriodMillis); - } else { - mLastTimestamp = now; - } + // Get the current time and add a random offset to avoid colliding with other + // caches with similar harmonic window boundaries + final long now = getTime() + mRandomOffset; + final boolean newWindow = cached == null || now - cached.timestamp >= mPeriodMillis; + if (newWindow || cached.count.getAndIncrement() < mLimit) { + // Fetch a new value + Value freshValue = query.fetchValue(); + long freshTimestamp = now; + // If rate limiting, set timestamp to start of this window + if (mLimit > 1) { + freshTimestamp = now - (now % mPeriodMillis); + } - if (newWindow) { - mCount = 1; - } else { - mCount++; - } + CachedValue freshCached = new CachedValue(); + freshCached.value = freshValue; + freshCached.timestamp = freshTimestamp; + if (newWindow) { + freshCached.count = new AtomicInteger(1); + } else { + freshCached.count = cached.count; } - return mCurrentValue; + + // If we fail to CAS then it means that another thread beat us to it. + // In this case we don't override their work. + mCachedValue.compareAndSet(cached, freshCached); } + return mCachedValue.get().value; } } diff --git a/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java b/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java index 7734148bac5e..52ff79da26ea 100644 --- a/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java +++ b/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java @@ -18,9 +18,15 @@ package com.android.internal.util; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import android.os.SystemClock; import androidx.test.runner.AndroidJUnit4; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.CountDownLatch; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -38,7 +44,7 @@ public class RateLimitingCacheTest { mCounter = -1; } - RateLimitingCache.ValueFetcher mFetcher = () -> { + private final RateLimitingCache.ValueFetcher mFetcher = () -> { return ++mCounter; }; @@ -119,6 +125,143 @@ public class RateLimitingCacheTest { assertCount(s, 2000, 20, 33); } + /** + * Exercises concurrent access to the cache. + */ + @Test + public void testMultipleThreads() throws InterruptedException { + final long periodMillis = 1000; + final int maxCountPerPeriod = 10; + final RateLimitingCache s = + new RateLimitingCache<>(periodMillis, maxCountPerPeriod); + + Thread t1 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(mFetcher); + } + }); + Thread t2 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(mFetcher); + } + }); + + final long startTimeMillis = SystemClock.elapsedRealtime(); + t1.start(); + t2.start(); + t1.join(); + t2.join(); + final long endTimeMillis = SystemClock.elapsedRealtime(); + + final long periodsElapsed = 1 + ((endTimeMillis - startTimeMillis) / periodMillis); + final long expected = Math.min(100 + 100, periodsElapsed * maxCountPerPeriod) - 1; + assertEquals(mCounter, expected); + } + + /** + * Multiple threads calling get() on the cache while the cached value is stale are allowed + * to fetch, regardless of the rate limiting. + * This is to prevent a slow getting thread from blocking other threads from getting a fresh + * value. + */ + @Test + public void testMultipleThreads_oneThreadIsSlow() throws InterruptedException { + final long periodMillis = 1000; + final int maxCountPerPeriod = 1; + final RateLimitingCache s = + new RateLimitingCache<>(periodMillis, maxCountPerPeriod); + + final CountDownLatch latch1 = new CountDownLatch(2); + final CountDownLatch latch2 = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + final RateLimitingCache.ValueFetcher fetcher = () -> { + latch1.countDown(); + try { + latch2.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return counter.incrementAndGet(); + }; + + Thread t1 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(fetcher); + } + }); + Thread t2 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(fetcher); + } + }); + + t1.start(); + t2.start(); + // Both threads should be admitted to fetch because there is no fresh cached value, + // even though this exceeds the rate limit of at most 1 call per period. + // Wait for both threads to be fetching. + latch1.await(); + // Allow the fetcher to return. + latch2.countDown(); + // Wait for both threads to finish their fetches. + t1.join(); + t2.join(); + + assertEquals(counter.get(), 2); + } + + /** + * Even if multiple threads race to refresh the cache, only one thread gets to set a new value. + * This ensures, among other things, that the cache never returns values that were fetched out + * of order. + */ + @Test + public void testMultipleThreads_cachedValueNeverGoesBackInTime() throws InterruptedException { + final long periodMillis = 10; + final int maxCountPerPeriod = 3; + final RateLimitingCache s = + new RateLimitingCache<>(periodMillis, maxCountPerPeriod); + final AtomicInteger counter = new AtomicInteger(); + final RateLimitingCache.ValueFetcher fetcher = () -> { + // Note that this fetcher has a side effect, which is strictly not allowed for + // RateLimitingCache users, but we make an exception for the purpose of this test. + return counter.incrementAndGet(); + }; + + // Make three threads that spin on getting from the cache + final AtomicBoolean shouldRun = new AtomicBoolean(true); + Runnable worker = new Runnable() { + @Override + public void run() { + while (shouldRun.get()) { + s.get(fetcher); + } + } + }; + Thread t1 = new Thread(worker); + Thread t2 = new Thread(worker); + Thread t3 = new Thread(worker); + t1.start(); + t2.start(); + t3.start(); + + // Get values until a sufficiently convincing high value while ensuring that values are + // monotonically non-decreasing. + int lastSeen = 0; + while (lastSeen < 10000) { + int value = s.get(fetcher); + if (value < lastSeen) { + fail("Unexpectedly saw decreasing value " + value + " after " + lastSeen); + } + lastSeen = value; + } + + shouldRun.set(false); + t1.join(); + t2.join(); + t3.join(); + } + /** * Helper to make repeated calls every 5 millis to verify the number of expected fetches for * the given parameters. -- cgit v1.2.3-59-g8ed1b