From 89714ce0759dabe11b53495c77ef164c6de1c677 Mon Sep 17 00:00:00 2001 From: Andrey Epin Date: Tue, 16 May 2023 15:22:32 -0700 Subject: Limit the number of parallel preview loadings in the image preview loader. With ag/23192997 the preview cache size effectively increased the total number of parallel image requests to a large enough value to DDoS Files app's content provider i.e. when sharing a large number of images, cache pre-population requests overlap with the actual preview loadings and causing some of them to fail to load. Update ImagePreviewImageLoader to use a Semaphore to limit the total number of parallel preview loadings. Fix: 283000541 Test: manual testing Change-Id: I6152f6e589a8b36a4810d617633017b72202e66f --- .../contentpreview/ImagePreviewImageLoader.kt | 29 +++- .../contentpreview/ImagePreviewImageLoaderTest.kt | 177 ++++++++++++++++++++- 2 files changed, 196 insertions(+), 10 deletions(-) (limited to 'java') diff --git a/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt b/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt index 89b79a0a..22dd1125 100644 --- a/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt +++ b/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt @@ -26,28 +26,42 @@ import androidx.annotation.VisibleForTesting import androidx.collection.LruCache import androidx.lifecycle.Lifecycle import androidx.lifecycle.coroutineScope +import java.util.function.Consumer import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import java.util.function.Consumer +import kotlinx.coroutines.sync.Semaphore private const val TAG = "ImagePreviewImageLoader" /** - * Implements preview image loading for the content preview UI. Provides requests deduplication and - * image caching. + * Implements preview image loading for the content preview UI. Provides requests deduplication, + * image caching, and a limit on the number of parallel loadings. */ @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE) -class ImagePreviewImageLoader( +class ImagePreviewImageLoader +@VisibleForTesting +constructor( private val scope: CoroutineScope, thumbnailSize: Int, private val contentResolver: ContentResolver, cacheSize: Int, + // TODO: consider providing a scope with the dispatcher configured with + // [CoroutineDispatcher#limitedParallelism] instead + private val contentResolverSemaphore: Semaphore, ) : ImageLoader { + constructor( + scope: CoroutineScope, + thumbnailSize: Int, + contentResolver: ContentResolver, + cacheSize: Int, + maxSimultaneousRequests: Int = 4 + ) : this(scope, thumbnailSize, contentResolver, cacheSize, Semaphore(maxSimultaneousRequests)) + private val thumbnailSize: Size = Size(thumbnailSize, thumbnailSize) private val lock = Any() @@ -103,13 +117,16 @@ class ImagePreviewImageLoader( } } - private fun RequestRecord.loadBitmap() { + private suspend fun RequestRecord.loadBitmap() { + contentResolverSemaphore.acquire() val bitmap = try { contentResolver.loadThumbnail(uri, thumbnailSize, null) } catch (t: Throwable) { Log.d(TAG, "failed to load $uri preview", t) null + } finally { + contentResolverSemaphore.release() } complete(bitmap) } @@ -136,4 +153,4 @@ class ImagePreviewImageLoader( val deferred: CompletableDeferred, @GuardedBy("lock") var caching: Boolean ) -} \ No newline at end of file +} diff --git a/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt b/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt index 184401a0..6e57c289 100644 --- a/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt +++ b/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt @@ -27,20 +27,33 @@ import com.android.intentresolver.any import com.android.intentresolver.anyOrNull import com.android.intentresolver.mock import com.android.intentresolver.whenever +import com.google.common.truth.Truth.assertThat +import java.util.ArrayDeque +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit.MILLISECONDS +import java.util.concurrent.TimeUnit.SECONDS +import java.util.concurrent.atomic.AtomicInteger +import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineStart.UNDISPATCHED import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Runnable import kotlinx.coroutines.async import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import kotlinx.coroutines.plus +import kotlinx.coroutines.sync.Semaphore import kotlinx.coroutines.test.StandardTestDispatcher import kotlinx.coroutines.test.TestCoroutineScheduler import kotlinx.coroutines.test.UnconfinedTestDispatcher import kotlinx.coroutines.test.resetMain import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.setMain +import kotlinx.coroutines.yield import org.junit.After import org.junit.Before import org.junit.Test @@ -72,7 +85,7 @@ class ImagePreviewImageLoaderTest { lifecycleOwner.lifecycle.coroutineScope + dispatcher, imageSize.width, contentResolver, - 1, + cacheSize = 1, ) } @@ -118,7 +131,7 @@ class ImagePreviewImageLoaderTest { lifecycleOwner.lifecycle.coroutineScope + dispatcher, imageSize.width, contentResolver, - 1, + cacheSize = 1, ) coroutineScope { launch(start = UNDISPATCHED) { testSubject(uriOne, false) } @@ -164,7 +177,7 @@ class ImagePreviewImageLoaderTest { lifecycleOwner.lifecycle.coroutineScope + dispatcher, imageSize.width, contentResolver, - 1 + cacheSize = 1, ) coroutineScope { val deferred = async(start = UNDISPATCHED) { testSubject(uriOne, false) } @@ -183,7 +196,7 @@ class ImagePreviewImageLoaderTest { lifecycleOwner.lifecycle.coroutineScope + dispatcher, imageSize.width, contentResolver, - 1 + cacheSize = 1, ) coroutineScope { launch(start = UNDISPATCHED) { testSubject(uriOne, false) } @@ -194,4 +207,160 @@ class ImagePreviewImageLoaderTest { verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null) } + + @Test + fun invoke_semaphoreGuardsContentResolverCalls() = runTest { + val contentResolver = + mock { + whenever(loadThumbnail(any(), any(), anyOrNull())) + .thenThrow(SecurityException("test")) + } + val acquireCount = AtomicInteger() + val releaseCount = AtomicInteger() + val testSemaphore = + object : Semaphore { + override val availablePermits: Int + get() = error("Unexpected invocation") + + override suspend fun acquire() { + acquireCount.getAndIncrement() + } + + override fun tryAcquire(): Boolean { + error("Unexpected invocation") + } + + override fun release() { + releaseCount.getAndIncrement() + } + } + + val testSubject = + ImagePreviewImageLoader( + lifecycleOwner.lifecycle.coroutineScope + dispatcher, + imageSize.width, + contentResolver, + cacheSize = 1, + testSemaphore, + ) + testSubject(uriOne, false) + + verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null) + assertThat(acquireCount.get()).isEqualTo(1) + assertThat(releaseCount.get()).isEqualTo(1) + } + + @Test + fun invoke_semaphoreIsReleasedAfterContentResolverFailure() = runTest { + val semaphoreDeferred = CompletableDeferred() + val releaseCount = AtomicInteger() + val testSemaphore = + object : Semaphore { + override val availablePermits: Int + get() = error("Unexpected invocation") + + override suspend fun acquire() { + semaphoreDeferred.await() + } + + override fun tryAcquire(): Boolean { + error("Unexpected invocation") + } + + override fun release() { + releaseCount.getAndIncrement() + } + } + + val testSubject = + ImagePreviewImageLoader( + lifecycleOwner.lifecycle.coroutineScope + dispatcher, + imageSize.width, + contentResolver, + cacheSize = 1, + testSemaphore, + ) + launch(start = UNDISPATCHED) { testSubject(uriOne, false) } + + verify(contentResolver, never()).loadThumbnail(any(), any(), anyOrNull()) + + semaphoreDeferred.complete(Unit) + + verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null) + assertThat(releaseCount.get()).isEqualTo(1) + } + + @Test + fun invoke_multipleSimultaneousCalls_limitOnNumberOfSimultaneousOutgoingCallsIsRespected() { + val requestCount = 4 + val thumbnailCallsCdl = CountDownLatch(requestCount) + val pendingThumbnailCalls = ArrayDeque() + val contentResolver = + mock { + whenever(loadThumbnail(any(), any(), anyOrNull())).thenAnswer { + val latch = CountDownLatch(1) + synchronized(pendingThumbnailCalls) { pendingThumbnailCalls.offer(latch) } + thumbnailCallsCdl.countDown() + latch.await() + bitmap + } + } + val name = "LoadImage" + val maxSimultaneousRequests = 2 + val threadsStartedCdl = CountDownLatch(requestCount) + val dispatcher = NewThreadDispatcher(name) { threadsStartedCdl.countDown() } + val testSubject = + ImagePreviewImageLoader( + lifecycleOwner.lifecycle.coroutineScope + dispatcher + CoroutineName(name), + imageSize.width, + contentResolver, + cacheSize = 1, + maxSimultaneousRequests, + ) + runTest { + repeat(requestCount) { + launch { testSubject(Uri.parse("content://org.pkg.app/image-$it.png")) } + } + yield() + // wait for all requests to be dispatched + assertThat(threadsStartedCdl.await(5, SECONDS)).isTrue() + + assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isFalse() + synchronized(pendingThumbnailCalls) { + assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests) + } + + pendingThumbnailCalls.poll()?.countDown() + assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isFalse() + synchronized(pendingThumbnailCalls) { + assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests) + } + + pendingThumbnailCalls.poll()?.countDown() + assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isTrue() + synchronized(pendingThumbnailCalls) { + assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests) + } + for (cdl in pendingThumbnailCalls) { + cdl.countDown() + } + } + } +} + +private class NewThreadDispatcher( + private val coroutineName: String, + private val launchedCallback: () -> Unit +) : CoroutineDispatcher() { + override fun isDispatchNeeded(context: CoroutineContext): Boolean = true + + override fun dispatch(context: CoroutineContext, block: Runnable) { + Thread { + if (coroutineName == context[CoroutineName.Key]?.name) { + launchedCallback() + } + block.run() + } + .start() + } } -- cgit v1.2.3-59-g8ed1b