summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/java/android/view/textclassifier/TextClassifierImpl.java172
-rw-r--r--core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java37
2 files changed, 168 insertions, 41 deletions
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java
index 3e240cfdb69f..7f1e443f4aa5 100644
--- a/core/java/android/view/textclassifier/TextClassifierImpl.java
+++ b/core/java/android/view/textclassifier/TextClassifierImpl.java
@@ -31,6 +31,7 @@ import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import android.graphics.drawable.Icon;
+import android.icu.util.ULocale;
import android.net.Uri;
import android.os.Bundle;
import android.os.LocaleList;
@@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;
import com.google.android.textclassifier.AnnotatorModel;
+import com.google.android.textclassifier.LangIdModel;
import java.io.File;
import java.io.FileNotFoundException;
@@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier {
private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";
private static final String UPDATED_MODEL_FILE_PATH =
"/data/misc/textclassifier/textclassifier.model";
+ private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model";
+ private static final String UPDATED_LANG_ID_MODEL_FILE_PATH =
+ "/data/misc/textclassifier/lang_id.model";
private final Context mContext;
private final TextClassifier mFallback;
@@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier {
@GuardedBy("mLock") // Do not access outside this lock.
private ModelFile mModel;
@GuardedBy("mLock") // Do not access outside this lock.
- private AnnotatorModel mNative;
+ private AnnotatorModel mAnnotatorImpl;
+ @GuardedBy("mLock") // Do not access outside this lock.
+ private LangIdModel mLangIdImpl;
private final Object mLoggerLock = new Object();
@GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier {
&& rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
final String localesString = concatenateLocales(request.getDefaultLocales());
final ZonedDateTime refTime = ZonedDateTime.now();
- final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
+ final AnnotatorModel annotatorImpl =
+ getAnnotatorImpl(request.getDefaultLocales());
final int start;
final int end;
if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
start = request.getStartIndex();
end = request.getEndIndex();
} else {
- final int[] startEnd = nativeImpl.suggestSelection(
+ final int[] startEnd = annotatorImpl.suggestSelection(
string, request.getStartIndex(), request.getEndIndex(),
new AnnotatorModel.SelectionOptions(localesString));
start = startEnd[0];
@@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier {
&& start <= request.getStartIndex() && end >= request.getEndIndex()) {
final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
final AnnotatorModel.ClassificationResult[] results =
- nativeImpl.classifyText(
+ annotatorImpl.classifyText(
string, start, end,
new AnnotatorModel.ClassificationOptions(
refTime.toInstant().toEpochMilli(),
@@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier {
final ZonedDateTime refTime = request.getReferenceTime() != null
? request.getReferenceTime() : ZonedDateTime.now();
final AnnotatorModel.ClassificationResult[] results =
- getNative(request.getDefaultLocales())
+ getAnnotatorImpl(request.getDefaultLocales())
.classifyText(
string, request.getStartIndex(), request.getEndIndex(),
new AnnotatorModel.ClassificationOptions(
@@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier {
? request.getEntityConfig().resolveEntityListModifications(
getEntitiesForHints(request.getEntityConfig().getHints()))
: mSettings.getEntityListDefault();
- final AnnotatorModel nativeImpl =
- getNative(request.getDefaultLocales());
+ final AnnotatorModel annotatorImpl =
+ getAnnotatorImpl(request.getDefaultLocales());
final AnnotatorModel.AnnotatedSpan[] annotations =
- nativeImpl.annotate(
+ annotatorImpl.annotate(
textString,
new AnnotatorModel.AnnotationOptions(
refTime.toInstant().toEpochMilli(),
@@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
+ /** @inheritDoc */
@Override
public void onSelectionEvent(SelectionEvent event) {
Preconditions.checkNotNull(event);
@@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
- private AnnotatorModel getNative(LocaleList localeList)
+ /** @inheritDoc */
+ @Override
+ public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
+ Preconditions.checkNotNull(request);
+ Utils.checkMainThread();
+ try {
+ final TextLanguage.Builder builder = new TextLanguage.Builder();
+ final LangIdModel.LanguageResult[] langResults =
+ getLangIdImpl().detectLanguages(request.getText().toString());
+ for (int i = 0; i < langResults.length; i++) {
+ builder.putLocale(
+ ULocale.forLanguageTag(langResults[i].getLanguage()),
+ langResults[i].getScore());
+ }
+ return builder.build();
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ Log.e(LOG_TAG, "Error detecting text language.", t);
+ }
+ return mFallback.detectLanguage(request);
+ }
+
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
throws FileNotFoundException {
synchronized (mLock) {
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -307,31 +338,79 @@ public final class TextClassifierImpl implements TextClassifier {
if (bestModel == null) {
throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
}
- if (mNative == null || !Objects.equals(mModel, bestModel)) {
+ if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {
Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
- destroyNativeIfExistsLocked();
+ destroyAnnotatorImplIfExistsLocked();
final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- mNative = new AnnotatorModel(fd.getFd());
- closeAndLogError(fd);
- mModel = bestModel;
+ try {
+ if (fd != null) {
+ mAnnotatorImpl = new AnnotatorModel(fd.getFd());
+ mModel = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(fd);
+ }
}
- return mNative;
+ return mAnnotatorImpl;
}
}
- private String createId(String text, int start, int end) {
+ @GuardedBy("mLock") // Do not call outside this lock.
+ private void destroyAnnotatorImplIfExistsLocked() {
+ if (mAnnotatorImpl != null) {
+ mAnnotatorImpl.close();
+ mAnnotatorImpl = null;
+ }
+ }
+
+ private LangIdModel getLangIdImpl() throws FileNotFoundException {
synchronized (mLock) {
- return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
- mModel.getSupportedLocales());
+ if (mLangIdImpl == null) {
+ ParcelFileDescriptor factoryFd = null;
+ ParcelFileDescriptor updateFd = null;
+ try {
+ int factoryVersion = -1;
+ int updateVersion = factoryVersion;
+ final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH);
+ if (factoryFile.exists()) {
+ factoryFd = ParcelFileDescriptor.open(
+ factoryFile, ParcelFileDescriptor.MODE_READ_ONLY);
+ // TODO: Uncomment when method is implemented:
+ // if (factoryFd != null) {
+ // factoryVersion = LangIdModel.getVersion(factoryFd.getFd());
+ // }
+ }
+ final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH);
+ if (updateFile.exists()) {
+ updateFd = ParcelFileDescriptor.open(
+ updateFile, ParcelFileDescriptor.MODE_READ_ONLY);
+ // TODO: Uncomment when method is implemented:
+ // if (updateFd != null) {
+ // updateVersion = LangIdModel.getVersion(updateFd.getFd());
+ // }
+ }
+
+ if (updateVersion > factoryVersion) {
+ mLangIdImpl = new LangIdModel(updateFd.getFd());
+ } else if (factoryFd != null) {
+ mLangIdImpl = new LangIdModel(factoryFd.getFd());
+ } else {
+ throw new FileNotFoundException("Language detection model not found");
+ }
+ } finally {
+ maybeCloseAndLogError(factoryFd);
+ maybeCloseAndLogError(updateFd);
+ }
+ }
+ return mLangIdImpl;
}
}
- @GuardedBy("mLock") // Do not call outside this lock.
- private void destroyNativeIfExistsLocked() {
- if (mNative != null) {
- mNative.close();
- mNative = null;
+ private String createId(String text, int start, int end) {
+ synchronized (mLock) {
+ return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
+ mModel.getSupportedLocales());
}
}
@@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier {
.setText(classifiedText);
final int size = classifications.length;
- AnnotatorModel.ClassificationResult highestScoringResult = null;
- float highestScore = Float.MIN_VALUE;
+ AnnotatorModel.ClassificationResult highestScoringResult =
+ size > 0 ? classifications[0] : null;
for (int i = 0; i < size; i++) {
builder.setEntityType(classifications[i].getCollection(),
classifications[i].getScore());
- if (classifications[i].getScore() > highestScore) {
+ if (classifications[i].getScore() > highestScoringResult.getScore()) {
highestScoringResult = classifications[i];
- highestScore = classifications[i].getScore();
}
}
boolean isPrimaryAction = true;
for (LabeledIntent labeledIntent : IntentFactory.create(
- mContext, referenceTime, highestScoringResult, classifiedText)) {
+ mContext, classifiedText, referenceTime, highestScoringResult)) {
final RemoteAction action = labeledIntent.asRemoteAction(mContext);
if (action == null) {
continue;
@@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier {
}
/**
- * Closes the ParcelFileDescriptor and logs any errors that occur.
+ * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
*/
- private static void closeAndLogError(ParcelFileDescriptor fd) {
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
+ }
+
try {
fd.close();
} catch (IOException e) {
@@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier {
/** Returns null if the path did not point to a compatible model. */
static @Nullable ModelFile fromPath(String path) {
final File file = new File(path);
+ if (!file.exists()) {
+ return null;
+ }
+ ParcelFileDescriptor modelFd = null;
try {
- final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
- file, ParcelFileDescriptor.MODE_READ_ONLY);
+ modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ if (modelFd == null) {
+ return null;
+ }
final int version = AnnotatorModel.getVersion(modelFd.getFd());
- final String supportedLocalesStr =
- AnnotatorModel.getLocales(modelFd.getFd());
+ final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());
if (supportedLocalesStr.isEmpty()) {
Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
return null;
@@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier {
for (String langTag : supportedLocalesStr.split(",")) {
supportedLocales.add(Locale.forLanguageTag(langTag));
}
- closeAndLogError(modelFd);
return new ModelFile(path, file.getName(), version, supportedLocales,
languageIndependent);
} catch (FileNotFoundException e) {
Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
return null;
+ } finally {
+ maybeCloseAndLogError(modelFd);
}
}
@@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier {
public boolean equals(Object other) {
if (this == other) {
return true;
- } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
- return false;
- } else {
+ }
+ if (other instanceof ModelFile) {
final ModelFile otherModel = (ModelFile) other;
return mPath.equals(otherModel.mPath);
}
+ return false;
}
@Override
@@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier {
@NonNull
public static List<LabeledIntent> create(
Context context,
+ String text,
@Nullable Instant referenceTime,
- AnnotatorModel.ClassificationResult classification,
- String text) {
- final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
+ @Nullable AnnotatorModel.ClassificationResult classification) {
+ final String type = classification != null
+ ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
+ : null;
text = text.trim();
switch (type) {
case TextClassifier.TYPE_EMAIL:
diff --git a/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java b/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java
index e891fc9d6134..8646c685c998 100644
--- a/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java
+++ b/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java
@@ -307,6 +307,24 @@ public class TextClassificationManagerTest {
}
@Test
+ public void testDetectLanguage() {
+ if (isTextClassifierDisabled()) return;
+ String text = "This is English text";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+ TextLanguage textLanguage = mClassifier.detectLanguage(request);
+ assertThat(textLanguage, isTextLanguage("en"));
+ }
+
+ @Test
+ public void testDetectLanguage_japanese() {
+ if (isTextClassifierDisabled()) return;
+ String text = "これは日本語のテキストです";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+ TextLanguage textLanguage = mClassifier.detectLanguage(request);
+ assertThat(textLanguage, isTextLanguage("ja"));
+ }
+
+ @Test
public void testSetTextClassifier() {
TextClassifier classifier = mock(TextClassifier.class);
mTcm.setTextClassifier(classifier);
@@ -444,4 +462,23 @@ public class TextClassificationManagerTest {
}
};
}
+
+ private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
+ return new BaseMatcher<TextLanguage>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextLanguage) {
+ TextLanguage result = (TextLanguage) o;
+ return result.getLocaleHypothesisCount() > 0
+ && languageTag.equals(result.getLocale(0).toLanguageTag());
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("locale=").appendValue(languageTag);
+ }
+ };
+ }
}