Add current metadata to RestoreCoordinator state

so we know which backup version we need to expect during restore
diff --git a/app/src/main/java/com/stevesoltys/seedvault/restore/RestorableBackup.kt b/app/src/main/java/com/stevesoltys/seedvault/restore/RestorableBackup.kt
index 643ae84..ddcaeb8 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/restore/RestorableBackup.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/restore/RestorableBackup.kt
@@ -3,7 +3,7 @@
 import com.stevesoltys.seedvault.metadata.BackupMetadata
 import com.stevesoltys.seedvault.metadata.PackageMetadataMap
 
-data class RestorableBackup(private val backupMetadata: BackupMetadata) {
+data class RestorableBackup(val backupMetadata: BackupMetadata) {
 
     val name: String
         get() = backupMetadata.deviceName
diff --git a/app/src/main/java/com/stevesoltys/seedvault/restore/RestoreViewModel.kt b/app/src/main/java/com/stevesoltys/seedvault/restore/RestoreViewModel.kt
index 6a0c15a..99f8c20 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/restore/RestoreViewModel.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/restore/RestoreViewModel.kt
@@ -148,6 +148,7 @@
     }
 
     override fun onRestorableBackupClicked(restorableBackup: RestorableBackup) {
+        restoreCoordinator.beforeStartRestore(restorableBackup.backupMetadata)
         mChosenRestorableBackup.value = restorableBackup
         mDisplayFragment.setEvent(RESTORE_APPS)
     }
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/ConfigurableBackupTransport.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/ConfigurableBackupTransport.kt
index 2d39c3d..3592666 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/ConfigurableBackupTransport.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/ConfigurableBackupTransport.kt
@@ -204,8 +204,8 @@
         return restoreCoordinator.getCurrentRestoreSet()
     }
 
-    override fun startRestore(token: Long, packages: Array<PackageInfo>): Int {
-        return restoreCoordinator.startRestore(token, packages)
+    override fun startRestore(token: Long, packages: Array<PackageInfo>): Int = runBlocking {
+        restoreCoordinator.startRestore(token, packages)
     }
 
     override fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int = runBlocking {
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt
index 7889a4e..4c77e92 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt
@@ -2,7 +2,6 @@
 
 import android.app.backup.BackupTransport.TRANSPORT_ERROR
 import android.app.backup.BackupTransport.TRANSPORT_OK
-import android.app.backup.IBackupManager
 import android.app.backup.RestoreDescription
 import android.app.backup.RestoreDescription.NO_MORE_PACKAGES
 import android.app.backup.RestoreDescription.TYPE_FULL_STREAM
@@ -12,7 +11,6 @@
 import android.content.pm.PackageInfo
 import android.os.ParcelFileDescriptor
 import android.util.Log
-import androidx.collection.LongSparseArray
 import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER
 import com.stevesoltys.seedvault.R
 import com.stevesoltys.seedvault.header.UnsupportedVersionException
@@ -31,7 +29,8 @@
     /**
      * Optional [PackageInfo] for single package restore, to reduce data needed to read for @pm@
      */
-    val pmPackageInfo: PackageInfo?
+    val pmPackageInfo: PackageInfo?,
+    val backupMetadata: BackupMetadata
 ) {
     var currentPackage: String? = null
 }
@@ -51,7 +50,7 @@
 ) {
 
     private var state: RestoreCoordinatorState? = null
-    private var backupMetadata: LongSparseArray<BackupMetadata>? = null
+    private var backupMetadata: BackupMetadata? = null
     private val failedPackages = ArrayList<String>()
 
     suspend fun getAvailableMetadata(): Map<Long, BackupMetadata>? {
@@ -114,18 +113,26 @@
     }
 
     /**
+     * Call this before starting the restore as an optimization to prevent re-fetching metadata.
+     */
+    fun beforeStartRestore(backupMetadata: BackupMetadata) {
+        this.backupMetadata = backupMetadata
+    }
+
+    /**
      * Start restoring application data from backup.
      * After calling this function,
      * there will be alternate calls to [nextRestorePackage] and [getRestoreData]
      * to walk through the actual application data.
      *
-     * @param token A backup token as returned by [getAvailableRestoreSets] or [getCurrentRestoreSet].
+     * @param token A backup token as returned by [getAvailableRestoreSets]
+     * or [getCurrentRestoreSet].
      * @param packages List of applications to restore (if data is available).
      * Application data will be restored in the order given.
      * @return One of [TRANSPORT_OK] (OK so far, call [nextRestorePackage])
      * or [TRANSPORT_ERROR] (an error occurred, the restore should be aborted and rescheduled).
      */
-    fun startRestore(token: Long, packages: Array<out PackageInfo>): Int {
+    suspend fun startRestore(token: Long, packages: Array<out PackageInfo>): Int {
         check(state == null) { "Started new restore with existing state: $state" }
         Log.i(TAG, "Start restore with ${packages.map { info -> info.packageName }}")
 
@@ -151,7 +158,13 @@
                 packages[1]
             } else null
 
-        state = RestoreCoordinatorState(token, packages.iterator(), pmPackageInfo)
+        val metadata = if (backupMetadata?.token == token) {
+            backupMetadata!! // if token matches, backupMetadata is non-null
+        } else {
+            getAvailableMetadata()?.get(token) ?: return TRANSPORT_ERROR
+        }
+        state = RestoreCoordinatorState(token, packages.iterator(), pmPackageInfo, metadata)
+        backupMetadata = null
         failedPackages.clear()
         return TRANSPORT_OK
     }
@@ -269,18 +282,6 @@
         state = null
     }
 
-    /**
-     * Call this after calling [IBackupManager.getAvailableRestoreTokenForUser]
-     * to retrieve additional [BackupMetadata] that is not available in [RestoreSet].
-     *
-     * It will also clear the saved metadata, so that subsequent calls will return null.
-     */
-    fun getAndClearBackupMetadata(): LongSparseArray<BackupMetadata>? {
-        val result = backupMetadata
-        backupMetadata = null
-        return result
-    }
-
     fun isFailedPackage(packageName: String) = packageName in failedPackages
 
     // TODO this is plugin specific, needs to be factored out when supporting different plugins
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt
index c2e0c39..ee5084b 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt
@@ -181,6 +181,7 @@
         assertEquals(TRANSPORT_OK, backup.finishBackup())
 
         // start restore
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, arrayOf(packageInfo)))
 
         // find data for K/V backup
@@ -251,6 +252,7 @@
         assertEquals(TRANSPORT_OK, backup.finishBackup())
 
         // start restore
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, arrayOf(packageInfo)))
 
         // find data for K/V backup
@@ -311,6 +313,7 @@
         assertEquals(TRANSPORT_OK, backup.finishBackup())
 
         // start restore
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, arrayOf(packageInfo)))
 
         // find data only for full backup
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/TransportTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/TransportTest.kt
index 712ee85..6691804 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/TransportTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/TransportTest.kt
@@ -10,6 +10,8 @@
 import com.stevesoltys.seedvault.Clock
 import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER
 import com.stevesoltys.seedvault.crypto.Crypto
+import com.stevesoltys.seedvault.getRandomString
+import com.stevesoltys.seedvault.metadata.BackupMetadata
 import com.stevesoltys.seedvault.metadata.MetadataManager
 import com.stevesoltys.seedvault.settings.SettingsManager
 import io.mockk.every
@@ -41,6 +43,12 @@
     protected val pmPackageInfo = PackageInfo().apply {
         packageName = MAGIC_PACKAGE_MANAGER
     }
+    protected val metadata = BackupMetadata(
+        token = token,
+        androidVersion = Random.nextInt(),
+        androidIncremental = getRandomString(),
+        deviceName = getRandomString()
+    )
 
     init {
         mockkStatic(Log::class)
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt
index 9318f92..bafef62 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt
@@ -10,7 +10,6 @@
 import android.os.ParcelFileDescriptor
 import com.stevesoltys.seedvault.coAssertThrows
 import com.stevesoltys.seedvault.getRandomString
-import com.stevesoltys.seedvault.metadata.BackupMetadata
 import com.stevesoltys.seedvault.metadata.EncryptedBackupMetadata
 import com.stevesoltys.seedvault.metadata.MetadataReader
 import com.stevesoltys.seedvault.metadata.PackageMetadata
@@ -26,7 +25,6 @@
 import kotlinx.coroutines.runBlocking
 import org.junit.jupiter.api.Assertions.assertEquals
 import org.junit.jupiter.api.Assertions.assertNotNull
-import org.junit.jupiter.api.Assertions.assertNull
 import org.junit.jupiter.api.Assertions.assertThrows
 import org.junit.jupiter.api.Assertions.fail
 import org.junit.jupiter.api.Test
@@ -69,18 +67,13 @@
     @Test
     fun `getAvailableRestoreSets() builds set from plugin response`() = runBlocking {
         val encryptedMetadata = EncryptedBackupMetadata(token, inputStream)
-        val metadata = BackupMetadata(
-            token = token,
-            androidVersion = Random.nextInt(),
-            androidIncremental = getRandomString(),
-            deviceName = getRandomString()
-        )
 
         coEvery { plugin.getAvailableBackups() } returns sequenceOf(
             encryptedMetadata,
-            encryptedMetadata
+            EncryptedBackupMetadata(token + 1, inputStream)
         )
         every { metadataReader.readMetadata(inputStream, token) } returns metadata
+        every { metadataReader.readMetadata(inputStream, token + 1) } returns metadata
         every { inputStream.close() } just Runs
 
         val sets = restore.getAvailableRestoreSets() ?: fail()
@@ -97,68 +90,102 @@
     }
 
     @Test
-    fun `startRestore() returns OK`() {
+    fun `startRestore() returns OK`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray))
     }
 
     @Test
-    fun `startRestore() can not be called twice`() {
+    fun `startRestore() fetches metadata if missing`() = runBlocking {
+        coEvery { plugin.getAvailableBackups() } returns sequenceOf(
+            EncryptedBackupMetadata(token, inputStream),
+            EncryptedBackupMetadata(token + 1, inputStream)
+        )
+        every { metadataReader.readMetadata(inputStream, token) } returns metadata
+        every { metadataReader.readMetadata(inputStream, token + 1) } returns metadata
+        every { inputStream.close() } just Runs
+
+        assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray))
+    }
+
+    @Test
+    fun `startRestore() errors if metadata is not matching token`() = runBlocking {
+        coEvery { plugin.getAvailableBackups() } returns sequenceOf(
+            EncryptedBackupMetadata(token + 42, inputStream)
+        )
+        every { metadataReader.readMetadata(inputStream, token + 42) } returns metadata
+        every { inputStream.close() } just Runs
+
+        assertEquals(TRANSPORT_ERROR, restore.startRestore(token, packageInfoArray))
+    }
+
+    @Test
+    fun `startRestore() can not be called twice`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray))
         assertThrows(IllegalStateException::class.javaObjectType) {
-            restore.startRestore(token, packageInfoArray)
+            runBlocking {
+                restore.startRestore(token, packageInfoArray)
+            }
         }
+        Unit
     }
 
     @Test
-    fun `startRestore() can be be called again after restore finished`() {
+    fun `startRestore() can be be called again after restore finished`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray))
 
         every { full.hasState() } returns false
         restore.finishRestore()
 
+        restore.beforeStartRestore(metadata)
         assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray))
     }
 
     @Test
-    fun `startRestore() optimized auto-restore with removed storage shows notification`() {
-        every { settingsManager.getStorage() } returns storage
-        every { storage.isUnavailableUsb(context) } returns true
-        every { metadataManager.getPackageMetadata(packageName) } returns PackageMetadata(42L)
-        every { storage.name } returns storageName
-        every {
-            notificationManager.onRemovableStorageNotAvailableForRestore(
-                packageName,
-                storageName
-            )
-        } just Runs
+    fun `startRestore() optimized auto-restore with removed storage shows notification`() =
+        runBlocking {
+            every { settingsManager.getStorage() } returns storage
+            every { storage.isUnavailableUsb(context) } returns true
+            every { metadataManager.getPackageMetadata(packageName) } returns PackageMetadata(42L)
+            every { storage.name } returns storageName
+            every {
+                notificationManager.onRemovableStorageNotAvailableForRestore(
+                    packageName,
+                    storageName
+                )
+            } just Runs
 
-        assertEquals(TRANSPORT_ERROR, restore.startRestore(token, pmPackageInfoArray))
+            assertEquals(TRANSPORT_ERROR, restore.startRestore(token, pmPackageInfoArray))
 
-        verify(exactly = 1) {
-            notificationManager.onRemovableStorageNotAvailableForRestore(
-                packageName,
-                storageName
-            )
+            verify(exactly = 1) {
+                notificationManager.onRemovableStorageNotAvailableForRestore(
+                    packageName,
+                    storageName
+                )
+            }
         }
-    }
 
     @Test
-    fun `startRestore() optimized auto-restore with available storage shows no notification`() {
-        every { settingsManager.getStorage() } returns storage
-        every { storage.isUnavailableUsb(context) } returns false
+    fun `startRestore() optimized auto-restore with available storage shows no notification`() =
+        runBlocking {
+            every { settingsManager.getStorage() } returns storage
+            every { storage.isUnavailableUsb(context) } returns false
 
-        assertEquals(TRANSPORT_OK, restore.startRestore(token, pmPackageInfoArray))
+            restore.beforeStartRestore(metadata)
+            assertEquals(TRANSPORT_OK, restore.startRestore(token, pmPackageInfoArray))
 
-        verify(exactly = 0) {
-            notificationManager.onRemovableStorageNotAvailableForRestore(
-                packageName,
-                storageName
-            )
+            verify(exactly = 0) {
+                notificationManager.onRemovableStorageNotAvailableForRestore(
+                    packageName,
+                    storageName
+                )
+            }
         }
-    }
 
     @Test
-    fun `startRestore() with removed storage shows no notification`() {
+    fun `startRestore() with removed storage shows no notification`() = runBlocking {
         every { settingsManager.getStorage() } returns storage
         every { storage.isUnavailableUsb(context) } returns true
         every { metadataManager.getPackageMetadata(packageName) } returns null
@@ -182,6 +209,7 @@
 
     @Test
     fun `nextRestorePackage() returns KV description and takes precedence`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         restore.startRestore(token, packageInfoArray)
 
         coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
@@ -193,6 +221,7 @@
 
     @Test
     fun `nextRestorePackage() returns full description if no KV data found`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         restore.startRestore(token, packageInfoArray)
 
         coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
@@ -205,6 +234,7 @@
 
     @Test
     fun `nextRestorePackage() returns NO_MORE_PACKAGES if data found`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         restore.startRestore(token, packageInfoArray)
 
         coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
@@ -215,6 +245,7 @@
 
     @Test
     fun `nextRestorePackage() returns all packages from startRestore()`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         restore.startRestore(token, packageInfoArray2)
 
         coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
@@ -234,22 +265,24 @@
     }
 
     @Test
-    fun `when kv#hasDataForPackage() throws return null`() = runBlocking {
+    fun `when kv#hasDataForPackage() throws, it tries next package`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         restore.startRestore(token, packageInfoArray)
 
         coEvery { kv.hasDataForPackage(token, packageInfo) } throws IOException()
 
-        assertNull(restore.nextRestorePackage())
+        assertEquals(NO_MORE_PACKAGES, restore.nextRestorePackage())
     }
 
     @Test
-    fun `when full#hasDataForPackage() throws return null`() = runBlocking {
+    fun `when full#hasDataForPackage() throws, it tries next package`() = runBlocking {
+        restore.beforeStartRestore(metadata)
         restore.startRestore(token, packageInfoArray)
 
         coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
         coEvery { full.hasDataForPackage(token, packageInfo) } throws IOException()
 
-        assertNull(restore.nextRestorePackage())
+        assertEquals(NO_MORE_PACKAGES, restore.nextRestorePackage())
     }
 
     @Test
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt
index 92bc1b0..b0ee983 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt
@@ -61,7 +61,7 @@
         kvRestore,
         fullRestore,
         metadataReader
-    )
+    ).apply { beforeStartRestore(metadata) }
 
     private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true)
     private val appData = ("562AB665C3543120FC794D7CDA3AC18E5959235A4D" +