K/V restore using single file
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt
index 31a7436..62c21b6 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt
@@ -272,6 +272,7 @@
val salt = metadataManager.salt
val result = kv.performBackup(packageInfo, data, flags, token, salt)
if (result == TRANSPORT_OK && packageName == MAGIC_PACKAGE_MANAGER) {
+ // TODO move to finish backup of @pm@ so we can upload the DB before
// hook in here to back up APKs of apps that are otherwise not allowed for backup
backUpApksOfNotBackedUpPackages()
}
@@ -392,7 +393,9 @@
}
// getCurrentPackage() not-null because we have state
onPackageBackedUp(kv.getCurrentPackage()!!, BackupType.KV)
+ val isPmBackup = kv.getCurrentPackage()!!.packageName == MAGIC_PACKAGE_MANAGER
kv.finishBackup()
+ // TODO move @pm@ backup hook here
}
full.hasState() -> {
check(!kv.hasState()) {
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt
index 7689220..f6d03e7 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt
@@ -8,37 +8,59 @@
import android.provider.BaseColumns
import java.io.File
import java.io.FileInputStream
+import java.io.FileOutputStream
import java.io.InputStream
+import java.io.OutputStream
interface KvDbManager {
- fun getDb(packageName: String): KVDb
+ fun getDb(packageName: String, isRestore: Boolean = false): KVDb
+
+ /**
+ * Use only for backup.
+ */
fun getDbInputStream(packageName: String): InputStream
+
+ /**
+ * Use only for restore.
+ */
+ fun getDbOutputStream(packageName: String): OutputStream
+
+ /**
+ * Use only for backup.
+ */
fun existsDb(packageName: String): Boolean
- fun deleteDb(packageName: String): Boolean
+ fun deleteDb(packageName: String, isRestore: Boolean = false): Boolean
}
class KvDbManagerImpl(private val context: Context) : KvDbManager {
- override fun getDb(packageName: String): KVDb {
- return KVDbImpl(context, getFileName(packageName))
+ override fun getDb(packageName: String, isRestore: Boolean): KVDb {
+ return KVDbImpl(context, getFileName(packageName, isRestore))
}
- private fun getFileName(packageName: String) = "kv_$packageName.db"
+ private fun getFileName(packageName: String, isRestore: Boolean): String {
+ val prefix = if (isRestore) "restore_" else ""
+ return "${prefix}kv_$packageName.db"
+ }
- private fun getDbFile(packageName: String): File {
- return context.getDatabasePath(getFileName(packageName))
+ private fun getDbFile(packageName: String, isRestore: Boolean = false): File {
+ return context.getDatabasePath(getFileName(packageName, isRestore))
}
override fun getDbInputStream(packageName: String): InputStream {
return FileInputStream(getDbFile(packageName))
}
+ override fun getDbOutputStream(packageName: String): OutputStream {
+ return FileOutputStream(getDbFile(packageName, true))
+ }
+
override fun existsDb(packageName: String): Boolean {
return getDbFile(packageName).isFile
}
- override fun deleteDb(packageName: String): Boolean {
- return getDbFile(packageName).delete()
+ override fun deleteDb(packageName: String, isRestore: Boolean): Boolean {
+ return getDbFile(packageName, isRestore).delete()
}
}
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt
index 5ee74ed..7af153f 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt
@@ -15,19 +15,25 @@
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForKV
+import com.stevesoltys.seedvault.transport.backup.BackupPlugin
+import com.stevesoltys.seedvault.transport.backup.KVDb
+import com.stevesoltys.seedvault.transport.backup.KvDbManager
import libcore.io.IoUtils.closeQuietly
import java.io.IOException
import java.security.GeneralSecurityException
import java.util.ArrayList
+import java.util.zip.GZIPInputStream
import javax.crypto.AEADBadTagException
private class KVRestoreState(
val version: Byte,
val token: Long,
+ val name: String,
val packageInfo: PackageInfo,
/**
* Optional [PackageInfo] for single package restore, optimizes restore of @pm@
*/
+ @Deprecated("TODO remove?")
val pmPackageInfo: PackageInfo?
)
@@ -35,20 +41,25 @@
@Suppress("BlockingMethodInNonBlockingContext")
internal class KVRestore(
- private val plugin: KVRestorePlugin,
+ private val plugin: BackupPlugin,
+ private val legacyPlugin: KVRestorePlugin,
private val outputFactory: OutputFactory,
private val headerReader: HeaderReader,
- private val crypto: Crypto
+ private val crypto: Crypto,
+ private val dbManager: KvDbManager
) {
private var state: KVRestoreState? = null
/**
* Return true if there are records stored for the given package.
+ *
+ * Deprecated. Use only for v0 backups.
*/
@Throws(IOException::class)
+ @Deprecated("Use BackupPlugin#hasData() instead")
suspend fun hasDataForPackage(token: Long, packageInfo: PackageInfo): Boolean {
- return plugin.hasDataForPackage(token, packageInfo)
+ return legacyPlugin.hasDataForPackage(token, packageInfo)
}
/**
@@ -62,10 +73,11 @@
fun initializeState(
version: Byte,
token: Long,
+ name: String,
packageInfo: PackageInfo,
pmPackageInfo: PackageInfo? = null
) {
- state = KVRestoreState(version, token, packageInfo, pmPackageInfo)
+ state = KVRestoreState(version, token, name, packageInfo, pmPackageInfo)
}
/**
@@ -78,12 +90,66 @@
suspend fun getRestoreData(data: ParcelFileDescriptor): Int {
val state = this.state ?: throw IllegalStateException("no state")
+ // take legacy path for version 0
+ if (state.version == 0x00.toByte()) return getRestoreDataV0(state, data)
+
+ return try {
+ val db = getRestoreDb(state)
+ val out = outputFactory.getBackupDataOutput(data)
+ db.getAll().sortedBy { it.first }.forEach { (key, value) ->
+ val size = value.size
+ Log.v(TAG, " ... key=$key size=$size")
+ out.writeEntityHeader(key, size)
+ out.writeEntityData(value, size)
+ }
+ TRANSPORT_OK
+ } catch (e: UnsupportedVersionException) {
+ Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
+ TRANSPORT_ERROR
+ } catch (e: IOException) {
+ Log.e(TAG, "Unable to process K/V backup database", e)
+ TRANSPORT_ERROR
+ } catch (e: GeneralSecurityException) {
+ Log.e(TAG, "General security exception while reading backup database", e)
+ TRANSPORT_ERROR
+ } catch (e: AEADBadTagException) {
+ Log.e(TAG, "Decryption failed", e)
+ TRANSPORT_ERROR
+ } finally {
+ dbManager.deleteDb(state.packageInfo.packageName, true)
+ this.state = null
+ closeQuietly(data)
+ }
+ }
+
+ @Throws(IOException::class, GeneralSecurityException::class, UnsupportedVersionException::class)
+ private suspend fun getRestoreDb(state: KVRestoreState): KVDb {
+ val packageName = state.packageInfo.packageName
+ plugin.getInputStream(state.token, state.name).use { inputStream ->
+ headerReader.readVersion(inputStream, state.version)
+ val ad = getADForKV(VERSION, packageName)
+ crypto.newDecryptingStream(inputStream, ad).use { decryptedStream ->
+ GZIPInputStream(decryptedStream).use { gzipStream ->
+ dbManager.getDbOutputStream(packageName).use { outputStream ->
+ gzipStream.copyTo(outputStream)
+ }
+ }
+ }
+ }
+ return dbManager.getDb(packageName, true)
+ }
+
+ //
+ // v0 restore legacy code below
+ //
+
+ private suspend fun getRestoreDataV0(state: KVRestoreState, data: ParcelFileDescriptor): Int {
// The restore set is the concatenation of the individual record blobs,
// each of which is a file in the package's directory.
// We return the data in lexical order sorted by key,
// so that apps which use synthetic keys like BLOB_1, BLOB_2, etc
// will see the date in the most obvious order.
- val sortedKeys = getSortedKeys(state.token, state.packageInfo)
+ val sortedKeys = getSortedKeysV0(state.token, state.packageInfo)
if (sortedKeys == null) {
// nextRestorePackage() ensures the dir exists, so this is an error
Log.e(TAG, "No keys for package: ${state.packageInfo.packageName}")
@@ -96,7 +162,7 @@
return try {
val dataOutput = outputFactory.getBackupDataOutput(data)
for (keyEntry in sortedKeys) {
- readAndWriteValue(state, keyEntry, dataOutput)
+ readAndWriteValueV0(state, keyEntry, dataOutput)
}
TRANSPORT_OK
} catch (e: IOException) {
@@ -105,9 +171,6 @@
} catch (e: SecurityException) {
Log.e(TAG, "Security exception while reading backup records", e)
TRANSPORT_ERROR
- } catch (e: GeneralSecurityException) {
- Log.e(TAG, "General security exception while reading backup records", e)
- TRANSPORT_ERROR
} catch (e: UnsupportedVersionException) {
Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
TRANSPORT_ERROR
@@ -124,9 +187,9 @@
* Return a list of the records (represented by key files) in the given directory,
* sorted lexically by the Base64-decoded key file name, not by the on-disk filename.
*/
- private suspend fun getSortedKeys(token: Long, packageInfo: PackageInfo): List<DecodedKey>? {
+ private suspend fun getSortedKeysV0(token: Long, packageInfo: PackageInfo): List<DecodedKey>? {
val records: List<String> = try {
- plugin.listRecords(token, packageInfo)
+ legacyPlugin.listRecords(token, packageInfo)
} catch (e: IOException) {
return null
}
@@ -150,24 +213,18 @@
/**
* Read the encrypted value for the given key and write it to the given [BackupDataOutput].
*/
+ @Suppress("Deprecation")
@Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
- private suspend fun readAndWriteValue(
+ private suspend fun readAndWriteValueV0(
state: KVRestoreState,
dKey: DecodedKey,
out: BackupDataOutput
- ) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
+ ) = legacyPlugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
.use { inputStream ->
val version = headerReader.readVersion(inputStream, state.version)
val packageName = state.packageInfo.packageName
- val value = if (version == 0.toByte()) {
- crypto.decryptHeader(inputStream, version, packageName, dKey.key)
- crypto.decryptMultipleSegments(inputStream)
- } else {
- val ad = getADForKV(VERSION, packageName)
- crypto.newDecryptingStream(inputStream, ad).use { decryptedStream ->
- decryptedStream.readBytes()
- }
- }
+ crypto.decryptHeader(inputStream, version, packageName, dKey.key)
+ val value = crypto.decryptMultipleSegments(inputStream)
val size = value.size
Log.v(TAG, " ... key=${dKey.key} size=$size")
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt
index 0859cda..586a846 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt
@@ -207,7 +207,13 @@
val name = crypto.getNameForPackage(state.backupMetadata.salt, packageName)
if (plugin.hasData(state.token, name)) {
Log.i(TAG, "Found K/V data for $packageName.")
- kv.initializeState(version, state.token, packageInfo, state.pmPackageInfo)
+ kv.initializeState(
+ version = version,
+ token = state.token,
+ name = name,
+ packageInfo = packageInfo,
+ pmPackageInfo = state.pmPackageInfo
+ )
state.currentPackage = packageName
TYPE_KEY_VALUE
} else throw IOException("No data found for $packageName. Skipping.")
@@ -243,7 +249,7 @@
// check key/value data first and if available, don't even check for full data
kv.hasDataForPackage(state.token, packageInfo) -> {
Log.i(TAG, "Found K/V data for $packageName.")
- kv.initializeState(0x00, state.token, packageInfo, state.pmPackageInfo)
+ kv.initializeState(0x00, state.token, "", packageInfo, state.pmPackageInfo)
state.currentPackage = packageName
TYPE_KEY_VALUE
}
diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt
index 62756a8..7ef3c4c 100644
--- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt
+++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt
@@ -5,7 +5,7 @@
val restoreModule = module {
single { OutputFactory() }
- single { KVRestore(get<RestorePlugin>().kvRestorePlugin, get(), get(), get()) }
+ single { KVRestore(get(), get<RestorePlugin>().kvRestorePlugin, get(), get(), get(), get()) }
single { FullRestore(get(), get<RestorePlugin>().fullRestorePlugin, get(), get(), get()) }
single {
RestoreCoordinator(androidContext(), get(), get(), get(), get(), get(), get(), get(), get())
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt
index 73b2508..282a6ba 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt
@@ -10,7 +10,6 @@
import com.stevesoltys.seedvault.crypto.CipherFactoryImpl
import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
-import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.metadata.BackupType
@@ -39,6 +38,7 @@
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
+import io.mockk.verify
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
@@ -81,7 +81,14 @@
)
private val kvRestorePlugin = mockk<KVRestorePlugin>()
- private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl)
+ private val kvRestore = KVRestore(
+ backupPlugin,
+ kvRestorePlugin,
+ outputFactory,
+ headerReader,
+ cryptoImpl,
+ dbManager
+ )
private val fullRestorePlugin = mockk<FullRestorePlugin>()
private val fullRestore =
FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)
@@ -104,9 +111,7 @@
private val metadataOutputStream = ByteArrayOutputStream()
private val packageMetadata = PackageMetadata(time = 0L)
private val key = "RestoreKey"
- private val key64 = key.encodeBase64()
private val key2 = "RestoreKey2"
- private val key264 = key2.encodeBase64()
// as we use real crypto, we need a real name for packageInfo
private val realName = cryptoImpl.getNameForPackage(salt, packageInfo.packageName)
@@ -116,7 +121,6 @@
val value = CapturingSlot<ByteArray>()
val value2 = CapturingSlot<ByteArray>()
val bOutputStream = ByteArrayOutputStream()
- val bOutputStream2 = ByteArrayOutputStream()
every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt
@@ -170,29 +174,21 @@
// restore finds the backed up key and writes the decrypted value
val backupDataOutput = mockk<BackupDataOutput>()
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
- val rInputStream2 = ByteArrayInputStream(bOutputStream2.toByteArray())
- coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64, key264)
+ coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput
- coEvery {
- kvRestorePlugin.getInputStreamForRecord(
- token,
- packageInfo,
- key64
- )
- } returns rInputStream
every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137
every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size
- coEvery {
- kvRestorePlugin.getInputStreamForRecord(
- token,
- packageInfo,
- key264
- )
- } returns rInputStream2
every { backupDataOutput.writeEntityHeader(key2, appData2.size) } returns 1137
every { backupDataOutput.writeEntityData(appData2, appData2.size) } returns appData2.size
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
+
+ verify {
+ backupDataOutput.writeEntityHeader(key, appData.size)
+ backupDataOutput.writeEntityData(appData, appData.size)
+ backupDataOutput.writeEntityHeader(key2, appData2.size)
+ backupDataOutput.writeEntityData(appData2, appData2.size)
+ }
}
@Test
@@ -246,19 +242,17 @@
// restore finds the backed up key and writes the decrypted value
val backupDataOutput = mockk<BackupDataOutput>()
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
- coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64)
+ coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput
- coEvery {
- kvRestorePlugin.getInputStreamForRecord(
- token,
- packageInfo,
- key64
- )
- } returns rInputStream
every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137
every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
+
+ verify {
+ backupDataOutput.writeEntityHeader(key, appData.size)
+ backupDataOutput.writeEntityData(appData, appData.size)
+ }
}
@Test
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt
index 34a5e0d..7173f2f 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt
@@ -3,22 +3,29 @@
import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.toHexString
-import junit.framework.Assert.assertEquals
-import junit.framework.Assert.assertFalse
-import junit.framework.Assert.assertNull
-import junit.framework.Assert.assertTrue
import org.json.JSONObject
import org.junit.jupiter.api.Assertions.assertArrayEquals
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Assertions.assertFalse
+import org.junit.jupiter.api.Assertions.assertNull
+import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
import java.io.InputStream
+import java.io.OutputStream
import kotlin.random.Random
class TestKvDbManager : KvDbManager {
private var db: TestKVDb? = null
+ private val outputStream = ByteArrayOutputStream()
- override fun getDb(packageName: String): KVDb {
+ override fun getDb(packageName: String, isRestore: Boolean): KVDb {
+ if (isRestore) {
+ readDbFromStream(ByteArrayInputStream(outputStream.toByteArray()))
+ return this.db!!
+ }
return TestKVDb().apply { db = this }
}
@@ -26,11 +33,16 @@
return ByteArrayInputStream(db!!.serialize().toByteArray())
}
+ override fun getDbOutputStream(packageName: String): OutputStream {
+ outputStream.reset()
+ return outputStream
+ }
+
override fun existsDb(packageName: String): Boolean {
return db != null
}
- override fun deleteDb(packageName: String): Boolean {
+ override fun deleteDb(packageName: String, isRestore: Boolean): Boolean {
clearDb()
return true
}
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt
index 05a8337..a2ae238 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt
@@ -10,33 +10,51 @@
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForKV
+import com.stevesoltys.seedvault.transport.backup.BackupPlugin
+import com.stevesoltys.seedvault.transport.backup.KVDb
+import com.stevesoltys.seedvault.transport.backup.KvDbManager
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.mockkStatic
+import io.mockk.verify
import io.mockk.verifyAll
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.security.GeneralSecurityException
+import java.util.zip.GZIPOutputStream
import kotlin.random.Random
@Suppress("BlockingMethodInNonBlockingContext")
internal class KVRestoreTest : RestoreTest() {
- private val plugin = mockk<KVRestorePlugin>()
+ private val plugin = mockk<BackupPlugin>()
+ private val legacyPlugin = mockk<KVRestorePlugin>()
+ private val dbManager = mockk<KvDbManager>()
private val output = mockk<BackupDataOutput>()
- private val restore = KVRestore(plugin, outputFactory, headerReader, crypto)
+ private val restore =
+ KVRestore(plugin, legacyPlugin, outputFactory, headerReader, crypto, dbManager)
+
+ private val db = mockk<KVDb>()
private val ad = getADForKV(VERSION, packageInfo.packageName)
private val key = "Restore Key"
private val key64 = key.encodeBase64()
private val key2 = "Restore Key2"
private val key264 = key2.encodeBase64()
+ private val data2 = getRandomByteArray()
+
+ private val outputStream = ByteArrayOutputStream().apply {
+ GZIPOutputStream(this).close()
+ }
+ private val decryptInputStream = ByteArrayInputStream(outputStream.toByteArray())
init {
// for InputStream#readBytes()
@@ -44,15 +62,6 @@
}
@Test
- fun `hasDataForPackage() delegates to plugin`() = runBlocking {
- val result = Random.nextBoolean()
-
- coEvery { plugin.hasDataForPackage(token, packageInfo) } returns result
-
- assertEquals(result, restore.hasDataForPackage(token, packageInfo))
- }
-
- @Test
fun `getRestoreData() throws without initializing state`() {
coAssertThrows(IllegalStateException::class.java) {
restore.getRestoreData(fileDescriptor)
@@ -60,22 +69,133 @@
}
@Test
- fun `listing records throws`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ fun `unexpected version aborts with error`() = runBlocking {
+ restore.initializeState(VERSION, token, name, packageInfo)
- coEvery { plugin.listRecords(token, packageInfo) } throws IOException()
+ coEvery { plugin.getInputStream(token, name) } returns inputStream
+ every {
+ headerReader.readVersion(inputStream, VERSION)
+ } throws UnsupportedVersionException(Byte.MAX_VALUE)
+ every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
+ streamsGetClosed()
+
+ assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
+ verifyStreamWasClosed()
+ }
+
+ @Test
+ fun `newDecryptingStream throws`() = runBlocking {
+ restore.initializeState(VERSION, token, name, packageInfo)
+
+ coEvery { plugin.getInputStream(token, name) } returns inputStream
+ every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
+ every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
+ every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
+ streamsGetClosed()
+
+ assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
+ verifyStreamWasClosed()
+
+ verifyAll {
+ dbManager.deleteDb(packageInfo.packageName, true)
+ }
+ }
+
+ @Test
+ fun `writeEntityHeader throws`() = runBlocking {
+ restore.initializeState(VERSION, token, name, packageInfo)
+
+ coEvery { plugin.getInputStream(token, name) } returns inputStream
+ every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
+ every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream
+ every {
+ dbManager.getDbOutputStream(packageInfo.packageName)
+ } returns ByteArrayOutputStream()
+ every { dbManager.getDb(packageInfo.packageName, true) } returns db
+ every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
+ every { db.getAll() } returns listOf(Pair(key, data))
+ every { output.writeEntityHeader(key, data.size) } throws IOException()
+ every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
+ streamsGetClosed()
+
+ assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
+ verifyStreamWasClosed()
+
+ verify {
+ dbManager.deleteDb(packageInfo.packageName, true)
+ }
+ }
+
+ @Test
+ fun `two records get restored`() = runBlocking {
+ restore.initializeState(VERSION, token, name, packageInfo)
+
+ coEvery { plugin.getInputStream(token, name) } returns inputStream
+ every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
+ every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream
+ every {
+ dbManager.getDbOutputStream(packageInfo.packageName)
+ } returns ByteArrayOutputStream()
+ every { dbManager.getDb(packageInfo.packageName, true) } returns db
+ every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
+ every { db.getAll() } returns listOf(
+ Pair(key, data),
+ Pair(key2, data2)
+ )
+ every { output.writeEntityHeader(key, data.size) } returns 42
+ every { output.writeEntityData(data, data.size) } returns data.size
+ every { output.writeEntityHeader(key2, data2.size) } returns 42
+ every { output.writeEntityData(data2, data2.size) } returns data2.size
+
+ every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
+ streamsGetClosed()
+
+ assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
+ verifyStreamWasClosed()
+
+ verify {
+ output.writeEntityHeader(key, data.size)
+ output.writeEntityData(data, data.size)
+ output.writeEntityHeader(key2, data2.size)
+ output.writeEntityData(data2, data2.size)
+ dbManager.deleteDb(packageInfo.packageName, true)
+ }
+ }
+
+ //
+ // v0 legacy tests below
+ //
+
+ @Test
+ @Suppress("Deprecation")
+ fun `v0 hasDataForPackage() delegates to plugin`() = runBlocking {
+ val result = Random.nextBoolean()
+
+ coEvery { legacyPlugin.hasDataForPackage(token, packageInfo) } returns result
+
+ assertEquals(result, restore.hasDataForPackage(token, packageInfo))
+ }
+
+ @Test
+ @Suppress("Deprecation")
+ fun `v0 listing records throws`() = runBlocking {
+ restore.initializeState(0x00, token, name, packageInfo)
+
+ coEvery { legacyPlugin.listRecords(token, packageInfo) } throws IOException()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
}
@Test
- fun `reading VersionHeader with unsupported version throws`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ fun `v0 reading VersionHeader with unsupported version throws`() = runBlocking {
+ restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
every {
- headerReader.readVersion(inputStream, VERSION)
+ headerReader.readVersion(inputStream, 0x00)
} throws UnsupportedVersionException(unsupportedVersion)
streamsGetClosed()
@@ -84,12 +204,14 @@
}
@Test
- fun `error reading VersionHeader throws`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ fun `v0 error reading VersionHeader throws`() = runBlocking {
+ restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0x00) } throws IOException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@@ -97,13 +219,18 @@
}
@Test
- fun `decrypting stream throws`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ @Suppress("deprecation")
+ fun `v0 decrypting stream throws`() = runBlocking {
+ restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0x00) } returns 0x00
+ every {
+ crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
+ } throws IOException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@@ -111,13 +238,19 @@
}
@Test
- fun `decrypting stream throws security exception`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ @Suppress("deprecation")
+ fun `v0 decrypting stream throws security exception`() = runBlocking {
+ restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException()
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0x00) } returns 0x00
+ every {
+ crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
+ } returns VersionHeader(0x00, packageInfo.packageName, key)
+ every { crypto.decryptMultipleSegments(inputStream) } throws IOException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@@ -125,14 +258,19 @@
}
@Test
- fun `writing header throws`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ @Suppress("Deprecation")
+ fun `v0 writing header throws`() = runBlocking {
+ restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
- every { decryptedInputStream.readBytes() } returns data
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0) } returns 0
+ every {
+ crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
+ } returns VersionHeader(0x00, packageInfo.packageName, key)
+ every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } throws IOException()
streamsGetClosed()
@@ -141,14 +279,19 @@
}
@Test
- fun `writing value throws`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ @Suppress("deprecation")
+ fun `v0 writing value throws`() = runBlocking {
+ restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
- every { decryptedInputStream.readBytes() } returns data
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0) } returns 0
+ every {
+ crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
+ } returns VersionHeader(0, packageInfo.packageName, key)
+ every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } throws IOException()
streamsGetClosed()
@@ -158,14 +301,19 @@
}
@Test
- fun `writing value succeeds`() = runBlocking {
- restore.initializeState(VERSION, token, packageInfo)
+ @Suppress("deprecation")
+ fun `v0 writing value succeeds`() = runBlocking {
+ restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
- every { decryptedInputStream.readBytes() } returns data
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0) } returns 0
+ every {
+ crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
+ } returns VersionHeader(0, packageInfo.packageName, key)
+ every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size
streamsGetClosed()
@@ -175,14 +323,17 @@
}
@Test
- fun `writing value uses old v0 code`() = runBlocking {
- restore.initializeState(0.toByte(), token, packageInfo)
+ @Suppress("deprecation")
+ fun `v0 writing value uses old v0 code`() = runBlocking {
+ restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0) } returns 0
every {
- crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key)
+ crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(VERSION, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
@@ -194,43 +345,35 @@
}
@Test
- fun `unexpected version aborts with error`() = runBlocking {
- restore.initializeState(Byte.MAX_VALUE, token, packageInfo)
-
- getRecordsAndOutput()
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every {
- headerReader.readVersion(inputStream, Byte.MAX_VALUE)
- } throws GeneralSecurityException()
- streamsGetClosed()
-
- assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
- verifyStreamWasClosed()
- }
-
- @Test
- fun `writing two values succeeds`() = runBlocking {
+ @Suppress("Deprecation")
+ fun `v0 writing two values succeeds`() = runBlocking {
val data2 = getRandomByteArray()
val inputStream2 = mockk<InputStream>()
- val decryptedInputStream2 = mockk<InputStream>()
- restore.initializeState(VERSION, token, packageInfo)
+ restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput(listOf(key64, key264))
// first key/value
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
- every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
- every { decryptedInputStream.readBytes() } returns data
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
+ } returns inputStream
+ every { headerReader.readVersion(inputStream, 0) } returns 0
+ every {
+ crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
+ } returns VersionHeader(0, packageInfo.packageName, key)
+ every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size
// second key/value
- coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2
- every { headerReader.readVersion(inputStream2, VERSION) } returns VERSION
- every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2
- every { decryptedInputStream2.readBytes() } returns data2
+ coEvery {
+ legacyPlugin.getInputStreamForRecord(token, packageInfo, key264)
+ } returns inputStream2
+ every { headerReader.readVersion(inputStream2, 0) } returns 0
+ every {
+ crypto.decryptHeader(inputStream2, 0, packageInfo.packageName, key2)
+ } returns VersionHeader(0, packageInfo.packageName, key2)
+ every { crypto.decryptMultipleSegments(inputStream2) } returns data2
every { output.writeEntityHeader(key2, data2.size) } returns 42
every { output.writeEntityData(data2, data2.size) } returns data2.size
- every { decryptedInputStream2.close() } just Runs
every { inputStream2.close() } just Runs
streamsGetClosed()
@@ -238,12 +381,11 @@
}
private fun getRecordsAndOutput(recordKeys: List<String> = listOf(key64)) {
- coEvery { plugin.listRecords(token, packageInfo) } returns recordKeys
+ coEvery { legacyPlugin.listRecords(token, packageInfo) } returns recordKeys
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
}
private fun streamsGetClosed() {
- every { decryptedInputStream.close() } just Runs
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
}
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt
index 72fe53a..357098f 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt
@@ -223,19 +223,20 @@
every { crypto.getNameForPackage(metadata.salt, packageName) } returns name
coEvery { plugin.hasData(token, name) } returns true
- every { kv.initializeState(VERSION, token, packageInfo) } just Runs
+ every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs
val expected = RestoreDescription(packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
}
@Test
+ @Suppress("Deprecation")
fun `v0 nextRestorePackage() returns KV description and takes precedence`() = runBlocking {
restore.beforeStartRestore(metadata.copy(version = 0x00))
restore.startRestore(token, packageInfoArray)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
- every { kv.initializeState(0x00, token, packageInfo) } just Runs
+ every { kv.initializeState(0x00, token, "", packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
@@ -292,7 +293,7 @@
every { crypto.getNameForPackage(metadata.salt, packageName) } returns name
coEvery { plugin.hasData(token, name) } returns true
- every { kv.initializeState(VERSION, token, packageInfo) } just Runs
+ every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
@@ -315,7 +316,7 @@
restore.startRestore(token, packageInfoArray2)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
- every { kv.initializeState(0.toByte(), token, packageInfo) } just Runs
+ every { kv.initializeState(0.toByte(), token, "", packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
@@ -331,6 +332,7 @@
}
@Test
+ @Suppress("Deprecation")
fun `v0 when kv#hasDataForPackage() throws, it tries next package`() = runBlocking {
restore.beforeStartRestore(metadata.copy(version = 0x00))
restore.startRestore(token, packageInfoArray)
diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt
index 1b06ed7..07d4b5a 100644
--- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt
+++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt
@@ -16,6 +16,7 @@
import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.transport.TransportTest
import com.stevesoltys.seedvault.transport.backup.BackupPlugin
+import com.stevesoltys.seedvault.transport.backup.KvDbManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import io.mockk.coEvery
import io.mockk.every
@@ -44,12 +45,20 @@
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
+ private val dbManager = mockk<KvDbManager>()
private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>()
private val backupPlugin = mockk<BackupPlugin>()
private val kvRestorePlugin = mockk<KVRestorePlugin>()
- private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl)
+ private val kvRestore = KVRestore(
+ backupPlugin,
+ kvRestorePlugin,
+ outputFactory,
+ headerReader,
+ cryptoImpl,
+ dbManager
+ )
private val fullRestorePlugin = mockk<FullRestorePlugin>()
private val fullRestore =
FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)