summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Shai Barack <shayba@google.com> 2025-02-05 11:24:10 -0800
committer Android (Google) Code Review <android-gerrit@google.com> 2025-02-05 11:24:10 -0800
commitef93f386acb5b2f0ebcb3b04d6e32fa79b0801d3 (patch)
treee0a35f61a4d2b119c4627bde11c43ac3a9076637
parente1d2a135efccf87f3e820ae50de90c88c70395fa (diff)
parentc59db84c455570af8a314f426357355499c804d5 (diff)
Merge "Remove synchronization (and potential priority inversions) from RateLimitingCache" into main
-rw-r--r--core/java/com/android/internal/util/RateLimitingCache.java76
-rw-r--r--core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java145
2 files changed, 192 insertions, 29 deletions
diff --git a/core/java/com/android/internal/util/RateLimitingCache.java b/core/java/com/android/internal/util/RateLimitingCache.java
index 9916076fd0ef..956d5d680fe7 100644
--- a/core/java/com/android/internal/util/RateLimitingCache.java
+++ b/core/java/com/android/internal/util/RateLimitingCache.java
@@ -17,6 +17,8 @@
package com.android.internal.util;
import android.os.SystemClock;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
/**
* A speed/rate limiting cache that's used to cache a value to be returned as long as period hasn't
@@ -30,6 +32,12 @@ import android.os.SystemClock;
* and then the cached value is returned for the remainder of the period. It uses a simple fixed
* window method to track rate. Use a window and count appropriate for bursts of calls and for
* high latency/cost of the AIDL call.
+ * <p>
+ * This class is thread-safe. When multiple threads call get(), they will all fetch a new value
+ * if the cached value is stale. This is to prevent a slow getting thread from blocking other
+ * threads from getting a fresh value. In such circumsntaces it's possible to exceed
+ * <code>count</code> calls in a given period by up to the number of threads that are concurrently
+ * attempting to get a fresh value minus one.
*
* @param <Value> The type of the return value
* @hide
@@ -37,12 +45,11 @@ import android.os.SystemClock;
@android.ravenwood.annotation.RavenwoodKeepWholeClass
public class RateLimitingCache<Value> {
- private volatile Value mCurrentValue;
- private volatile long mLastTimestamp; // Can be last fetch time or window start of fetch time
private final long mPeriodMillis; // window size
private final int mLimit; // max per window
- private int mCount = 0; // current count within window
- private long mRandomOffset; // random offset to avoid batching of AIDL calls at window boundary
+ // random offset to avoid batching of AIDL calls at window boundary
+ private final long mRandomOffset;
+ private final AtomicReference<CachedValue> mCachedValue = new AtomicReference();
/**
* The interface to fetch the actual value, if the cache is null or expired.
@@ -56,6 +63,12 @@ public class RateLimitingCache<Value> {
V fetchValue();
}
+ class CachedValue {
+ Value value;
+ long timestamp;
+ AtomicInteger count; // current count within window
+ }
+
/**
* Create a speed limiting cache that returns the same value until periodMillis has passed
* and then fetches a new value via the {@link ValueFetcher}.
@@ -83,6 +96,8 @@ public class RateLimitingCache<Value> {
mLimit = count;
if (mLimit > 1 && periodMillis > 1) {
mRandomOffset = (long) (Math.random() * (periodMillis / 2));
+ } else {
+ mRandomOffset = 0;
}
}
@@ -102,34 +117,39 @@ public class RateLimitingCache<Value> {
* @return the cached or latest value
*/
public Value get(ValueFetcher<Value> query) {
- // If the value never changes
- if (mPeriodMillis < 0 && mLastTimestamp != 0) {
- return mCurrentValue;
- }
+ CachedValue cached = mCachedValue.get();
- synchronized (this) {
- // Get the current time and add a random offset to avoid colliding with other
- // caches with similar harmonic window boundaries
- final long now = getTime() + mRandomOffset;
- final boolean newWindow = now - mLastTimestamp >= mPeriodMillis;
- if (newWindow || mCount < mLimit) {
- // Fetch a new value
- mCurrentValue = query.fetchValue();
+ // If the value never changes and there is a previous cached value, return it
+ if (mPeriodMillis < 0 && cached != null && cached.timestamp != 0) {
+ return cached.value;
+ }
- // If rate limiting, set timestamp to start of this window
- if (mLimit > 1) {
- mLastTimestamp = now - (now % mPeriodMillis);
- } else {
- mLastTimestamp = now;
- }
+ // Get the current time and add a random offset to avoid colliding with other
+ // caches with similar harmonic window boundaries
+ final long now = getTime() + mRandomOffset;
+ final boolean newWindow = cached == null || now - cached.timestamp >= mPeriodMillis;
+ if (newWindow || cached.count.getAndIncrement() < mLimit) {
+ // Fetch a new value
+ Value freshValue = query.fetchValue();
+ long freshTimestamp = now;
+ // If rate limiting, set timestamp to start of this window
+ if (mLimit > 1) {
+ freshTimestamp = now - (now % mPeriodMillis);
+ }
- if (newWindow) {
- mCount = 1;
- } else {
- mCount++;
- }
+ CachedValue freshCached = new CachedValue();
+ freshCached.value = freshValue;
+ freshCached.timestamp = freshTimestamp;
+ if (newWindow) {
+ freshCached.count = new AtomicInteger(1);
+ } else {
+ freshCached.count = cached.count;
}
- return mCurrentValue;
+
+ // If we fail to CAS then it means that another thread beat us to it.
+ // In this case we don't override their work.
+ mCachedValue.compareAndSet(cached, freshCached);
}
+ return mCachedValue.get().value;
}
}
diff --git a/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java b/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java
index 7734148bac5e..52ff79da26ea 100644
--- a/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java
+++ b/core/tests/coretests/src/com/android/internal/util/RateLimitingCacheTest.java
@@ -18,9 +18,15 @@ package com.android.internal.util;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import android.os.SystemClock;
import androidx.test.runner.AndroidJUnit4;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.CountDownLatch;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -38,7 +44,7 @@ public class RateLimitingCacheTest {
mCounter = -1;
}
- RateLimitingCache.ValueFetcher<Integer> mFetcher = () -> {
+ private final RateLimitingCache.ValueFetcher<Integer> mFetcher = () -> {
return ++mCounter;
};
@@ -120,6 +126,143 @@ public class RateLimitingCacheTest {
}
/**
+ * Exercises concurrent access to the cache.
+ */
+ @Test
+ public void testMultipleThreads() throws InterruptedException {
+ final long periodMillis = 1000;
+ final int maxCountPerPeriod = 10;
+ final RateLimitingCache<Integer> s =
+ new RateLimitingCache<>(periodMillis, maxCountPerPeriod);
+
+ Thread t1 = new Thread(() -> {
+ for (int i = 0; i < 100; i++) {
+ s.get(mFetcher);
+ }
+ });
+ Thread t2 = new Thread(() -> {
+ for (int i = 0; i < 100; i++) {
+ s.get(mFetcher);
+ }
+ });
+
+ final long startTimeMillis = SystemClock.elapsedRealtime();
+ t1.start();
+ t2.start();
+ t1.join();
+ t2.join();
+ final long endTimeMillis = SystemClock.elapsedRealtime();
+
+ final long periodsElapsed = 1 + ((endTimeMillis - startTimeMillis) / periodMillis);
+ final long expected = Math.min(100 + 100, periodsElapsed * maxCountPerPeriod) - 1;
+ assertEquals(mCounter, expected);
+ }
+
+ /**
+ * Multiple threads calling get() on the cache while the cached value is stale are allowed
+ * to fetch, regardless of the rate limiting.
+ * This is to prevent a slow getting thread from blocking other threads from getting a fresh
+ * value.
+ */
+ @Test
+ public void testMultipleThreads_oneThreadIsSlow() throws InterruptedException {
+ final long periodMillis = 1000;
+ final int maxCountPerPeriod = 1;
+ final RateLimitingCache<Integer> s =
+ new RateLimitingCache<>(periodMillis, maxCountPerPeriod);
+
+ final CountDownLatch latch1 = new CountDownLatch(2);
+ final CountDownLatch latch2 = new CountDownLatch(1);
+ final AtomicInteger counter = new AtomicInteger();
+ final RateLimitingCache.ValueFetcher<Integer> fetcher = () -> {
+ latch1.countDown();
+ try {
+ latch2.await();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return counter.incrementAndGet();
+ };
+
+ Thread t1 = new Thread(() -> {
+ for (int i = 0; i < 100; i++) {
+ s.get(fetcher);
+ }
+ });
+ Thread t2 = new Thread(() -> {
+ for (int i = 0; i < 100; i++) {
+ s.get(fetcher);
+ }
+ });
+
+ t1.start();
+ t2.start();
+ // Both threads should be admitted to fetch because there is no fresh cached value,
+ // even though this exceeds the rate limit of at most 1 call per period.
+ // Wait for both threads to be fetching.
+ latch1.await();
+ // Allow the fetcher to return.
+ latch2.countDown();
+ // Wait for both threads to finish their fetches.
+ t1.join();
+ t2.join();
+
+ assertEquals(counter.get(), 2);
+ }
+
+ /**
+ * Even if multiple threads race to refresh the cache, only one thread gets to set a new value.
+ * This ensures, among other things, that the cache never returns values that were fetched out
+ * of order.
+ */
+ @Test
+ public void testMultipleThreads_cachedValueNeverGoesBackInTime() throws InterruptedException {
+ final long periodMillis = 10;
+ final int maxCountPerPeriod = 3;
+ final RateLimitingCache<Integer> s =
+ new RateLimitingCache<>(periodMillis, maxCountPerPeriod);
+ final AtomicInteger counter = new AtomicInteger();
+ final RateLimitingCache.ValueFetcher<Integer> fetcher = () -> {
+ // Note that this fetcher has a side effect, which is strictly not allowed for
+ // RateLimitingCache users, but we make an exception for the purpose of this test.
+ return counter.incrementAndGet();
+ };
+
+ // Make three threads that spin on getting from the cache
+ final AtomicBoolean shouldRun = new AtomicBoolean(true);
+ Runnable worker = new Runnable() {
+ @Override
+ public void run() {
+ while (shouldRun.get()) {
+ s.get(fetcher);
+ }
+ }
+ };
+ Thread t1 = new Thread(worker);
+ Thread t2 = new Thread(worker);
+ Thread t3 = new Thread(worker);
+ t1.start();
+ t2.start();
+ t3.start();
+
+ // Get values until a sufficiently convincing high value while ensuring that values are
+ // monotonically non-decreasing.
+ int lastSeen = 0;
+ while (lastSeen < 10000) {
+ int value = s.get(fetcher);
+ if (value < lastSeen) {
+ fail("Unexpectedly saw decreasing value " + value + " after " + lastSeen);
+ }
+ lastSeen = value;
+ }
+
+ shouldRun.set(false);
+ t1.join();
+ t2.join();
+ t3.join();
+ }
+
+ /**
* Helper to make repeated calls every 5 millis to verify the number of expected fetches for
* the given parameters.
* @param cache the cache object