diff options
3 files changed, 654 insertions, 238 deletions
diff --git a/core/java/android/view/textclassifier/ModelFileManager.java b/core/java/android/view/textclassifier/ModelFileManager.java new file mode 100644 index 000000000000..adea1259b943 --- /dev/null +++ b/core/java/android/view/textclassifier/ModelFileManager.java @@ -0,0 +1,291 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.view.textclassifier; + +import static android.view.textclassifier.TextClassifier.DEFAULT_LOG_TAG; + +import android.annotation.Nullable; +import android.os.LocaleList; +import android.os.ParcelFileDescriptor; +import android.text.TextUtils; + +import com.android.internal.annotations.VisibleForTesting; +import com.android.internal.util.Preconditions; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.StringJoiner; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Manages model files that are listed by the model files supplier. + * @hide + */ +@VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE) +public final class ModelFileManager { + private final Object mLock = new Object(); + private final Supplier<List<ModelFile>> mModelFileSupplier; + + private List<ModelFile> mModelFiles; + + public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) { + mModelFileSupplier = Preconditions.checkNotNull(modelFileSupplier); + } + + /** + * Returns an unmodifiable list of model files listed by the given model files supplier. + * <p> + * The result is cached. + */ + public List<ModelFile> listModelFiles() { + synchronized (mLock) { + if (mModelFiles == null) { + mModelFiles = Collections.unmodifiableList(mModelFileSupplier.get()); + } + return mModelFiles; + } + } + + /** + * Returns the best model file for the given localelist, {@code null} if nothing is found. + * + * @param localeList the required locales, use {@code null} if there is no preference. + */ + public ModelFile findBestModelFile(@Nullable LocaleList localeList) { + // Specified localeList takes priority over the system default, so it is listed first. + final String languages = localeList == null || localeList.isEmpty() + ? LocaleList.getDefault().toLanguageTags() + : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags(); + final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); + + ModelFile bestModel = null; + for (ModelFile model : listModelFiles()) { + if (model.isAnyLanguageSupported(languageRangeList)) { + if (model.isPreferredTo(bestModel)) { + bestModel = model; + } + } + } + return bestModel; + } + + /** + * Default implementation of the model file supplier. + */ + public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> { + private final File mUpdatedModelFile; + private final File mFactoryModelDir; + private final Pattern mModelFilenamePattern; + private final Function<Integer, Integer> mVersionSupplier; + private final Function<Integer, String> mSupportedLocalesSupplier; + + public ModelFileSupplierImpl( + File factoryModelDir, + String factoryModelFileNameRegex, + File updatedModelFile, + Function<Integer, Integer> versionSupplier, + Function<Integer, String> supportedLocalesSupplier) { + mUpdatedModelFile = Preconditions.checkNotNull(updatedModelFile); + mFactoryModelDir = Preconditions.checkNotNull(factoryModelDir); + mModelFilenamePattern = Pattern.compile( + Preconditions.checkNotNull(factoryModelFileNameRegex)); + mVersionSupplier = Preconditions.checkNotNull(versionSupplier); + mSupportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier); + } + + @Override + public List<ModelFile> get() { + final List<ModelFile> modelFiles = new ArrayList<>(); + // The update model has the highest precedence. + if (mUpdatedModelFile.exists()) { + final ModelFile updatedModel = createModelFile(mUpdatedModelFile); + if (updatedModel != null) { + modelFiles.add(updatedModel); + } + } + // Factory models should never have overlapping locales, so the order doesn't matter. + if (mFactoryModelDir.exists() && mFactoryModelDir.isDirectory()) { + final File[] files = mFactoryModelDir.listFiles(); + for (File file : files) { + final Matcher matcher = mModelFilenamePattern.matcher(file.getName()); + if (matcher.matches() && file.isFile()) { + final ModelFile model = createModelFile(file); + if (model != null) { + modelFiles.add(model); + } + } + } + } + return modelFiles; + } + + /** Returns null if the path did not point to a compatible model. */ + @Nullable + private ModelFile createModelFile(File file) { + if (!file.exists()) { + return null; + } + ParcelFileDescriptor modelFd = null; + try { + modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY); + if (modelFd == null) { + return null; + } + final int modelFdInt = modelFd.getFd(); + final int version = mVersionSupplier.apply(modelFdInt); + final String supportedLocalesStr = mSupportedLocalesSupplier.apply(modelFdInt); + if (supportedLocalesStr.isEmpty()) { + Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath()); + return null; + } + final List<Locale> supportedLocales = new ArrayList<>(); + for (String langTag : supportedLocalesStr.split(",")) { + supportedLocales.add(Locale.forLanguageTag(langTag)); + } + return new ModelFile( + file, + version, + supportedLocales, + ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr)); + } catch (FileNotFoundException e) { + Log.e(DEFAULT_LOG_TAG, "Failed to find " + file.getAbsolutePath(), e); + return null; + } finally { + maybeCloseAndLogError(modelFd); + } + } + + /** + * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. + */ + private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) { + if (fd == null) { + return; + } + try { + fd.close(); + } catch (IOException e) { + Log.e(DEFAULT_LOG_TAG, "Error closing file.", e); + } + } + + } + + /** + * Describes TextClassifier model files on disk. + */ + public static final class ModelFile { + public static final String LANGUAGE_INDEPENDENT = "*"; + + private final File mFile; + private final int mVersion; + private final List<Locale> mSupportedLocales; + private final boolean mLanguageIndependent; + + public ModelFile(File file, int version, List<Locale> supportedLocales, + boolean languageIndependent) { + mFile = Preconditions.checkNotNull(file); + mVersion = version; + mSupportedLocales = Preconditions.checkNotNull(supportedLocales); + mLanguageIndependent = languageIndependent; + } + + /** Returns the absolute path to the model file. */ + public String getPath() { + return mFile.getAbsolutePath(); + } + + /** Returns a name to use for id generation, effectively the name of the model file. */ + public String getName() { + return mFile.getName(); + } + + /** Returns the version tag in the model's metadata. */ + public int getVersion() { + return mVersion; + } + + /** Returns whether the language supports any language in the given ranges. */ + public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) { + Preconditions.checkNotNull(languageRanges); + return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null; + } + + /** Returns an immutable lists of supported locales. */ + public List<Locale> getSupportedLocales() { + return Collections.unmodifiableList(mSupportedLocales); + } + + /** + * Returns if this model file is preferred to the given one. + */ + public boolean isPreferredTo(@Nullable ModelFile model) { + // A model is preferred to no model. + if (model == null) { + return true; + } + + // A language-specific model is preferred to a language independent + // model. + if (!mLanguageIndependent && model.mLanguageIndependent) { + return true; + } + + // A higher-version model is preferred. + if (mVersion > model.getVersion()) { + return true; + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(getPath()); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other instanceof ModelFile) { + final ModelFile otherModel = (ModelFile) other; + return TextUtils.equals(getPath(), otherModel.getPath()); + } + return false; + } + + @Override + public String toString() { + final StringJoiner localesJoiner = new StringJoiner(","); + for (Locale locale : mSupportedLocales) { + localesJoiner.add(locale.toLanguageTag()); + } + return String.format(Locale.US, + "ModelFile { path=%s name=%s version=%d locales=%s }", + getPath(), getName(), mVersion, localesJoiner.toString()); + } + } +} diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index 7f1e443f4aa5..159bfaa2ab26 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -58,16 +58,12 @@ import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.StringJoiner; import java.util.concurrent.TimeUnit; -import java.util.regex.Matcher; -import java.util.regex.Pattern; /** * Default implementation of the {@link TextClassifier} interface. @@ -81,13 +77,18 @@ import java.util.regex.Pattern; public final class TextClassifierImpl implements TextClassifier { private static final String LOG_TAG = DEFAULT_LOG_TAG; - private static final String MODEL_DIR = "/etc/textclassifier/"; - private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model"; - private static final String UPDATED_MODEL_FILE_PATH = - "/data/misc/textclassifier/textclassifier.model"; - private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model"; - private static final String UPDATED_LANG_ID_MODEL_FILE_PATH = - "/data/misc/textclassifier/lang_id.model"; + + private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/"); + // Annotator + private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX = + "textclassifier\\.(.*)\\.model"; + private static final File ANNOTATOR_UPDATED_MODEL_FILE = + new File("/data/misc/textclassifier/textclassifier.model"); + + // LangID + private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model"; + private static final File UPDATED_LANG_ID_MODEL_FILE = + new File("/data/misc/textclassifier/lang_id.model"); private final Context mContext; private final TextClassifier mFallback; @@ -95,9 +96,7 @@ public final class TextClassifierImpl implements TextClassifier { private final Object mLock = new Object(); @GuardedBy("mLock") // Do not access outside this lock. - private List<ModelFile> mAllModelFiles; - @GuardedBy("mLock") // Do not access outside this lock. - private ModelFile mModel; + private ModelFileManager.ModelFile mAnnotatorModelInUse; @GuardedBy("mLock") // Do not access outside this lock. private AnnotatorModel mAnnotatorImpl; @GuardedBy("mLock") // Do not access outside this lock. @@ -109,12 +108,29 @@ public final class TextClassifierImpl implements TextClassifier { private final TextClassificationConstants mSettings; + private final ModelFileManager mAnnotatorModelFileManager; + private final ModelFileManager mLangIdModelFileManager; + public TextClassifierImpl( Context context, TextClassificationConstants settings, TextClassifier fallback) { mContext = Preconditions.checkNotNull(context); mFallback = Preconditions.checkNotNull(fallback); mSettings = Preconditions.checkNotNull(settings); mGenerateLinksLogger = new GenerateLinksLogger(mSettings.getGenerateLinksLogSampleRate()); + mAnnotatorModelFileManager = new ModelFileManager( + new ModelFileManager.ModelFileSupplierImpl( + FACTORY_MODEL_DIR, + ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX, + ANNOTATOR_UPDATED_MODEL_FILE, + AnnotatorModel::getVersion, + AnnotatorModel::getLocales)); + mLangIdModelFileManager = new ModelFileManager( + new ModelFileManager.ModelFileSupplierImpl( + FACTORY_MODEL_DIR, + LANG_ID_FACTORY_MODEL_FILENAME_REGEX, + UPDATED_LANG_ID_MODEL_FILE, + fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd) + fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT)); } public TextClassifierImpl(Context context, TextClassificationConstants settings) { @@ -334,22 +350,24 @@ public final class TextClassifierImpl implements TextClassifier { throws FileNotFoundException { synchronized (mLock) { localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; - final ModelFile bestModel = findBestModelLocked(localeList); + final ModelFileManager.ModelFile bestModel = + mAnnotatorModelFileManager.findBestModelFile(localeList); if (bestModel == null) { - throw new FileNotFoundException("No model for " + localeList.toLanguageTags()); + throw new FileNotFoundException( + "No annotator model for " + localeList.toLanguageTags()); } - if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) { + if (mAnnotatorImpl == null || !Objects.equals(mAnnotatorModelInUse, bestModel)) { Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); destroyAnnotatorImplIfExistsLocked(); - final ParcelFileDescriptor fd = ParcelFileDescriptor.open( + final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); try { - if (fd != null) { - mAnnotatorImpl = new AnnotatorModel(fd.getFd()); - mModel = bestModel; + if (pfd != null) { + mAnnotatorImpl = new AnnotatorModel(pfd.getFd()); + mAnnotatorModelInUse = bestModel; } } finally { - maybeCloseAndLogError(fd); + maybeCloseAndLogError(pfd); } } return mAnnotatorImpl; @@ -367,40 +385,19 @@ public final class TextClassifierImpl implements TextClassifier { private LangIdModel getLangIdImpl() throws FileNotFoundException { synchronized (mLock) { if (mLangIdImpl == null) { - ParcelFileDescriptor factoryFd = null; - ParcelFileDescriptor updateFd = null; + final ModelFileManager.ModelFile bestModel = + mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList()); + if (bestModel == null) { + throw new FileNotFoundException("No LangID model is found"); + } + final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( + new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); try { - int factoryVersion = -1; - int updateVersion = factoryVersion; - final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH); - if (factoryFile.exists()) { - factoryFd = ParcelFileDescriptor.open( - factoryFile, ParcelFileDescriptor.MODE_READ_ONLY); - // TODO: Uncomment when method is implemented: - // if (factoryFd != null) { - // factoryVersion = LangIdModel.getVersion(factoryFd.getFd()); - // } - } - final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH); - if (updateFile.exists()) { - updateFd = ParcelFileDescriptor.open( - updateFile, ParcelFileDescriptor.MODE_READ_ONLY); - // TODO: Uncomment when method is implemented: - // if (updateFd != null) { - // updateVersion = LangIdModel.getVersion(updateFd.getFd()); - // } - } - - if (updateVersion > factoryVersion) { - mLangIdImpl = new LangIdModel(updateFd.getFd()); - } else if (factoryFd != null) { - mLangIdImpl = new LangIdModel(factoryFd.getFd()); - } else { - throw new FileNotFoundException("Language detection model not found"); + if (pfd != null) { + mLangIdImpl = new LangIdModel(pfd.getFd()); } } finally { - maybeCloseAndLogError(factoryFd); - maybeCloseAndLogError(updateFd); + maybeCloseAndLogError(pfd); } } return mLangIdImpl; @@ -409,8 +406,9 @@ public final class TextClassifierImpl implements TextClassifier { private String createId(String text, int start, int end) { synchronized (mLock) { - return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(), - mModel.getSupportedLocales()); + return SelectionSessionLogger.createId(text, start, end, mContext, + mAnnotatorModelInUse.getVersion(), + mAnnotatorModelInUse.getSupportedLocales()); } } @@ -418,66 +416,6 @@ public final class TextClassifierImpl implements TextClassifier { return (locales == null) ? "" : locales.toLanguageTags(); } - /** - * Finds the most appropriate model to use for the given target locale list. - * - * The basic logic is: we ignore all models that don't support any of the target locales. For - * the remaining candidates, we take the update model unless its version number is lower than - * the factory version. It's assumed that factory models do not have overlapping locale ranges - * and conflict resolution between these models hence doesn't matter. - */ - @GuardedBy("mLock") // Do not call outside this lock. - @Nullable - private ModelFile findBestModelLocked(LocaleList localeList) { - // Specified localeList takes priority over the system default, so it is listed first. - final String languages = localeList.isEmpty() - ? LocaleList.getDefault().toLanguageTags() - : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags(); - final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); - - ModelFile bestModel = null; - for (ModelFile model : listAllModelsLocked()) { - if (model.isAnyLanguageSupported(languageRangeList)) { - if (model.isPreferredTo(bestModel)) { - bestModel = model; - } - } - } - return bestModel; - } - - /** Returns a list of all model files available, in order of precedence. */ - @GuardedBy("mLock") // Do not call outside this lock. - private List<ModelFile> listAllModelsLocked() { - if (mAllModelFiles == null) { - final List<ModelFile> allModels = new ArrayList<>(); - // The update model has the highest precedence. - if (new File(UPDATED_MODEL_FILE_PATH).exists()) { - final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH); - if (updatedModel != null) { - allModels.add(updatedModel); - } - } - // Factory models should never have overlapping locales, so the order doesn't matter. - final File modelsDir = new File(MODEL_DIR); - if (modelsDir.exists() && modelsDir.isDirectory()) { - final File[] modelFiles = modelsDir.listFiles(); - final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX); - for (File modelFile : modelFiles) { - final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName()); - if (matcher.matches() && modelFile.isFile()) { - final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath()); - if (model != null) { - allModels.add(model); - } - } - } - } - mAllModelFiles = allModels; - } - return mAllModelFiles; - } - private TextClassification createClassificationResult( AnnotatorModel.ClassificationResult[] classifications, String text, int start, int end, @Nullable Instant referenceTime) { @@ -523,12 +461,18 @@ public final class TextClassifierImpl implements TextClassifier { @Override public void dump(@NonNull IndentingPrintWriter printWriter) { synchronized (mLock) { - listAllModelsLocked(); printWriter.println("TextClassifierImpl:"); printWriter.increaseIndent(); - printWriter.println("Model file(s):"); + printWriter.println("Annotator model file(s):"); printWriter.increaseIndent(); - for (ModelFile modelFile : mAllModelFiles) { + for (ModelFileManager.ModelFile modelFile : + mAnnotatorModelFileManager.listModelFiles()) { + printWriter.println(modelFile.toString()); + } + printWriter.decreaseIndent(); + printWriter.println("LangID model file(s):"); + for (ModelFileManager.ModelFile modelFile : + mLangIdModelFileManager.listModelFiles()) { printWriter.println(modelFile.toString()); } printWriter.decreaseIndent(); @@ -554,126 +498,6 @@ public final class TextClassifierImpl implements TextClassifier { } /** - * Describes TextClassifier model files on disk. - */ - private static final class ModelFile { - - private final String mPath; - private final String mName; - private final int mVersion; - private final List<Locale> mSupportedLocales; - private final boolean mLanguageIndependent; - - /** Returns null if the path did not point to a compatible model. */ - static @Nullable ModelFile fromPath(String path) { - final File file = new File(path); - if (!file.exists()) { - return null; - } - ParcelFileDescriptor modelFd = null; - try { - modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY); - if (modelFd == null) { - return null; - } - final int version = AnnotatorModel.getVersion(modelFd.getFd()); - final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd()); - if (supportedLocalesStr.isEmpty()) { - Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath()); - return null; - } - final boolean languageIndependent = supportedLocalesStr.equals("*"); - final List<Locale> supportedLocales = new ArrayList<>(); - for (String langTag : supportedLocalesStr.split(",")) { - supportedLocales.add(Locale.forLanguageTag(langTag)); - } - return new ModelFile(path, file.getName(), version, supportedLocales, - languageIndependent); - } catch (FileNotFoundException e) { - Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e); - return null; - } finally { - maybeCloseAndLogError(modelFd); - } - } - - /** The absolute path to the model file. */ - String getPath() { - return mPath; - } - - /** A name to use for id generation. Effectively the name of the model file. */ - String getName() { - return mName; - } - - /** Returns the version tag in the model's metadata. */ - int getVersion() { - return mVersion; - } - - /** Returns whether the language supports any language in the given ranges. */ - boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) { - return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null; - } - - /** All locales supported by the model. */ - List<Locale> getSupportedLocales() { - return Collections.unmodifiableList(mSupportedLocales); - } - - public boolean isPreferredTo(ModelFile model) { - // A model is preferred to no model. - if (model == null) { - return true; - } - - // A language-specific model is preferred to a language independent - // model. - if (!mLanguageIndependent && model.mLanguageIndependent) { - return true; - } - - // A higher-version model is preferred. - if (getVersion() > model.getVersion()) { - return true; - } - return false; - } - - @Override - public boolean equals(Object other) { - if (this == other) { - return true; - } - if (other instanceof ModelFile) { - final ModelFile otherModel = (ModelFile) other; - return mPath.equals(otherModel.mPath); - } - return false; - } - - @Override - public String toString() { - final StringJoiner localesJoiner = new StringJoiner(","); - for (Locale locale : mSupportedLocales) { - localesJoiner.add(locale.toLanguageTag()); - } - return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }", - mPath, mName, mVersion, localesJoiner.toString()); - } - - private ModelFile(String path, String name, int version, List<Locale> supportedLocales, - boolean languageIndependent) { - mPath = path; - mName = name; - mVersion = version; - mSupportedLocales = supportedLocales; - mLanguageIndependent = languageIndependent; - } - } - - /** * Helper class to store the information from which RemoteActions are built. */ private static final class LabeledIntent { diff --git a/core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java b/core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java new file mode 100644 index 000000000000..51e5aec8b219 --- /dev/null +++ b/core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java @@ -0,0 +1,301 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.view.textclassifier; + +import static com.google.common.truth.Truth.assertThat; + +import static org.mockito.Mockito.when; + +import android.os.LocaleList; +import android.support.test.InstrumentationRegistry; +import android.support.test.filters.SmallTest; +import android.support.test.runner.AndroidJUnit4; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +@SmallTest +@RunWith(AndroidJUnit4.class) +public class ModelFileManagerTest { + + @Mock + private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier; + private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl; + private ModelFileManager mModelFileManager; + private File mRootTestDir; + private File mFactoryModelDir; + private File mUpdatedModelFile; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + mModelFileManager = new ModelFileManager(mModelFileSupplier); + mRootTestDir = InstrumentationRegistry.getContext().getCacheDir(); + mFactoryModelDir = new File(mRootTestDir, "factory"); + mUpdatedModelFile = new File(mRootTestDir, "updated.model"); + + mModelFileSupplierImpl = + new ModelFileManager.ModelFileSupplierImpl( + mFactoryModelDir, + "test\\d.model", + mUpdatedModelFile, + fd -> 1, + fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT + ); + + mRootTestDir.mkdirs(); + mFactoryModelDir.mkdirs(); + } + + @After + public void removeTestDir() { + recursiveDelete(mRootTestDir); + } + + @Test + public void get() { + ModelFileManager.ModelFile modelFile = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, Collections.emptyList(), true); + when(mModelFileSupplier.get()).thenReturn(Collections.singletonList(modelFile)); + + List<ModelFileManager.ModelFile> modelFiles = mModelFileManager.listModelFiles(); + + assertThat(modelFiles).hasSize(1); + assertThat(modelFiles.get(0)).isEqualTo(modelFile); + } + + @Test + public void findBestModel_versionCode() { + ModelFileManager.ModelFile olderModelFile = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.emptyList(), true); + + ModelFileManager.ModelFile newerModelFile = + new ModelFileManager.ModelFile( + new File("/path/b"), 2, + Collections.emptyList(), true); + when(mModelFileSupplier.get()) + .thenReturn(Arrays.asList(olderModelFile, newerModelFile)); + + ModelFileManager.ModelFile bestModelFile = + mModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList()); + + assertThat(bestModelFile).isEqualTo(newerModelFile); + } + + @Test + public void findBestModel_languageDependentModelIsPreferred() { + Locale locale = Locale.forLanguageTag("ja"); + ModelFileManager.ModelFile languageIndependentModelFile = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.emptyList(), true); + + ModelFileManager.ModelFile languageDependentModelFile = + new ModelFileManager.ModelFile( + new File("/path/b"), 1, + Collections.singletonList(locale), false); + when(mModelFileSupplier.get()) + .thenReturn( + Arrays.asList(languageIndependentModelFile, languageDependentModelFile)); + + ModelFileManager.ModelFile bestModelFile = + mModelFileManager.findBestModelFile( + LocaleList.forLanguageTags(locale.toLanguageTag())); + assertThat(bestModelFile).isEqualTo(languageDependentModelFile); + } + + @Test + public void findBestModel_useIndependentWhenNoLanguageModelMatch() { + Locale locale = Locale.forLanguageTag("ja"); + ModelFileManager.ModelFile languageIndependentModelFile = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.emptyList(), true); + + ModelFileManager.ModelFile languageDependentModelFile = + new ModelFileManager.ModelFile( + new File("/path/b"), 1, + Collections.singletonList(locale), false); + + when(mModelFileSupplier.get()) + .thenReturn( + Arrays.asList(languageIndependentModelFile, languageDependentModelFile)); + + ModelFileManager.ModelFile bestModelFile = + mModelFileManager.findBestModelFile( + LocaleList.forLanguageTags("zh-hk")); + assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); + } + + @Test + public void findBestModel_languageIsMoreImportantThanVersion() { + ModelFileManager.ModelFile matchButOlderModel = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("fr")), false); + + ModelFileManager.ModelFile mismatchButNewerModel = + new ModelFileManager.ModelFile( + new File("/path/b"), 2, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + when(mModelFileSupplier.get()) + .thenReturn( + Arrays.asList(matchButOlderModel, mismatchButNewerModel)); + + ModelFileManager.ModelFile bestModelFile = + mModelFileManager.findBestModelFile( + LocaleList.forLanguageTags("fr")); + assertThat(bestModelFile).isEqualTo(matchButOlderModel); + } + + @Test + public void modelFileEquals() { + ModelFileManager.ModelFile modelA = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + ModelFileManager.ModelFile modelB = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + assertThat(modelA).isEqualTo(modelB); + } + + @Test + public void modelFile_different() { + ModelFileManager.ModelFile modelA = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + ModelFileManager.ModelFile modelB = + new ModelFileManager.ModelFile( + new File("/path/b"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + assertThat(modelA).isNotEqualTo(modelB); + } + + + @Test + public void modelFile_getPath() { + ModelFileManager.ModelFile modelA = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + assertThat(modelA.getPath()).isEqualTo("/path/a"); + } + + @Test + public void modelFile_getName() { + ModelFileManager.ModelFile modelA = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + assertThat(modelA.getName()).isEqualTo("a"); + } + + @Test + public void modelFile_isPreferredTo_languageDependentIsBetter() { + ModelFileManager.ModelFile modelA = + new ModelFileManager.ModelFile( + new File("/path/a"), 1, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + ModelFileManager.ModelFile modelB = + new ModelFileManager.ModelFile( + new File("/path/b"), 2, + Collections.emptyList(), true); + + assertThat(modelA.isPreferredTo(modelB)).isTrue(); + } + + @Test + public void modelFile_isPreferredTo_version() { + ModelFileManager.ModelFile modelA = + new ModelFileManager.ModelFile( + new File("/path/a"), 2, + Collections.singletonList(Locale.forLanguageTag("ja")), false); + + ModelFileManager.ModelFile modelB = + new ModelFileManager.ModelFile( + new File("/path/b"), 1, + Collections.emptyList(), false); + + assertThat(modelA.isPreferredTo(modelB)).isTrue(); + } + + @Test + public void testFileSupplierImpl_updatedFileOnly() throws IOException { + mUpdatedModelFile.createNewFile(); + File model1 = new File(mFactoryModelDir, "test1.model"); + model1.createNewFile(); + File model2 = new File(mFactoryModelDir, "test2.model"); + model2.createNewFile(); + new File(mFactoryModelDir, "not_match_regex.model").createNewFile(); + + List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get(); + List<String> modelFilePaths = + modelFiles + .stream() + .map(modelFile -> modelFile.getPath()) + .collect(Collectors.toList()); + + assertThat(modelFiles).hasSize(3); + assertThat(modelFilePaths).containsExactly( + mUpdatedModelFile.getAbsolutePath(), + model1.getAbsolutePath(), + model2.getAbsolutePath()); + } + + @Test + public void testFileSupplierImpl_empty() { + mFactoryModelDir.delete(); + List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get(); + + assertThat(modelFiles).hasSize(0); + } + + private static void recursiveDelete(File f) { + if (f.isDirectory()) { + for (File innerFile : f.listFiles()) { + recursiveDelete(innerFile); + } + } + f.delete(); + } +} |