diff options
| author | 2017-11-30 15:01:40 +0100 | |
|---|---|---|
| committer | 2017-12-01 10:37:01 +0100 | |
| commit | bbe43dfd97c01364e46df452be4c99536d64e4fb (patch) | |
| tree | 2964448e11b27f7561efd369200eeee6b21eedb5 | |
| parent | 05013b377266f9e4e2651c6aa819960479dc3676 (diff) | |
Storage refactor for EntityConfidence
Caching the sorted entity list so users don't need to be careful to cache
the result of getEntities (previously dont by TextSelection and
TextClassification, but not TextLink). Also switched to ArrayMap as it's
better suited for small maps like the ones generated by the classifier.
Test: Ran FrameworksCoreTests
Change-Id: I08cc9f72146ccab88b6a3624f3775a366c814f7a
4 files changed, 43 insertions, 52 deletions
diff --git a/core/java/android/view/textclassifier/EntityConfidence.java b/core/java/android/view/textclassifier/EntityConfidence.java index 0589d204ac3f..19660d95e927 100644 --- a/core/java/android/view/textclassifier/EntityConfidence.java +++ b/core/java/android/view/textclassifier/EntityConfidence.java @@ -18,13 +18,12 @@ package android.view.textclassifier; import android.annotation.FloatRange; import android.annotation.NonNull; +import android.util.ArrayMap; import com.android.internal.util.Preconditions; import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -36,42 +35,43 @@ import java.util.Map; */ final class EntityConfidence<T> { - private final Map<T, Float> mEntityConfidence = new HashMap<>(); - - private final Comparator<T> mEntityComparator = (e1, e2) -> { - float score1 = mEntityConfidence.get(e1); - float score2 = mEntityConfidence.get(e2); - if (score1 > score2) { - return -1; - } - if (score1 < score2) { - return 1; - } - return 0; - }; + private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>(); + private final ArrayList<T> mSortedEntities = new ArrayList<>(); EntityConfidence() {} EntityConfidence(@NonNull EntityConfidence<T> source) { Preconditions.checkNotNull(source); mEntityConfidence.putAll(source.mEntityConfidence); + mSortedEntities.addAll(source.mSortedEntities); } /** - * Sets an entity type for the classified text and assigns a confidence score. + * Constructs an EntityConfidence from a map of entity to confidence. * - * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence). - * 0 implies the entity does not exist for the classified text. - * Values greater than 1 are clamped to 1. + * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1. + * + * @param source a map from entity to a confidence value in the range 0 (low confidence) to + * 1 (high confidence). */ - public void setEntityType( - @NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - Preconditions.checkNotNull(type); - if (confidenceScore > 0) { - mEntityConfidence.put(type, Math.min(1, confidenceScore)); - } else { - mEntityConfidence.remove(type); + EntityConfidence(@NonNull Map<T, Float> source) { + Preconditions.checkNotNull(source); + + // Prune non-existent entities and clamp to 1. + mEntityConfidence.ensureCapacity(source.size()); + for (Map.Entry<T, Float> it : source.entrySet()) { + if (it.getValue() <= 0) continue; + mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); } + + // Create a list of entities sorted by decreasing confidence for getEntities(). + mSortedEntities.ensureCapacity(mEntityConfidence.size()); + mSortedEntities.addAll(mEntityConfidence.keySet()); + mSortedEntities.sort((e1, e2) -> { + float score1 = mEntityConfidence.get(e1); + float score2 = mEntityConfidence.get(e2); + return Float.compare(score2, score1); + }); } /** @@ -80,10 +80,7 @@ final class EntityConfidence<T> { */ @NonNull public List<T> getEntities() { - List<T> entities = new ArrayList<>(mEntityConfidence.size()); - entities.addAll(mEntityConfidence.keySet()); - entities.sort(mEntityComparator); - return Collections.unmodifiableList(entities); + return Collections.unmodifiableList(mSortedEntities); } /** diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java index f675c355638c..89163238ea4d 100644 --- a/core/java/android/view/textclassifier/TextClassification.java +++ b/core/java/android/view/textclassifier/TextClassification.java @@ -24,6 +24,7 @@ import android.content.Context; import android.content.Intent; import android.graphics.drawable.Drawable; import android.os.LocaleList; +import android.util.ArrayMap; import android.view.View.OnClickListener; import android.view.textclassifier.TextClassifier.EntityType; @@ -32,6 +33,7 @@ import com.android.internal.util.Preconditions; import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; /** * Information for generating a widget to handle classified text. @@ -95,7 +97,6 @@ public final class TextClassification { @NonNull private final List<Intent> mIntents; @NonNull private final List<OnClickListener> mOnClickListeners; @NonNull private final EntityConfidence<String> mEntityConfidence; - @NonNull private final List<String> mEntities; private int mLogType; @NonNull private final String mVersionInfo; @@ -105,7 +106,7 @@ public final class TextClassification { @NonNull List<String> labels, @NonNull List<Intent> intents, @NonNull List<OnClickListener> onClickListeners, - @NonNull EntityConfidence<String> entityConfidence, + @NonNull Map<String, Float> entityConfidence, int logType, @NonNull String versionInfo) { Preconditions.checkArgument(labels.size() == intents.size()); @@ -117,7 +118,6 @@ public final class TextClassification { mIntents = intents; mOnClickListeners = onClickListeners; mEntityConfidence = new EntityConfidence<>(entityConfidence); - mEntities = mEntityConfidence.getEntities(); mLogType = logType; mVersionInfo = versionInfo; } @@ -135,7 +135,7 @@ public final class TextClassification { */ @IntRange(from = 0) public int getEntityCount() { - return mEntities.size(); + return mEntityConfidence.getEntities().size(); } /** @@ -147,7 +147,7 @@ public final class TextClassification { */ @NonNull public @EntityType String getEntity(int index) { - return mEntities.get(index); + return mEntityConfidence.getEntities().get(index); } /** @@ -311,8 +311,7 @@ public final class TextClassification { @NonNull private final List<String> mLabels = new ArrayList<>(); @NonNull private final List<Intent> mIntents = new ArrayList<>(); @NonNull private final List<OnClickListener> mOnClickListeners = new ArrayList<>(); - @NonNull private final EntityConfidence<String> mEntityConfidence = - new EntityConfidence<>(); + @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>(); private int mLogType; @NonNull private String mVersionInfo = ""; @@ -334,7 +333,7 @@ public final class TextClassification { public Builder setEntityType( @NonNull @EntityType String type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - mEntityConfidence.setEntityType(type, confidenceScore); + mEntityConfidence.put(type, confidenceScore); return this; } diff --git a/core/java/android/view/textclassifier/TextLinks.java b/core/java/android/view/textclassifier/TextLinks.java index 76748d2b191a..0e039e35367e 100644 --- a/core/java/android/view/textclassifier/TextLinks.java +++ b/core/java/android/view/textclassifier/TextLinks.java @@ -103,11 +103,7 @@ public final class TextLinks { mOriginalText = originalText; mStart = start; mEnd = end; - mEntityScores = new EntityConfidence<>(); - - for (Map.Entry<String, Float> entry : entityScores.entrySet()) { - mEntityScores.setEntityType(entry.getKey(), entry.getValue()); - } + mEntityScores = new EntityConfidence<>(entityScores); } /** diff --git a/core/java/android/view/textclassifier/TextSelection.java b/core/java/android/view/textclassifier/TextSelection.java index 480b27a73fc1..ced4018bcd82 100644 --- a/core/java/android/view/textclassifier/TextSelection.java +++ b/core/java/android/view/textclassifier/TextSelection.java @@ -21,12 +21,13 @@ import android.annotation.IntRange; import android.annotation.NonNull; import android.annotation.Nullable; import android.os.LocaleList; +import android.util.ArrayMap; import android.view.textclassifier.TextClassifier.EntityType; import com.android.internal.util.Preconditions; -import java.util.List; import java.util.Locale; +import java.util.Map; /** * Information about where text selection should be. @@ -36,7 +37,6 @@ public final class TextSelection { private final int mStartIndex; private final int mEndIndex; @NonNull private final EntityConfidence<String> mEntityConfidence; - @NonNull private final List<String> mEntities; @NonNull private final String mLogSource; @NonNull private final String mVersionInfo; @@ -46,7 +46,6 @@ public final class TextSelection { mStartIndex = startIndex; mEndIndex = endIndex; mEntityConfidence = new EntityConfidence<>(entityConfidence); - mEntities = mEntityConfidence.getEntities(); mLogSource = logSource; mVersionInfo = versionInfo; } @@ -70,7 +69,7 @@ public final class TextSelection { */ @IntRange(from = 0) public int getEntityCount() { - return mEntities.size(); + return mEntityConfidence.getEntities().size(); } /** @@ -82,7 +81,7 @@ public final class TextSelection { */ @NonNull public @EntityType String getEntity(int index) { - return mEntities.get(index); + return mEntityConfidence.getEntities().get(index); } /** @@ -126,8 +125,7 @@ public final class TextSelection { private final int mStartIndex; private final int mEndIndex; - @NonNull private final EntityConfidence<String> mEntityConfidence = - new EntityConfidence<>(); + @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>(); @NonNull private String mLogSource = ""; @NonNull private String mVersionInfo = ""; @@ -154,7 +152,7 @@ public final class TextSelection { public Builder setEntityType( @NonNull @EntityType String type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - mEntityConfidence.setEntityType(type, confidenceScore); + mEntityConfidence.put(type, confidenceScore); return this; } @@ -181,7 +179,8 @@ public final class TextSelection { */ public TextSelection build() { return new TextSelection( - mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo); + mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence), mLogSource, + mVersionInfo); } } |