summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/java/android/view/textclassifier/ModelFileManager.java291
-rw-r--r--core/java/android/view/textclassifier/TextClassifierImpl.java300
-rw-r--r--core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java301
3 files changed, 654 insertions, 238 deletions
diff --git a/core/java/android/view/textclassifier/ModelFileManager.java b/core/java/android/view/textclassifier/ModelFileManager.java
new file mode 100644
index 000000000000..adea1259b943
--- /dev/null
+++ b/core/java/android/view/textclassifier/ModelFileManager.java
@@ -0,0 +1,291 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.view.textclassifier;
+
+import static android.view.textclassifier.TextClassifier.DEFAULT_LOG_TAG;
+
+import android.annotation.Nullable;
+import android.os.LocaleList;
+import android.os.ParcelFileDescriptor;
+import android.text.TextUtils;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.internal.util.Preconditions;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.StringJoiner;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * Manages model files that are listed by the model files supplier.
+ * @hide
+ */
+@VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
+public final class ModelFileManager {
+ private final Object mLock = new Object();
+ private final Supplier<List<ModelFile>> mModelFileSupplier;
+
+ private List<ModelFile> mModelFiles;
+
+ public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) {
+ mModelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ }
+
+ /**
+ * Returns an unmodifiable list of model files listed by the given model files supplier.
+ * <p>
+ * The result is cached.
+ */
+ public List<ModelFile> listModelFiles() {
+ synchronized (mLock) {
+ if (mModelFiles == null) {
+ mModelFiles = Collections.unmodifiableList(mModelFileSupplier.get());
+ }
+ return mModelFiles;
+ }
+ }
+
+ /**
+ * Returns the best model file for the given localelist, {@code null} if nothing is found.
+ *
+ * @param localeList the required locales, use {@code null} if there is no preference.
+ */
+ public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
+ // Specified localeList takes priority over the system default, so it is listed first.
+ final String languages = localeList == null || localeList.isEmpty()
+ ? LocaleList.getDefault().toLanguageTags()
+ : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
+ final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+
+ ModelFile bestModel = null;
+ for (ModelFile model : listModelFiles()) {
+ if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (model.isPreferredTo(bestModel)) {
+ bestModel = model;
+ }
+ }
+ }
+ return bestModel;
+ }
+
+ /**
+ * Default implementation of the model file supplier.
+ */
+ public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> {
+ private final File mUpdatedModelFile;
+ private final File mFactoryModelDir;
+ private final Pattern mModelFilenamePattern;
+ private final Function<Integer, Integer> mVersionSupplier;
+ private final Function<Integer, String> mSupportedLocalesSupplier;
+
+ public ModelFileSupplierImpl(
+ File factoryModelDir,
+ String factoryModelFileNameRegex,
+ File updatedModelFile,
+ Function<Integer, Integer> versionSupplier,
+ Function<Integer, String> supportedLocalesSupplier) {
+ mUpdatedModelFile = Preconditions.checkNotNull(updatedModelFile);
+ mFactoryModelDir = Preconditions.checkNotNull(factoryModelDir);
+ mModelFilenamePattern = Pattern.compile(
+ Preconditions.checkNotNull(factoryModelFileNameRegex));
+ mVersionSupplier = Preconditions.checkNotNull(versionSupplier);
+ mSupportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
+ }
+
+ @Override
+ public List<ModelFile> get() {
+ final List<ModelFile> modelFiles = new ArrayList<>();
+ // The update model has the highest precedence.
+ if (mUpdatedModelFile.exists()) {
+ final ModelFile updatedModel = createModelFile(mUpdatedModelFile);
+ if (updatedModel != null) {
+ modelFiles.add(updatedModel);
+ }
+ }
+ // Factory models should never have overlapping locales, so the order doesn't matter.
+ if (mFactoryModelDir.exists() && mFactoryModelDir.isDirectory()) {
+ final File[] files = mFactoryModelDir.listFiles();
+ for (File file : files) {
+ final Matcher matcher = mModelFilenamePattern.matcher(file.getName());
+ if (matcher.matches() && file.isFile()) {
+ final ModelFile model = createModelFile(file);
+ if (model != null) {
+ modelFiles.add(model);
+ }
+ }
+ }
+ }
+ return modelFiles;
+ }
+
+ /** Returns null if the path did not point to a compatible model. */
+ @Nullable
+ private ModelFile createModelFile(File file) {
+ if (!file.exists()) {
+ return null;
+ }
+ ParcelFileDescriptor modelFd = null;
+ try {
+ modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ if (modelFd == null) {
+ return null;
+ }
+ final int modelFdInt = modelFd.getFd();
+ final int version = mVersionSupplier.apply(modelFdInt);
+ final String supportedLocalesStr = mSupportedLocalesSupplier.apply(modelFdInt);
+ if (supportedLocalesStr.isEmpty()) {
+ Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
+ return null;
+ }
+ final List<Locale> supportedLocales = new ArrayList<>();
+ for (String langTag : supportedLocalesStr.split(",")) {
+ supportedLocales.add(Locale.forLanguageTag(langTag));
+ }
+ return new ModelFile(
+ file,
+ version,
+ supportedLocales,
+ ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
+ } catch (FileNotFoundException e) {
+ Log.e(DEFAULT_LOG_TAG, "Failed to find " + file.getAbsolutePath(), e);
+ return null;
+ } finally {
+ maybeCloseAndLogError(modelFd);
+ }
+ }
+
+ /**
+ * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
+ */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
+ }
+ try {
+ fd.close();
+ } catch (IOException e) {
+ Log.e(DEFAULT_LOG_TAG, "Error closing file.", e);
+ }
+ }
+
+ }
+
+ /**
+ * Describes TextClassifier model files on disk.
+ */
+ public static final class ModelFile {
+ public static final String LANGUAGE_INDEPENDENT = "*";
+
+ private final File mFile;
+ private final int mVersion;
+ private final List<Locale> mSupportedLocales;
+ private final boolean mLanguageIndependent;
+
+ public ModelFile(File file, int version, List<Locale> supportedLocales,
+ boolean languageIndependent) {
+ mFile = Preconditions.checkNotNull(file);
+ mVersion = version;
+ mSupportedLocales = Preconditions.checkNotNull(supportedLocales);
+ mLanguageIndependent = languageIndependent;
+ }
+
+ /** Returns the absolute path to the model file. */
+ public String getPath() {
+ return mFile.getAbsolutePath();
+ }
+
+ /** Returns a name to use for id generation, effectively the name of the model file. */
+ public String getName() {
+ return mFile.getName();
+ }
+
+ /** Returns the version tag in the model's metadata. */
+ public int getVersion() {
+ return mVersion;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null;
+ }
+
+ /** Returns an immutable lists of supported locales. */
+ public List<Locale> getSupportedLocales() {
+ return Collections.unmodifiableList(mSupportedLocales);
+ }
+
+ /**
+ * Returns if this model file is preferred to the given one.
+ */
+ public boolean isPreferredTo(@Nullable ModelFile model) {
+ // A model is preferred to no model.
+ if (model == null) {
+ return true;
+ }
+
+ // A language-specific model is preferred to a language independent
+ // model.
+ if (!mLanguageIndependent && model.mLanguageIndependent) {
+ return true;
+ }
+
+ // A higher-version model is preferred.
+ if (mVersion > model.getVersion()) {
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(getPath());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (other instanceof ModelFile) {
+ final ModelFile otherModel = (ModelFile) other;
+ return TextUtils.equals(getPath(), otherModel.getPath());
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : mSupportedLocales) {
+ localesJoiner.add(locale.toLanguageTag());
+ }
+ return String.format(Locale.US,
+ "ModelFile { path=%s name=%s version=%d locales=%s }",
+ getPath(), getName(), mVersion, localesJoiner.toString());
+ }
+ }
+}
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java
index 7f1e443f4aa5..159bfaa2ab26 100644
--- a/core/java/android/view/textclassifier/TextClassifierImpl.java
+++ b/core/java/android/view/textclassifier/TextClassifierImpl.java
@@ -58,16 +58,12 @@ import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
-import java.util.StringJoiner;
import java.util.concurrent.TimeUnit;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
/**
* Default implementation of the {@link TextClassifier} interface.
@@ -81,13 +77,18 @@ import java.util.regex.Pattern;
public final class TextClassifierImpl implements TextClassifier {
private static final String LOG_TAG = DEFAULT_LOG_TAG;
- private static final String MODEL_DIR = "/etc/textclassifier/";
- private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";
- private static final String UPDATED_MODEL_FILE_PATH =
- "/data/misc/textclassifier/textclassifier.model";
- private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model";
- private static final String UPDATED_LANG_ID_MODEL_FILE_PATH =
- "/data/misc/textclassifier/lang_id.model";
+
+ private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
+ // Annotator
+ private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
+ "textclassifier\\.(.*)\\.model";
+ private static final File ANNOTATOR_UPDATED_MODEL_FILE =
+ new File("/data/misc/textclassifier/textclassifier.model");
+
+ // LangID
+ private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
+ private static final File UPDATED_LANG_ID_MODEL_FILE =
+ new File("/data/misc/textclassifier/lang_id.model");
private final Context mContext;
private final TextClassifier mFallback;
@@ -95,9 +96,7 @@ public final class TextClassifierImpl implements TextClassifier {
private final Object mLock = new Object();
@GuardedBy("mLock") // Do not access outside this lock.
- private List<ModelFile> mAllModelFiles;
- @GuardedBy("mLock") // Do not access outside this lock.
- private ModelFile mModel;
+ private ModelFileManager.ModelFile mAnnotatorModelInUse;
@GuardedBy("mLock") // Do not access outside this lock.
private AnnotatorModel mAnnotatorImpl;
@GuardedBy("mLock") // Do not access outside this lock.
@@ -109,12 +108,29 @@ public final class TextClassifierImpl implements TextClassifier {
private final TextClassificationConstants mSettings;
+ private final ModelFileManager mAnnotatorModelFileManager;
+ private final ModelFileManager mLangIdModelFileManager;
+
public TextClassifierImpl(
Context context, TextClassificationConstants settings, TextClassifier fallback) {
mContext = Preconditions.checkNotNull(context);
mFallback = Preconditions.checkNotNull(fallback);
mSettings = Preconditions.checkNotNull(settings);
mGenerateLinksLogger = new GenerateLinksLogger(mSettings.getGenerateLinksLogSampleRate());
+ mAnnotatorModelFileManager = new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
+ ANNOTATOR_UPDATED_MODEL_FILE,
+ AnnotatorModel::getVersion,
+ AnnotatorModel::getLocales));
+ mLangIdModelFileManager = new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_LANG_ID_MODEL_FILE,
+ fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd)
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
}
public TextClassifierImpl(Context context, TextClassificationConstants settings) {
@@ -334,22 +350,24 @@ public final class TextClassifierImpl implements TextClassifier {
throws FileNotFoundException {
synchronized (mLock) {
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
- final ModelFile bestModel = findBestModelLocked(localeList);
+ final ModelFileManager.ModelFile bestModel =
+ mAnnotatorModelFileManager.findBestModelFile(localeList);
if (bestModel == null) {
- throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
+ throw new FileNotFoundException(
+ "No annotator model for " + localeList.toLanguageTags());
}
- if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {
+ if (mAnnotatorImpl == null || !Objects.equals(mAnnotatorModelInUse, bestModel)) {
Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
destroyAnnotatorImplIfExistsLocked();
- final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
+ final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
try {
- if (fd != null) {
- mAnnotatorImpl = new AnnotatorModel(fd.getFd());
- mModel = bestModel;
+ if (pfd != null) {
+ mAnnotatorImpl = new AnnotatorModel(pfd.getFd());
+ mAnnotatorModelInUse = bestModel;
}
} finally {
- maybeCloseAndLogError(fd);
+ maybeCloseAndLogError(pfd);
}
}
return mAnnotatorImpl;
@@ -367,40 +385,19 @@ public final class TextClassifierImpl implements TextClassifier {
private LangIdModel getLangIdImpl() throws FileNotFoundException {
synchronized (mLock) {
if (mLangIdImpl == null) {
- ParcelFileDescriptor factoryFd = null;
- ParcelFileDescriptor updateFd = null;
+ final ModelFileManager.ModelFile bestModel =
+ mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
+ if (bestModel == null) {
+ throw new FileNotFoundException("No LangID model is found");
+ }
+ final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
try {
- int factoryVersion = -1;
- int updateVersion = factoryVersion;
- final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH);
- if (factoryFile.exists()) {
- factoryFd = ParcelFileDescriptor.open(
- factoryFile, ParcelFileDescriptor.MODE_READ_ONLY);
- // TODO: Uncomment when method is implemented:
- // if (factoryFd != null) {
- // factoryVersion = LangIdModel.getVersion(factoryFd.getFd());
- // }
- }
- final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH);
- if (updateFile.exists()) {
- updateFd = ParcelFileDescriptor.open(
- updateFile, ParcelFileDescriptor.MODE_READ_ONLY);
- // TODO: Uncomment when method is implemented:
- // if (updateFd != null) {
- // updateVersion = LangIdModel.getVersion(updateFd.getFd());
- // }
- }
-
- if (updateVersion > factoryVersion) {
- mLangIdImpl = new LangIdModel(updateFd.getFd());
- } else if (factoryFd != null) {
- mLangIdImpl = new LangIdModel(factoryFd.getFd());
- } else {
- throw new FileNotFoundException("Language detection model not found");
+ if (pfd != null) {
+ mLangIdImpl = new LangIdModel(pfd.getFd());
}
} finally {
- maybeCloseAndLogError(factoryFd);
- maybeCloseAndLogError(updateFd);
+ maybeCloseAndLogError(pfd);
}
}
return mLangIdImpl;
@@ -409,8 +406,9 @@ public final class TextClassifierImpl implements TextClassifier {
private String createId(String text, int start, int end) {
synchronized (mLock) {
- return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
- mModel.getSupportedLocales());
+ return SelectionSessionLogger.createId(text, start, end, mContext,
+ mAnnotatorModelInUse.getVersion(),
+ mAnnotatorModelInUse.getSupportedLocales());
}
}
@@ -418,66 +416,6 @@ public final class TextClassifierImpl implements TextClassifier {
return (locales == null) ? "" : locales.toLanguageTags();
}
- /**
- * Finds the most appropriate model to use for the given target locale list.
- *
- * The basic logic is: we ignore all models that don't support any of the target locales. For
- * the remaining candidates, we take the update model unless its version number is lower than
- * the factory version. It's assumed that factory models do not have overlapping locale ranges
- * and conflict resolution between these models hence doesn't matter.
- */
- @GuardedBy("mLock") // Do not call outside this lock.
- @Nullable
- private ModelFile findBestModelLocked(LocaleList localeList) {
- // Specified localeList takes priority over the system default, so it is listed first.
- final String languages = localeList.isEmpty()
- ? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
- final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
-
- ModelFile bestModel = null;
- for (ModelFile model : listAllModelsLocked()) {
- if (model.isAnyLanguageSupported(languageRangeList)) {
- if (model.isPreferredTo(bestModel)) {
- bestModel = model;
- }
- }
- }
- return bestModel;
- }
-
- /** Returns a list of all model files available, in order of precedence. */
- @GuardedBy("mLock") // Do not call outside this lock.
- private List<ModelFile> listAllModelsLocked() {
- if (mAllModelFiles == null) {
- final List<ModelFile> allModels = new ArrayList<>();
- // The update model has the highest precedence.
- if (new File(UPDATED_MODEL_FILE_PATH).exists()) {
- final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH);
- if (updatedModel != null) {
- allModels.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- final File modelsDir = new File(MODEL_DIR);
- if (modelsDir.exists() && modelsDir.isDirectory()) {
- final File[] modelFiles = modelsDir.listFiles();
- final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX);
- for (File modelFile : modelFiles) {
- final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName());
- if (matcher.matches() && modelFile.isFile()) {
- final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath());
- if (model != null) {
- allModels.add(model);
- }
- }
- }
- }
- mAllModelFiles = allModels;
- }
- return mAllModelFiles;
- }
-
private TextClassification createClassificationResult(
AnnotatorModel.ClassificationResult[] classifications,
String text, int start, int end, @Nullable Instant referenceTime) {
@@ -523,12 +461,18 @@ public final class TextClassifierImpl implements TextClassifier {
@Override
public void dump(@NonNull IndentingPrintWriter printWriter) {
synchronized (mLock) {
- listAllModelsLocked();
printWriter.println("TextClassifierImpl:");
printWriter.increaseIndent();
- printWriter.println("Model file(s):");
+ printWriter.println("Annotator model file(s):");
printWriter.increaseIndent();
- for (ModelFile modelFile : mAllModelFiles) {
+ for (ModelFileManager.ModelFile modelFile :
+ mAnnotatorModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("LangID model file(s):");
+ for (ModelFileManager.ModelFile modelFile :
+ mLangIdModelFileManager.listModelFiles()) {
printWriter.println(modelFile.toString());
}
printWriter.decreaseIndent();
@@ -554,126 +498,6 @@ public final class TextClassifierImpl implements TextClassifier {
}
/**
- * Describes TextClassifier model files on disk.
- */
- private static final class ModelFile {
-
- private final String mPath;
- private final String mName;
- private final int mVersion;
- private final List<Locale> mSupportedLocales;
- private final boolean mLanguageIndependent;
-
- /** Returns null if the path did not point to a compatible model. */
- static @Nullable ModelFile fromPath(String path) {
- final File file = new File(path);
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int version = AnnotatorModel.getVersion(modelFd.getFd());
- final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());
- if (supportedLocalesStr.isEmpty()) {
- Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final boolean languageIndependent = supportedLocalesStr.equals("*");
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : supportedLocalesStr.split(",")) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(path, file.getName(), version, supportedLocales,
- languageIndependent);
- } catch (FileNotFoundException e) {
- Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** The absolute path to the model file. */
- String getPath() {
- return mPath;
- }
-
- /** A name to use for id generation. Effectively the name of the model file. */
- String getName() {
- return mName;
- }
-
- /** Returns the version tag in the model's metadata. */
- int getVersion() {
- return mVersion;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null;
- }
-
- /** All locales supported by the model. */
- List<Locale> getSupportedLocales() {
- return Collections.unmodifiableList(mSupportedLocales);
- }
-
- public boolean isPreferredTo(ModelFile model) {
- // A model is preferred to no model.
- if (model == null) {
- return true;
- }
-
- // A language-specific model is preferred to a language independent
- // model.
- if (!mLanguageIndependent && model.mLanguageIndependent) {
- return true;
- }
-
- // A higher-version model is preferred.
- if (getVersion() > model.getVersion()) {
- return true;
- }
- return false;
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return mPath.equals(otherModel.mPath);
- }
- return false;
- }
-
- @Override
- public String toString() {
- final StringJoiner localesJoiner = new StringJoiner(",");
- for (Locale locale : mSupportedLocales) {
- localesJoiner.add(locale.toLanguageTag());
- }
- return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }",
- mPath, mName, mVersion, localesJoiner.toString());
- }
-
- private ModelFile(String path, String name, int version, List<Locale> supportedLocales,
- boolean languageIndependent) {
- mPath = path;
- mName = name;
- mVersion = version;
- mSupportedLocales = supportedLocales;
- mLanguageIndependent = languageIndependent;
- }
- }
-
- /**
* Helper class to store the information from which RemoteActions are built.
*/
private static final class LabeledIntent {
diff --git a/core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java b/core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java
new file mode 100644
index 000000000000..51e5aec8b219
--- /dev/null
+++ b/core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java
@@ -0,0 +1,301 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.view.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.mockito.Mockito.when;
+
+import android.os.LocaleList;
+import android.support.test.InstrumentationRegistry;
+import android.support.test.filters.SmallTest;
+import android.support.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ModelFileManagerTest {
+
+ @Mock
+ private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier;
+ private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl;
+ private ModelFileManager mModelFileManager;
+ private File mRootTestDir;
+ private File mFactoryModelDir;
+ private File mUpdatedModelFile;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ mModelFileManager = new ModelFileManager(mModelFileSupplier);
+ mRootTestDir = InstrumentationRegistry.getContext().getCacheDir();
+ mFactoryModelDir = new File(mRootTestDir, "factory");
+ mUpdatedModelFile = new File(mRootTestDir, "updated.model");
+
+ mModelFileSupplierImpl =
+ new ModelFileManager.ModelFileSupplierImpl(
+ mFactoryModelDir,
+ "test\\d.model",
+ mUpdatedModelFile,
+ fd -> 1,
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT
+ );
+
+ mRootTestDir.mkdirs();
+ mFactoryModelDir.mkdirs();
+ }
+
+ @After
+ public void removeTestDir() {
+ recursiveDelete(mRootTestDir);
+ }
+
+ @Test
+ public void get() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1, Collections.emptyList(), true);
+ when(mModelFileSupplier.get()).thenReturn(Collections.singletonList(modelFile));
+
+ List<ModelFileManager.ModelFile> modelFiles = mModelFileManager.listModelFiles();
+
+ assertThat(modelFiles).hasSize(1);
+ assertThat(modelFiles.get(0)).isEqualTo(modelFile);
+ }
+
+ @Test
+ public void findBestModel_versionCode() {
+ ModelFileManager.ModelFile olderModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.emptyList(), true);
+
+ ModelFileManager.ModelFile newerModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 2,
+ Collections.emptyList(), true);
+ when(mModelFileSupplier.get())
+ .thenReturn(Arrays.asList(olderModelFile, newerModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ mModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
+
+ assertThat(bestModelFile).isEqualTo(newerModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageDependentModelIsPreferred() {
+ Locale locale = Locale.forLanguageTag("ja");
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.emptyList(), true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 1,
+ Collections.singletonList(locale), false);
+ when(mModelFileSupplier.get())
+ .thenReturn(
+ Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ mModelFileManager.findBestModelFile(
+ LocaleList.forLanguageTags(locale.toLanguageTag()));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_useIndependentWhenNoLanguageModelMatch() {
+ Locale locale = Locale.forLanguageTag("ja");
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.emptyList(), true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 1,
+ Collections.singletonList(locale), false);
+
+ when(mModelFileSupplier.get())
+ .thenReturn(
+ Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ mModelFileManager.findBestModelFile(
+ LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion() {
+ ModelFileManager.ModelFile matchButOlderModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("fr")), false);
+
+ ModelFileManager.ModelFile mismatchButNewerModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 2,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ when(mModelFileSupplier.get())
+ .thenReturn(
+ Arrays.asList(matchButOlderModel, mismatchButNewerModel));
+
+ ModelFileManager.ModelFile bestModelFile =
+ mModelFileManager.findBestModelFile(
+ LocaleList.forLanguageTags("fr"));
+ assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+ }
+
+ @Test
+ public void modelFileEquals() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ assertThat(modelA).isEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_different() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ assertThat(modelA).isNotEqualTo(modelB);
+ }
+
+
+ @Test
+ public void modelFile_getPath() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ assertThat(modelA.getPath()).isEqualTo("/path/a");
+ }
+
+ @Test
+ public void modelFile_getName() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ assertThat(modelA.getName()).isEqualTo("a");
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_languageDependentIsBetter() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 2,
+ Collections.emptyList(), true);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_version() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 2,
+ Collections.singletonList(Locale.forLanguageTag("ja")), false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"), 1,
+ Collections.emptyList(), false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void testFileSupplierImpl_updatedFileOnly() throws IOException {
+ mUpdatedModelFile.createNewFile();
+ File model1 = new File(mFactoryModelDir, "test1.model");
+ model1.createNewFile();
+ File model2 = new File(mFactoryModelDir, "test2.model");
+ model2.createNewFile();
+ new File(mFactoryModelDir, "not_match_regex.model").createNewFile();
+
+ List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get();
+ List<String> modelFilePaths =
+ modelFiles
+ .stream()
+ .map(modelFile -> modelFile.getPath())
+ .collect(Collectors.toList());
+
+ assertThat(modelFiles).hasSize(3);
+ assertThat(modelFilePaths).containsExactly(
+ mUpdatedModelFile.getAbsolutePath(),
+ model1.getAbsolutePath(),
+ model2.getAbsolutePath());
+ }
+
+ @Test
+ public void testFileSupplierImpl_empty() {
+ mFactoryModelDir.delete();
+ List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get();
+
+ assertThat(modelFiles).hasSize(0);
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+}