diff options
| author | 2018-10-19 20:58:26 +0100 | |
|---|---|---|
| committer | 2018-10-24 15:55:12 +0100 | |
| commit | ee3a48eec0bd27aec554fdece6356f646706785c (patch) | |
| tree | fd1510f4cd0addeb17a94c62116851dddcefc822 | |
| parent | c2896a27fadc416458e883282bb0d8a0f81ee13a (diff) | |
Implement TextClassifierImpl.detectLanguage()
- Includes some fixes to handle null ParcelFileDescriptors.
- Closes fds immediately after the model has been loaded.
Bug: 116020587
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Ieb05d081847ac218d2a5b46db95cd512838f67ab
| -rw-r--r-- | core/java/android/view/textclassifier/TextClassifierImpl.java | 172 | ||||
| -rw-r--r-- | core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java | 37 | 
2 files changed, 168 insertions, 41 deletions
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index 3e240cfdb69f..7f1e443f4aa5 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -31,6 +31,7 @@ import android.content.Intent;  import android.content.pm.PackageManager;  import android.content.pm.ResolveInfo;  import android.graphics.drawable.Icon; +import android.icu.util.ULocale;  import android.net.Uri;  import android.os.Bundle;  import android.os.LocaleList; @@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter;  import com.android.internal.util.Preconditions;  import com.google.android.textclassifier.AnnotatorModel; +import com.google.android.textclassifier.LangIdModel;  import java.io.File;  import java.io.FileNotFoundException; @@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier {      private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";      private static final String UPDATED_MODEL_FILE_PATH =              "/data/misc/textclassifier/textclassifier.model"; +    private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model"; +    private static final String UPDATED_LANG_ID_MODEL_FILE_PATH = +            "/data/misc/textclassifier/lang_id.model";      private final Context mContext;      private final TextClassifier mFallback; @@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier {      @GuardedBy("mLock") // Do not access outside this lock.      private ModelFile mModel;      @GuardedBy("mLock") // Do not access outside this lock. -    private AnnotatorModel mNative; +    private AnnotatorModel mAnnotatorImpl; +    @GuardedBy("mLock") // Do not access outside this lock. +    private LangIdModel mLangIdImpl;      private final Object mLoggerLock = new Object();      @GuardedBy("mLoggerLock") // Do not access outside this lock. @@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier {                      && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {                  final String localesString = concatenateLocales(request.getDefaultLocales());                  final ZonedDateTime refTime = ZonedDateTime.now(); -                final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales()); +                final AnnotatorModel annotatorImpl = +                        getAnnotatorImpl(request.getDefaultLocales());                  final int start;                  final int end;                  if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {                      start = request.getStartIndex();                      end = request.getEndIndex();                  } else { -                    final int[] startEnd = nativeImpl.suggestSelection( +                    final int[] startEnd = annotatorImpl.suggestSelection(                              string, request.getStartIndex(), request.getEndIndex(),                              new AnnotatorModel.SelectionOptions(localesString));                      start = startEnd[0]; @@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier {                          && start <= request.getStartIndex() && end >= request.getEndIndex()) {                      final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);                      final AnnotatorModel.ClassificationResult[] results = -                            nativeImpl.classifyText( +                            annotatorImpl.classifyText(                                      string, start, end,                                      new AnnotatorModel.ClassificationOptions(                                              refTime.toInstant().toEpochMilli(), @@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier {                  final ZonedDateTime refTime = request.getReferenceTime() != null                          ? request.getReferenceTime() : ZonedDateTime.now();                  final AnnotatorModel.ClassificationResult[] results = -                        getNative(request.getDefaultLocales()) +                        getAnnotatorImpl(request.getDefaultLocales())                                  .classifyText(                                          string, request.getStartIndex(), request.getEndIndex(),                                          new AnnotatorModel.ClassificationOptions( @@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier {                      ? request.getEntityConfig().resolveEntityListModifications(                              getEntitiesForHints(request.getEntityConfig().getHints()))                      : mSettings.getEntityListDefault(); -            final AnnotatorModel nativeImpl = -                    getNative(request.getDefaultLocales()); +            final AnnotatorModel annotatorImpl = +                    getAnnotatorImpl(request.getDefaultLocales());              final AnnotatorModel.AnnotatedSpan[] annotations = -                    nativeImpl.annotate( +                    annotatorImpl.annotate(                          textString,                          new AnnotatorModel.AnnotationOptions(                                  refTime.toInstant().toEpochMilli(), @@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier {          }      } +    /** @inheritDoc */      @Override      public void onSelectionEvent(SelectionEvent event) {          Preconditions.checkNotNull(event); @@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier {          }      } -    private AnnotatorModel getNative(LocaleList localeList) +    /** @inheritDoc */ +    @Override +    public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) { +        Preconditions.checkNotNull(request); +        Utils.checkMainThread(); +        try { +            final TextLanguage.Builder builder = new TextLanguage.Builder(); +            final LangIdModel.LanguageResult[] langResults = +                    getLangIdImpl().detectLanguages(request.getText().toString()); +            for (int i = 0; i < langResults.length; i++) { +                builder.putLocale( +                        ULocale.forLanguageTag(langResults[i].getLanguage()), +                        langResults[i].getScore()); +            } +            return builder.build(); +        } catch (Throwable t) { +            // Avoid throwing from this method. Log the error. +            Log.e(LOG_TAG, "Error detecting text language.", t); +        } +        return mFallback.detectLanguage(request); +    } + +    private AnnotatorModel getAnnotatorImpl(LocaleList localeList)              throws FileNotFoundException {          synchronized (mLock) {              localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; @@ -307,31 +338,79 @@ public final class TextClassifierImpl implements TextClassifier {              if (bestModel == null) {                  throw new FileNotFoundException("No model for " + localeList.toLanguageTags());              } -            if (mNative == null || !Objects.equals(mModel, bestModel)) { +            if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {                  Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); -                destroyNativeIfExistsLocked(); +                destroyAnnotatorImplIfExistsLocked();                  final ParcelFileDescriptor fd = ParcelFileDescriptor.open(                          new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); -                mNative = new AnnotatorModel(fd.getFd()); -                closeAndLogError(fd); -                mModel = bestModel; +                try { +                    if (fd != null) { +                        mAnnotatorImpl = new AnnotatorModel(fd.getFd()); +                        mModel = bestModel; +                    } +                } finally { +                    maybeCloseAndLogError(fd); +                }              } -            return mNative; +            return mAnnotatorImpl;          }      } -    private String createId(String text, int start, int end) { +    @GuardedBy("mLock") // Do not call outside this lock. +    private void destroyAnnotatorImplIfExistsLocked() { +        if (mAnnotatorImpl != null) { +            mAnnotatorImpl.close(); +            mAnnotatorImpl = null; +        } +    } + +    private LangIdModel getLangIdImpl() throws FileNotFoundException {          synchronized (mLock) { -            return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(), -                    mModel.getSupportedLocales()); +            if (mLangIdImpl == null) { +                ParcelFileDescriptor factoryFd = null; +                ParcelFileDescriptor updateFd = null; +                try { +                    int factoryVersion = -1; +                    int updateVersion = factoryVersion; +                    final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH); +                    if (factoryFile.exists()) { +                        factoryFd = ParcelFileDescriptor.open( +                                factoryFile, ParcelFileDescriptor.MODE_READ_ONLY); +                        // TODO: Uncomment when method is implemented: +                        // if (factoryFd != null) { +                        //     factoryVersion = LangIdModel.getVersion(factoryFd.getFd()); +                        // } +                    } +                    final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH); +                    if (updateFile.exists()) { +                        updateFd = ParcelFileDescriptor.open( +                                updateFile, ParcelFileDescriptor.MODE_READ_ONLY); +                        // TODO: Uncomment when method is implemented: +                        // if (updateFd != null) { +                        //     updateVersion = LangIdModel.getVersion(updateFd.getFd()); +                        // } +                    } + +                    if (updateVersion > factoryVersion) { +                        mLangIdImpl = new LangIdModel(updateFd.getFd()); +                    } else if (factoryFd != null) { +                        mLangIdImpl = new LangIdModel(factoryFd.getFd()); +                    } else { +                        throw new FileNotFoundException("Language detection model not found"); +                    } +                } finally { +                    maybeCloseAndLogError(factoryFd); +                    maybeCloseAndLogError(updateFd); +                } +            } +            return mLangIdImpl;          }      } -    @GuardedBy("mLock") // Do not call outside this lock. -    private void destroyNativeIfExistsLocked() { -        if (mNative != null) { -            mNative.close(); -            mNative = null; +    private String createId(String text, int start, int end) { +        synchronized (mLock) { +            return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(), +                    mModel.getSupportedLocales());          }      } @@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier {                  .setText(classifiedText);          final int size = classifications.length; -        AnnotatorModel.ClassificationResult highestScoringResult = null; -        float highestScore = Float.MIN_VALUE; +        AnnotatorModel.ClassificationResult highestScoringResult = +                size > 0 ? classifications[0] : null;          for (int i = 0; i < size; i++) {              builder.setEntityType(classifications[i].getCollection(),                                    classifications[i].getScore()); -            if (classifications[i].getScore() > highestScore) { +            if (classifications[i].getScore() > highestScoringResult.getScore()) {                  highestScoringResult = classifications[i]; -                highestScore = classifications[i].getScore();              }          }          boolean isPrimaryAction = true;          for (LabeledIntent labeledIntent : IntentFactory.create( -                mContext, referenceTime, highestScoringResult, classifiedText)) { +                mContext, classifiedText, referenceTime, highestScoringResult)) {              final RemoteAction action = labeledIntent.asRemoteAction(mContext);              if (action == null) {                  continue; @@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier {      }      /** -     * Closes the ParcelFileDescriptor and logs any errors that occur. +     * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.       */ -    private static void closeAndLogError(ParcelFileDescriptor fd) { +    private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) { +        if (fd == null) { +            return; +        } +          try {              fd.close();          } catch (IOException e) { @@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier {          /** Returns null if the path did not point to a compatible model. */          static @Nullable ModelFile fromPath(String path) {              final File file = new File(path); +            if (!file.exists()) { +                return null; +            } +            ParcelFileDescriptor modelFd = null;              try { -                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open( -                        file, ParcelFileDescriptor.MODE_READ_ONLY); +                modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY); +                if (modelFd == null) { +                    return null; +                }                  final int version = AnnotatorModel.getVersion(modelFd.getFd()); -                final String supportedLocalesStr = -                        AnnotatorModel.getLocales(modelFd.getFd()); +                final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());                  if (supportedLocalesStr.isEmpty()) {                      Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());                      return null; @@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier {                  for (String langTag : supportedLocalesStr.split(",")) {                      supportedLocales.add(Locale.forLanguageTag(langTag));                  } -                closeAndLogError(modelFd);                  return new ModelFile(path, file.getName(), version, supportedLocales,                                       languageIndependent);              } catch (FileNotFoundException e) {                  Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);                  return null; +            } finally { +                maybeCloseAndLogError(modelFd);              }          } @@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier {          public boolean equals(Object other) {              if (this == other) {                  return true; -            } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) { -                return false; -            } else { +            } +            if (other instanceof ModelFile) {                  final ModelFile otherModel = (ModelFile) other;                  return mPath.equals(otherModel.mPath);              } +            return false;          }          @Override @@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier {          @NonNull          public static List<LabeledIntent> create(                  Context context, +                String text,                  @Nullable Instant referenceTime, -                AnnotatorModel.ClassificationResult classification, -                String text) { -            final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH); +                @Nullable AnnotatorModel.ClassificationResult classification) { +            final String type = classification != null +                    ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH) +                    : null;              text = text.trim();              switch (type) {                  case TextClassifier.TYPE_EMAIL: diff --git a/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java b/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java index e891fc9d6134..8646c685c998 100644 --- a/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java +++ b/core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java @@ -307,6 +307,24 @@ public class TextClassificationManagerTest {      }      @Test +    public void testDetectLanguage() { +        if (isTextClassifierDisabled()) return; +        String text = "This is English text"; +        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); +        TextLanguage textLanguage = mClassifier.detectLanguage(request); +        assertThat(textLanguage, isTextLanguage("en")); +    } + +    @Test +    public void testDetectLanguage_japanese() { +        if (isTextClassifierDisabled()) return; +        String text = "これは日本語のテキストです"; +        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); +        TextLanguage textLanguage = mClassifier.detectLanguage(request); +        assertThat(textLanguage, isTextLanguage("ja")); +    } + +    @Test      public void testSetTextClassifier() {          TextClassifier classifier = mock(TextClassifier.class);          mTcm.setTextClassifier(classifier); @@ -444,4 +462,23 @@ public class TextClassificationManagerTest {              }          };      } + +    private static Matcher<TextLanguage> isTextLanguage(final String languageTag) { +        return new BaseMatcher<TextLanguage>() { +            @Override +            public boolean matches(Object o) { +                if (o instanceof TextLanguage) { +                    TextLanguage result = (TextLanguage) o; +                    return result.getLocaleHypothesisCount() > 0 +                            && languageTag.equals(result.getLocale(0).toLanguageTag()); +                } +                return false; +            } + +            @Override +            public void describeTo(Description description) { +                description.appendText("locale=").appendValue(languageTag); +            } +        }; +    }  }  |