diff options
| author | 2025-02-05 11:24:10 -0800 | |
|---|---|---|
| committer | 2025-02-05 11:24:10 -0800 | |
| commit | ef93f386acb5b2f0ebcb3b04d6e32fa79b0801d3 (patch) | |
| tree | e0a35f61a4d2b119c4627bde11c43ac3a9076637 | |
| parent | e1d2a135efccf87f3e820ae50de90c88c70395fa (diff) | |
| parent | c59db84c455570af8a314f426357355499c804d5 (diff) | |
Merge "Remove synchronization (and potential priority inversions) from RateLimitingCache" into main
| -rw-r--r-- | core/java/com/android/internal/util/RateLimitingCache.java | 76 | ||||
| -rw-r--r-- | core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java | 145 |
2 files changed, 192 insertions, 29 deletions
diff --git a/core/java/com/android/internal/util/RateLimitingCache.java b/core/java/com/android/internal/util/RateLimitingCache.java index 9916076fd0ef..956d5d680fe7 100644 --- a/core/java/com/android/internal/util/RateLimitingCache.java +++ b/core/java/com/android/internal/util/RateLimitingCache.java @@ -17,6 +17,8 @@ package com.android.internal.util; import android.os.SystemClock; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; /** * A speed/rate limiting cache that's used to cache a value to be returned as long as period hasn't @@ -30,6 +32,12 @@ import android.os.SystemClock; * and then the cached value is returned for the remainder of the period. It uses a simple fixed * window method to track rate. Use a window and count appropriate for bursts of calls and for * high latency/cost of the AIDL call. + * <p> + * This class is thread-safe. When multiple threads call get(), they will all fetch a new value + * if the cached value is stale. This is to prevent a slow getting thread from blocking other + * threads from getting a fresh value. In such circumsntaces it's possible to exceed + * <code>count</code> calls in a given period by up to the number of threads that are concurrently + * attempting to get a fresh value minus one. * * @param <Value> The type of the return value * @hide @@ -37,12 +45,11 @@ import android.os.SystemClock; @android.ravenwood.annotation.RavenwoodKeepWholeClass public class RateLimitingCache<Value> { - private volatile Value mCurrentValue; - private volatile long mLastTimestamp; // Can be last fetch time or window start of fetch time private final long mPeriodMillis; // window size private final int mLimit; // max per window - private int mCount = 0; // current count within window - private long mRandomOffset; // random offset to avoid batching of AIDL calls at window boundary + // random offset to avoid batching of AIDL calls at window boundary + private final long mRandomOffset; + private final AtomicReference<CachedValue> mCachedValue = new AtomicReference(); /** * The interface to fetch the actual value, if the cache is null or expired. @@ -56,6 +63,12 @@ public class RateLimitingCache<Value> { V fetchValue(); } + class CachedValue { + Value value; + long timestamp; + AtomicInteger count; // current count within window + } + /** * Create a speed limiting cache that returns the same value until periodMillis has passed * and then fetches a new value via the {@link ValueFetcher}. @@ -83,6 +96,8 @@ public class RateLimitingCache<Value> { mLimit = count; if (mLimit > 1 && periodMillis > 1) { mRandomOffset = (long) (Math.random() * (periodMillis / 2)); + } else { + mRandomOffset = 0; } } @@ -102,34 +117,39 @@ public class RateLimitingCache<Value> { * @return the cached or latest value */ public Value get(ValueFetcher<Value> query) { - // If the value never changes - if (mPeriodMillis < 0 && mLastTimestamp != 0) { - return mCurrentValue; - } + CachedValue cached = mCachedValue.get(); - synchronized (this) { - // Get the current time and add a random offset to avoid colliding with other - // caches with similar harmonic window boundaries - final long now = getTime() + mRandomOffset; - final boolean newWindow = now - mLastTimestamp >= mPeriodMillis; - if (newWindow || mCount < mLimit) { - // Fetch a new value - mCurrentValue = query.fetchValue(); + // If the value never changes and there is a previous cached value, return it + if (mPeriodMillis < 0 && cached != null && cached.timestamp != 0) { + return cached.value; + } - // If rate limiting, set timestamp to start of this window - if (mLimit > 1) { - mLastTimestamp = now - (now % mPeriodMillis); - } else { - mLastTimestamp = now; - } + // Get the current time and add a random offset to avoid colliding with other + // caches with similar harmonic window boundaries + final long now = getTime() + mRandomOffset; + final boolean newWindow = cached == null || now - cached.timestamp >= mPeriodMillis; + if (newWindow || cached.count.getAndIncrement() < mLimit) { + // Fetch a new value + Value freshValue = query.fetchValue(); + long freshTimestamp = now; + // If rate limiting, set timestamp to start of this window + if (mLimit > 1) { + freshTimestamp = now - (now % mPeriodMillis); + } - if (newWindow) { - mCount = 1; - } else { - mCount++; - } + CachedValue freshCached = new CachedValue(); + freshCached.value = freshValue; + freshCached.timestamp = freshTimestamp; + if (newWindow) { + freshCached.count = new AtomicInteger(1); + } else { + freshCached.count = cached.count; } - return mCurrentValue; + + // If we fail to CAS then it means that another thread beat us to it. + // In this case we don't override their work. + mCachedValue.compareAndSet(cached, freshCached); } + return mCachedValue.get().value; } } diff --git a/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java b/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java index 7734148bac5e..52ff79da26ea 100644 --- a/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java +++ b/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java @@ -18,9 +18,15 @@ package com.android.internal.util; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import android.os.SystemClock; import androidx.test.runner.AndroidJUnit4; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.CountDownLatch; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -38,7 +44,7 @@ public class RateLimitingCacheTest { mCounter = -1; } - RateLimitingCache.ValueFetcher<Integer> mFetcher = () -> { + private final RateLimitingCache.ValueFetcher<Integer> mFetcher = () -> { return ++mCounter; }; @@ -120,6 +126,143 @@ public class RateLimitingCacheTest { } /** + * Exercises concurrent access to the cache. + */ + @Test + public void testMultipleThreads() throws InterruptedException { + final long periodMillis = 1000; + final int maxCountPerPeriod = 10; + final RateLimitingCache<Integer> s = + new RateLimitingCache<>(periodMillis, maxCountPerPeriod); + + Thread t1 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(mFetcher); + } + }); + Thread t2 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(mFetcher); + } + }); + + final long startTimeMillis = SystemClock.elapsedRealtime(); + t1.start(); + t2.start(); + t1.join(); + t2.join(); + final long endTimeMillis = SystemClock.elapsedRealtime(); + + final long periodsElapsed = 1 + ((endTimeMillis - startTimeMillis) / periodMillis); + final long expected = Math.min(100 + 100, periodsElapsed * maxCountPerPeriod) - 1; + assertEquals(mCounter, expected); + } + + /** + * Multiple threads calling get() on the cache while the cached value is stale are allowed + * to fetch, regardless of the rate limiting. + * This is to prevent a slow getting thread from blocking other threads from getting a fresh + * value. + */ + @Test + public void testMultipleThreads_oneThreadIsSlow() throws InterruptedException { + final long periodMillis = 1000; + final int maxCountPerPeriod = 1; + final RateLimitingCache<Integer> s = + new RateLimitingCache<>(periodMillis, maxCountPerPeriod); + + final CountDownLatch latch1 = new CountDownLatch(2); + final CountDownLatch latch2 = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + final RateLimitingCache.ValueFetcher<Integer> fetcher = () -> { + latch1.countDown(); + try { + latch2.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return counter.incrementAndGet(); + }; + + Thread t1 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(fetcher); + } + }); + Thread t2 = new Thread(() -> { + for (int i = 0; i < 100; i++) { + s.get(fetcher); + } + }); + + t1.start(); + t2.start(); + // Both threads should be admitted to fetch because there is no fresh cached value, + // even though this exceeds the rate limit of at most 1 call per period. + // Wait for both threads to be fetching. + latch1.await(); + // Allow the fetcher to return. + latch2.countDown(); + // Wait for both threads to finish their fetches. + t1.join(); + t2.join(); + + assertEquals(counter.get(), 2); + } + + /** + * Even if multiple threads race to refresh the cache, only one thread gets to set a new value. + * This ensures, among other things, that the cache never returns values that were fetched out + * of order. + */ + @Test + public void testMultipleThreads_cachedValueNeverGoesBackInTime() throws InterruptedException { + final long periodMillis = 10; + final int maxCountPerPeriod = 3; + final RateLimitingCache<Integer> s = + new RateLimitingCache<>(periodMillis, maxCountPerPeriod); + final AtomicInteger counter = new AtomicInteger(); + final RateLimitingCache.ValueFetcher<Integer> fetcher = () -> { + // Note that this fetcher has a side effect, which is strictly not allowed for + // RateLimitingCache users, but we make an exception for the purpose of this test. + return counter.incrementAndGet(); + }; + + // Make three threads that spin on getting from the cache + final AtomicBoolean shouldRun = new AtomicBoolean(true); + Runnable worker = new Runnable() { + @Override + public void run() { + while (shouldRun.get()) { + s.get(fetcher); + } + } + }; + Thread t1 = new Thread(worker); + Thread t2 = new Thread(worker); + Thread t3 = new Thread(worker); + t1.start(); + t2.start(); + t3.start(); + + // Get values until a sufficiently convincing high value while ensuring that values are + // monotonically non-decreasing. + int lastSeen = 0; + while (lastSeen < 10000) { + int value = s.get(fetcher); + if (value < lastSeen) { + fail("Unexpectedly saw decreasing value " + value + " after " + lastSeen); + } + lastSeen = value; + } + + shouldRun.set(false); + t1.join(); + t2.join(); + t3.join(); + } + + /** * Helper to make repeated calls every 5 millis to verify the number of expected fetches for * the given parameters. * @param cache the cache object |