From bbe43dfd97c01364e46df452be4c99536d64e4fb Mon Sep 17 00:00:00 2001 From: Jan Althaus Date: Thu, 30 Nov 2017 15:01:40 +0100 Subject: Storage refactor for EntityConfidence Caching the sorted entity list so users don't need to be careful to cache the result of getEntities (previously dont by TextSelection and TextClassification, but not TextLink). Also switched to ArrayMap as it's better suited for small maps like the ones generated by the classifier. Test: Ran FrameworksCoreTests Change-Id: I08cc9f72146ccab88b6a3624f3775a366c814f7a --- .../view/textclassifier/EntityConfidence.java | 57 ++++++++++------------ .../view/textclassifier/TextClassification.java | 15 +++--- .../android/view/textclassifier/TextLinks.java | 6 +-- .../android/view/textclassifier/TextSelection.java | 17 +++---- 4 files changed, 43 insertions(+), 52 deletions(-) diff --git a/core/java/android/view/textclassifier/EntityConfidence.java b/core/java/android/view/textclassifier/EntityConfidence.java index 0589d204ac3f..19660d95e927 100644 --- a/core/java/android/view/textclassifier/EntityConfidence.java +++ b/core/java/android/view/textclassifier/EntityConfidence.java @@ -18,13 +18,12 @@ package android.view.textclassifier; import android.annotation.FloatRange; import android.annotation.NonNull; +import android.util.ArrayMap; import com.android.internal.util.Preconditions; import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -36,42 +35,43 @@ import java.util.Map; */ final class EntityConfidence { - private final Map mEntityConfidence = new HashMap<>(); - - private final Comparator mEntityComparator = (e1, e2) -> { - float score1 = mEntityConfidence.get(e1); - float score2 = mEntityConfidence.get(e2); - if (score1 > score2) { - return -1; - } - if (score1 < score2) { - return 1; - } - return 0; - }; + private final ArrayMap mEntityConfidence = new ArrayMap<>(); + private final ArrayList mSortedEntities = new ArrayList<>(); EntityConfidence() {} EntityConfidence(@NonNull EntityConfidence source) { Preconditions.checkNotNull(source); mEntityConfidence.putAll(source.mEntityConfidence); + mSortedEntities.addAll(source.mSortedEntities); } /** - * Sets an entity type for the classified text and assigns a confidence score. + * Constructs an EntityConfidence from a map of entity to confidence. * - * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence). - * 0 implies the entity does not exist for the classified text. - * Values greater than 1 are clamped to 1. + * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1. + * + * @param source a map from entity to a confidence value in the range 0 (low confidence) to + * 1 (high confidence). */ - public void setEntityType( - @NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - Preconditions.checkNotNull(type); - if (confidenceScore > 0) { - mEntityConfidence.put(type, Math.min(1, confidenceScore)); - } else { - mEntityConfidence.remove(type); + EntityConfidence(@NonNull Map source) { + Preconditions.checkNotNull(source); + + // Prune non-existent entities and clamp to 1. + mEntityConfidence.ensureCapacity(source.size()); + for (Map.Entry it : source.entrySet()) { + if (it.getValue() <= 0) continue; + mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); } + + // Create a list of entities sorted by decreasing confidence for getEntities(). + mSortedEntities.ensureCapacity(mEntityConfidence.size()); + mSortedEntities.addAll(mEntityConfidence.keySet()); + mSortedEntities.sort((e1, e2) -> { + float score1 = mEntityConfidence.get(e1); + float score2 = mEntityConfidence.get(e2); + return Float.compare(score2, score1); + }); } /** @@ -80,10 +80,7 @@ final class EntityConfidence { */ @NonNull public List getEntities() { - List entities = new ArrayList<>(mEntityConfidence.size()); - entities.addAll(mEntityConfidence.keySet()); - entities.sort(mEntityComparator); - return Collections.unmodifiableList(entities); + return Collections.unmodifiableList(mSortedEntities); } /** diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java index f675c355638c..89163238ea4d 100644 --- a/core/java/android/view/textclassifier/TextClassification.java +++ b/core/java/android/view/textclassifier/TextClassification.java @@ -24,6 +24,7 @@ import android.content.Context; import android.content.Intent; import android.graphics.drawable.Drawable; import android.os.LocaleList; +import android.util.ArrayMap; import android.view.View.OnClickListener; import android.view.textclassifier.TextClassifier.EntityType; @@ -32,6 +33,7 @@ import com.android.internal.util.Preconditions; import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; /** * Information for generating a widget to handle classified text. @@ -95,7 +97,6 @@ public final class TextClassification { @NonNull private final List mIntents; @NonNull private final List mOnClickListeners; @NonNull private final EntityConfidence mEntityConfidence; - @NonNull private final List mEntities; private int mLogType; @NonNull private final String mVersionInfo; @@ -105,7 +106,7 @@ public final class TextClassification { @NonNull List labels, @NonNull List intents, @NonNull List onClickListeners, - @NonNull EntityConfidence entityConfidence, + @NonNull Map entityConfidence, int logType, @NonNull String versionInfo) { Preconditions.checkArgument(labels.size() == intents.size()); @@ -117,7 +118,6 @@ public final class TextClassification { mIntents = intents; mOnClickListeners = onClickListeners; mEntityConfidence = new EntityConfidence<>(entityConfidence); - mEntities = mEntityConfidence.getEntities(); mLogType = logType; mVersionInfo = versionInfo; } @@ -135,7 +135,7 @@ public final class TextClassification { */ @IntRange(from = 0) public int getEntityCount() { - return mEntities.size(); + return mEntityConfidence.getEntities().size(); } /** @@ -147,7 +147,7 @@ public final class TextClassification { */ @NonNull public @EntityType String getEntity(int index) { - return mEntities.get(index); + return mEntityConfidence.getEntities().get(index); } /** @@ -311,8 +311,7 @@ public final class TextClassification { @NonNull private final List mLabels = new ArrayList<>(); @NonNull private final List mIntents = new ArrayList<>(); @NonNull private final List mOnClickListeners = new ArrayList<>(); - @NonNull private final EntityConfidence mEntityConfidence = - new EntityConfidence<>(); + @NonNull private final Map mEntityConfidence = new ArrayMap<>(); private int mLogType; @NonNull private String mVersionInfo = ""; @@ -334,7 +333,7 @@ public final class TextClassification { public Builder setEntityType( @NonNull @EntityType String type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - mEntityConfidence.setEntityType(type, confidenceScore); + mEntityConfidence.put(type, confidenceScore); return this; } diff --git a/core/java/android/view/textclassifier/TextLinks.java b/core/java/android/view/textclassifier/TextLinks.java index 76748d2b191a..0e039e35367e 100644 --- a/core/java/android/view/textclassifier/TextLinks.java +++ b/core/java/android/view/textclassifier/TextLinks.java @@ -103,11 +103,7 @@ public final class TextLinks { mOriginalText = originalText; mStart = start; mEnd = end; - mEntityScores = new EntityConfidence<>(); - - for (Map.Entry entry : entityScores.entrySet()) { - mEntityScores.setEntityType(entry.getKey(), entry.getValue()); - } + mEntityScores = new EntityConfidence<>(entityScores); } /** diff --git a/core/java/android/view/textclassifier/TextSelection.java b/core/java/android/view/textclassifier/TextSelection.java index 480b27a73fc1..ced4018bcd82 100644 --- a/core/java/android/view/textclassifier/TextSelection.java +++ b/core/java/android/view/textclassifier/TextSelection.java @@ -21,12 +21,13 @@ import android.annotation.IntRange; import android.annotation.NonNull; import android.annotation.Nullable; import android.os.LocaleList; +import android.util.ArrayMap; import android.view.textclassifier.TextClassifier.EntityType; import com.android.internal.util.Preconditions; -import java.util.List; import java.util.Locale; +import java.util.Map; /** * Information about where text selection should be. @@ -36,7 +37,6 @@ public final class TextSelection { private final int mStartIndex; private final int mEndIndex; @NonNull private final EntityConfidence mEntityConfidence; - @NonNull private final List mEntities; @NonNull private final String mLogSource; @NonNull private final String mVersionInfo; @@ -46,7 +46,6 @@ public final class TextSelection { mStartIndex = startIndex; mEndIndex = endIndex; mEntityConfidence = new EntityConfidence<>(entityConfidence); - mEntities = mEntityConfidence.getEntities(); mLogSource = logSource; mVersionInfo = versionInfo; } @@ -70,7 +69,7 @@ public final class TextSelection { */ @IntRange(from = 0) public int getEntityCount() { - return mEntities.size(); + return mEntityConfidence.getEntities().size(); } /** @@ -82,7 +81,7 @@ public final class TextSelection { */ @NonNull public @EntityType String getEntity(int index) { - return mEntities.get(index); + return mEntityConfidence.getEntities().get(index); } /** @@ -126,8 +125,7 @@ public final class TextSelection { private final int mStartIndex; private final int mEndIndex; - @NonNull private final EntityConfidence mEntityConfidence = - new EntityConfidence<>(); + @NonNull private final Map mEntityConfidence = new ArrayMap<>(); @NonNull private String mLogSource = ""; @NonNull private String mVersionInfo = ""; @@ -154,7 +152,7 @@ public final class TextSelection { public Builder setEntityType( @NonNull @EntityType String type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - mEntityConfidence.setEntityType(type, confidenceScore); + mEntityConfidence.put(type, confidenceScore); return this; } @@ -181,7 +179,8 @@ public final class TextSelection { */ public TextSelection build() { return new TextSelection( - mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo); + mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence), mLogSource, + mVersionInfo); } } -- cgit v1.2.3-59-g8ed1b