diff options
| -rw-r--r-- | core/java/android/view/textclassifier/TextClassifierImpl.java | 172 | ||||
| -rw-r--r-- | core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java | 37 |
2 files changed, 168 insertions, 41 deletions
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index 3e240cfdb69f..7f1e443f4aa5 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -31,6 +31,7 @@ import android.content.Intent; import android.content.pm.PackageManager; import android.content.pm.ResolveInfo; import android.graphics.drawable.Icon; +import android.icu.util.ULocale; import android.net.Uri; import android.os.Bundle; import android.os.LocaleList; @@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter; import com.android.internal.util.Preconditions; import com.google.android.textclassifier.AnnotatorModel; +import com.google.android.textclassifier.LangIdModel; import java.io.File; import java.io.FileNotFoundException; @@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier { private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model"; private static final String UPDATED_MODEL_FILE_PATH = "/data/misc/textclassifier/textclassifier.model"; + private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model"; + private static final String UPDATED_LANG_ID_MODEL_FILE_PATH = + "/data/misc/textclassifier/lang_id.model"; private final Context mContext; private final TextClassifier mFallback; @@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier { @GuardedBy("mLock") // Do not access outside this lock. private ModelFile mModel; @GuardedBy("mLock") // Do not access outside this lock. - private AnnotatorModel mNative; + private AnnotatorModel mAnnotatorImpl; + @GuardedBy("mLock") // Do not access outside this lock. + private LangIdModel mLangIdImpl; private final Object mLoggerLock = new Object(); @GuardedBy("mLoggerLock") // Do not access outside this lock. @@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier { && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) { final String localesString = concatenateLocales(request.getDefaultLocales()); final ZonedDateTime refTime = ZonedDateTime.now(); - final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales()); + final AnnotatorModel annotatorImpl = + getAnnotatorImpl(request.getDefaultLocales()); final int start; final int end; if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) { start = request.getStartIndex(); end = request.getEndIndex(); } else { - final int[] startEnd = nativeImpl.suggestSelection( + final int[] startEnd = annotatorImpl.suggestSelection( string, request.getStartIndex(), request.getEndIndex(), new AnnotatorModel.SelectionOptions(localesString)); start = startEnd[0]; @@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier { && start <= request.getStartIndex() && end >= request.getEndIndex()) { final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end); final AnnotatorModel.ClassificationResult[] results = - nativeImpl.classifyText( + annotatorImpl.classifyText( string, start, end, new AnnotatorModel.ClassificationOptions( refTime.toInstant().toEpochMilli(), @@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier { final ZonedDateTime refTime = request.getReferenceTime() != null ? request.getReferenceTime() : ZonedDateTime.now(); final AnnotatorModel.ClassificationResult[] results = - getNative(request.getDefaultLocales()) + getAnnotatorImpl(request.getDefaultLocales()) .classifyText( string, request.getStartIndex(), request.getEndIndex(), new AnnotatorModel.ClassificationOptions( @@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier { ? request.getEntityConfig().resolveEntityListModifications( getEntitiesForHints(request.getEntityConfig().getHints())) : mSettings.getEntityListDefault(); - final AnnotatorModel nativeImpl = - getNative(request.getDefaultLocales()); + final AnnotatorModel annotatorImpl = + getAnnotatorImpl(request.getDefaultLocales()); final AnnotatorModel.AnnotatedSpan[] annotations = - nativeImpl.annotate( + annotatorImpl.annotate( textString, new AnnotatorModel.AnnotationOptions( refTime.toInstant().toEpochMilli(), @@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier { } } + /** @inheritDoc */ @Override public void onSelectionEvent(SelectionEvent event) { Preconditions.checkNotNull(event); @@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier { } } - private AnnotatorModel getNative(LocaleList localeList) + /** @inheritDoc */ + @Override + public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) { + Preconditions.checkNotNull(request); + Utils.checkMainThread(); + try { + final TextLanguage.Builder builder = new TextLanguage.Builder(); + final LangIdModel.LanguageResult[] langResults = + getLangIdImpl().detectLanguages(request.getText().toString()); + for (int i = 0; i < langResults.length; i++) { + builder.putLocale( + ULocale.forLanguageTag(langResults[i].getLanguage()), + langResults[i].getScore()); + } + return builder.build(); + } catch (Throwable t) { + // Avoid throwing from this method. Log the error. + Log.e(LOG_TAG, "Error detecting text language.", t); + } + return mFallback.detectLanguage(request); + } + + private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException { synchronized (mLock) { localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; @@ -307,31 +338,79 @@ public final class TextClassifierImpl implements TextClassifier { if (bestModel == null) { throw new FileNotFoundException("No model for " + localeList.toLanguageTags()); } - if (mNative == null || !Objects.equals(mModel, bestModel)) { + if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) { Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); - destroyNativeIfExistsLocked(); + destroyAnnotatorImplIfExistsLocked(); final ParcelFileDescriptor fd = ParcelFileDescriptor.open( new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); - mNative = new AnnotatorModel(fd.getFd()); - closeAndLogError(fd); - mModel = bestModel; + try { + if (fd != null) { + mAnnotatorImpl = new AnnotatorModel(fd.getFd()); + mModel = bestModel; + } + } finally { + maybeCloseAndLogError(fd); + } } - return mNative; + return mAnnotatorImpl; } } - private String createId(String text, int start, int end) { + @GuardedBy("mLock") // Do not call outside this lock. + private void destroyAnnotatorImplIfExistsLocked() { + if (mAnnotatorImpl != null) { + mAnnotatorImpl.close(); + mAnnotatorImpl = null; + } + } + + private LangIdModel getLangIdImpl() throws FileNotFoundException { synchronized (mLock) { - return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(), - mModel.getSupportedLocales()); + if (mLangIdImpl == null) { + ParcelFileDescriptor factoryFd = null; + ParcelFileDescriptor updateFd = null; + try { + int factoryVersion = -1; + int updateVersion = factoryVersion; + final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH); + if (factoryFile.exists()) { + factoryFd = ParcelFileDescriptor.open( + factoryFile, ParcelFileDescriptor.MODE_READ_ONLY); + // TODO: Uncomment when method is implemented: + // if (factoryFd != null) { + // factoryVersion = LangIdModel.getVersion(factoryFd.getFd()); + // } + } + final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH); + if (updateFile.exists()) { + updateFd = ParcelFileDescriptor.open( + updateFile, ParcelFileDescriptor.MODE_READ_ONLY); + // TODO: Uncomment when method is implemented: + // if (updateFd != null) { + // updateVersion = LangIdModel.getVersion(updateFd.getFd()); + // } + } + + if (updateVersion > factoryVersion) { + mLangIdImpl = new LangIdModel(updateFd.getFd()); + } else if (factoryFd != null) { + mLangIdImpl = new LangIdModel(factoryFd.getFd()); + } else { + throw new FileNotFoundException("Language detection model not found"); + } + } finally { + maybeCloseAndLogError(factoryFd); + maybeCloseAndLogError(updateFd); + } + } + return mLangIdImpl; } } - @GuardedBy("mLock") // Do not call outside this lock. - private void destroyNativeIfExistsLocked() { - if (mNative != null) { - mNative.close(); - mNative = null; + private String createId(String text, int start, int end) { + synchronized (mLock) { + return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(), + mModel.getSupportedLocales()); } } @@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier { .setText(classifiedText); final int size = classifications.length; - AnnotatorModel.ClassificationResult highestScoringResult = null; - float highestScore = Float.MIN_VALUE; + AnnotatorModel.ClassificationResult highestScoringResult = + size > 0 ? classifications[0] : null; for (int i = 0; i < size; i++) { builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore()); - if (classifications[i].getScore() > highestScore) { + if (classifications[i].getScore() > highestScoringResult.getScore()) { highestScoringResult = classifications[i]; - highestScore = classifications[i].getScore(); } } boolean isPrimaryAction = true; for (LabeledIntent labeledIntent : IntentFactory.create( - mContext, referenceTime, highestScoringResult, classifiedText)) { + mContext, classifiedText, referenceTime, highestScoringResult)) { final RemoteAction action = labeledIntent.asRemoteAction(mContext); if (action == null) { continue; @@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier { } /** - * Closes the ParcelFileDescriptor and logs any errors that occur. + * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */ - private static void closeAndLogError(ParcelFileDescriptor fd) { + private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) { + if (fd == null) { + return; + } + try { fd.close(); } catch (IOException e) { @@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier { /** Returns null if the path did not point to a compatible model. */ static @Nullable ModelFile fromPath(String path) { final File file = new File(path); + if (!file.exists()) { + return null; + } + ParcelFileDescriptor modelFd = null; try { - final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open( - file, ParcelFileDescriptor.MODE_READ_ONLY); + modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY); + if (modelFd == null) { + return null; + } final int version = AnnotatorModel.getVersion(modelFd.getFd()); - final String supportedLocalesStr = - AnnotatorModel.getLocales(modelFd.getFd()); + final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd()); if (supportedLocalesStr.isEmpty()) { Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath()); return null; @@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier { for (String langTag : supportedLocalesStr.split(",")) { supportedLocales.add(Locale.forLanguageTag(langTag)); } - closeAndLogError(modelFd); return new ModelFile(path, file.getName(), version, supportedLocales, languageIndependent); } catch (FileNotFoundException e) { Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e); return null; + } finally { + maybeCloseAndLogError(modelFd); } } @@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier { public boolean equals(Object other) { if (this == other) { return true; - } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) { - return false; - } else { + } + if (other instanceof ModelFile) { final ModelFile otherModel = (ModelFile) other; return mPath.equals(otherModel.mPath); } + return false; } @Override @@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier { @NonNull public static List<LabeledIntent> create( Context context, + String text, @Nullable Instant referenceTime, - AnnotatorModel.ClassificationResult classification, - String text) { - final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH); + @Nullable AnnotatorModel.ClassificationResult classification) { + final String type = classification != null + ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH) + : null; text = text.trim(); switch (type) { case TextClassifier.TYPE_EMAIL: diff --git a/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java b/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java index e891fc9d6134..8646c685c998 100644 --- a/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java +++ b/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java @@ -307,6 +307,24 @@ public class TextClassificationManagerTest { } @Test + public void testDetectLanguage() { + if (isTextClassifierDisabled()) return; + String text = "This is English text"; + TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); + TextLanguage textLanguage = mClassifier.detectLanguage(request); + assertThat(textLanguage, isTextLanguage("en")); + } + + @Test + public void testDetectLanguage_japanese() { + if (isTextClassifierDisabled()) return; + String text = "これは日本語のテキストです"; + TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); + TextLanguage textLanguage = mClassifier.detectLanguage(request); + assertThat(textLanguage, isTextLanguage("ja")); + } + + @Test public void testSetTextClassifier() { TextClassifier classifier = mock(TextClassifier.class); mTcm.setTextClassifier(classifier); @@ -444,4 +462,23 @@ public class TextClassificationManagerTest { } }; } + + private static Matcher<TextLanguage> isTextLanguage(final String languageTag) { + return new BaseMatcher<TextLanguage>() { + @Override + public boolean matches(Object o) { + if (o instanceof TextLanguage) { + TextLanguage result = (TextLanguage) o; + return result.getLocaleHypothesisCount() > 0 + && languageTag.equals(result.getLocale(0).toLanguageTag()); + } + return false; + } + + @Override + public void describeTo(Description description) { + description.appendText("locale=").appendValue(languageTag); + } + }; + } } |