summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Jan Althaus <jalt@google.com> 2017-11-30 15:01:40 +0100
committer Jan Althaus <jalt@google.com> 2017-12-01 10:37:01 +0100
commitbbe43dfd97c01364e46df452be4c99536d64e4fb (patch)
tree2964448e11b27f7561efd369200eeee6b21eedb5
parent05013b377266f9e4e2651c6aa819960479dc3676 (diff)
Storage refactor for EntityConfidence
Caching the sorted entity list so users don't need to be careful to cache the result of getEntities (previously dont by TextSelection and TextClassification, but not TextLink). Also switched to ArrayMap as it's better suited for small maps like the ones generated by the classifier. Test: Ran FrameworksCoreTests Change-Id: I08cc9f72146ccab88b6a3624f3775a366c814f7a
-rw-r--r--core/java/android/view/textclassifier/EntityConfidence.java57
-rw-r--r--core/java/android/view/textclassifier/TextClassification.java15
-rw-r--r--core/java/android/view/textclassifier/TextLinks.java6
-rw-r--r--core/java/android/view/textclassifier/TextSelection.java17
4 files changed, 43 insertions, 52 deletions
diff --git a/core/java/android/view/textclassifier/EntityConfidence.java b/core/java/android/view/textclassifier/EntityConfidence.java
index 0589d204ac3f..19660d95e927 100644
--- a/core/java/android/view/textclassifier/EntityConfidence.java
+++ b/core/java/android/view/textclassifier/EntityConfidence.java
@@ -18,13 +18,12 @@ package android.view.textclassifier;
import android.annotation.FloatRange;
import android.annotation.NonNull;
+import android.util.ArrayMap;
import com.android.internal.util.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -36,42 +35,43 @@ import java.util.Map;
*/
final class EntityConfidence<T> {
- private final Map<T, Float> mEntityConfidence = new HashMap<>();
-
- private final Comparator<T> mEntityComparator = (e1, e2) -> {
- float score1 = mEntityConfidence.get(e1);
- float score2 = mEntityConfidence.get(e2);
- if (score1 > score2) {
- return -1;
- }
- if (score1 < score2) {
- return 1;
- }
- return 0;
- };
+ private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>();
+ private final ArrayList<T> mSortedEntities = new ArrayList<>();
EntityConfidence() {}
EntityConfidence(@NonNull EntityConfidence<T> source) {
Preconditions.checkNotNull(source);
mEntityConfidence.putAll(source.mEntityConfidence);
+ mSortedEntities.addAll(source.mSortedEntities);
}
/**
- * Sets an entity type for the classified text and assigns a confidence score.
+ * Constructs an EntityConfidence from a map of entity to confidence.
*
- * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence).
- * 0 implies the entity does not exist for the classified text.
- * Values greater than 1 are clamped to 1.
+ * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
+ *
+ * @param source a map from entity to a confidence value in the range 0 (low confidence) to
+ * 1 (high confidence).
*/
- public void setEntityType(
- @NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
- Preconditions.checkNotNull(type);
- if (confidenceScore > 0) {
- mEntityConfidence.put(type, Math.min(1, confidenceScore));
- } else {
- mEntityConfidence.remove(type);
+ EntityConfidence(@NonNull Map<T, Float> source) {
+ Preconditions.checkNotNull(source);
+
+ // Prune non-existent entities and clamp to 1.
+ mEntityConfidence.ensureCapacity(source.size());
+ for (Map.Entry<T, Float> it : source.entrySet()) {
+ if (it.getValue() <= 0) continue;
+ mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
}
+
+ // Create a list of entities sorted by decreasing confidence for getEntities().
+ mSortedEntities.ensureCapacity(mEntityConfidence.size());
+ mSortedEntities.addAll(mEntityConfidence.keySet());
+ mSortedEntities.sort((e1, e2) -> {
+ float score1 = mEntityConfidence.get(e1);
+ float score2 = mEntityConfidence.get(e2);
+ return Float.compare(score2, score1);
+ });
}
/**
@@ -80,10 +80,7 @@ final class EntityConfidence<T> {
*/
@NonNull
public List<T> getEntities() {
- List<T> entities = new ArrayList<>(mEntityConfidence.size());
- entities.addAll(mEntityConfidence.keySet());
- entities.sort(mEntityComparator);
- return Collections.unmodifiableList(entities);
+ return Collections.unmodifiableList(mSortedEntities);
}
/**
diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java
index f675c355638c..89163238ea4d 100644
--- a/core/java/android/view/textclassifier/TextClassification.java
+++ b/core/java/android/view/textclassifier/TextClassification.java
@@ -24,6 +24,7 @@ import android.content.Context;
import android.content.Intent;
import android.graphics.drawable.Drawable;
import android.os.LocaleList;
+import android.util.ArrayMap;
import android.view.View.OnClickListener;
import android.view.textclassifier.TextClassifier.EntityType;
@@ -32,6 +33,7 @@ import com.android.internal.util.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
/**
* Information for generating a widget to handle classified text.
@@ -95,7 +97,6 @@ public final class TextClassification {
@NonNull private final List<Intent> mIntents;
@NonNull private final List<OnClickListener> mOnClickListeners;
@NonNull private final EntityConfidence<String> mEntityConfidence;
- @NonNull private final List<String> mEntities;
private int mLogType;
@NonNull private final String mVersionInfo;
@@ -105,7 +106,7 @@ public final class TextClassification {
@NonNull List<String> labels,
@NonNull List<Intent> intents,
@NonNull List<OnClickListener> onClickListeners,
- @NonNull EntityConfidence<String> entityConfidence,
+ @NonNull Map<String, Float> entityConfidence,
int logType,
@NonNull String versionInfo) {
Preconditions.checkArgument(labels.size() == intents.size());
@@ -117,7 +118,6 @@ public final class TextClassification {
mIntents = intents;
mOnClickListeners = onClickListeners;
mEntityConfidence = new EntityConfidence<>(entityConfidence);
- mEntities = mEntityConfidence.getEntities();
mLogType = logType;
mVersionInfo = versionInfo;
}
@@ -135,7 +135,7 @@ public final class TextClassification {
*/
@IntRange(from = 0)
public int getEntityCount() {
- return mEntities.size();
+ return mEntityConfidence.getEntities().size();
}
/**
@@ -147,7 +147,7 @@ public final class TextClassification {
*/
@NonNull
public @EntityType String getEntity(int index) {
- return mEntities.get(index);
+ return mEntityConfidence.getEntities().get(index);
}
/**
@@ -311,8 +311,7 @@ public final class TextClassification {
@NonNull private final List<String> mLabels = new ArrayList<>();
@NonNull private final List<Intent> mIntents = new ArrayList<>();
@NonNull private final List<OnClickListener> mOnClickListeners = new ArrayList<>();
- @NonNull private final EntityConfidence<String> mEntityConfidence =
- new EntityConfidence<>();
+ @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
private int mLogType;
@NonNull private String mVersionInfo = "";
@@ -334,7 +333,7 @@ public final class TextClassification {
public Builder setEntityType(
@NonNull @EntityType String type,
@FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
- mEntityConfidence.setEntityType(type, confidenceScore);
+ mEntityConfidence.put(type, confidenceScore);
return this;
}
diff --git a/core/java/android/view/textclassifier/TextLinks.java b/core/java/android/view/textclassifier/TextLinks.java
index 76748d2b191a..0e039e35367e 100644
--- a/core/java/android/view/textclassifier/TextLinks.java
+++ b/core/java/android/view/textclassifier/TextLinks.java
@@ -103,11 +103,7 @@ public final class TextLinks {
mOriginalText = originalText;
mStart = start;
mEnd = end;
- mEntityScores = new EntityConfidence<>();
-
- for (Map.Entry<String, Float> entry : entityScores.entrySet()) {
- mEntityScores.setEntityType(entry.getKey(), entry.getValue());
- }
+ mEntityScores = new EntityConfidence<>(entityScores);
}
/**
diff --git a/core/java/android/view/textclassifier/TextSelection.java b/core/java/android/view/textclassifier/TextSelection.java
index 480b27a73fc1..ced4018bcd82 100644
--- a/core/java/android/view/textclassifier/TextSelection.java
+++ b/core/java/android/view/textclassifier/TextSelection.java
@@ -21,12 +21,13 @@ import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.os.LocaleList;
+import android.util.ArrayMap;
import android.view.textclassifier.TextClassifier.EntityType;
import com.android.internal.util.Preconditions;
-import java.util.List;
import java.util.Locale;
+import java.util.Map;
/**
* Information about where text selection should be.
@@ -36,7 +37,6 @@ public final class TextSelection {
private final int mStartIndex;
private final int mEndIndex;
@NonNull private final EntityConfidence<String> mEntityConfidence;
- @NonNull private final List<String> mEntities;
@NonNull private final String mLogSource;
@NonNull private final String mVersionInfo;
@@ -46,7 +46,6 @@ public final class TextSelection {
mStartIndex = startIndex;
mEndIndex = endIndex;
mEntityConfidence = new EntityConfidence<>(entityConfidence);
- mEntities = mEntityConfidence.getEntities();
mLogSource = logSource;
mVersionInfo = versionInfo;
}
@@ -70,7 +69,7 @@ public final class TextSelection {
*/
@IntRange(from = 0)
public int getEntityCount() {
- return mEntities.size();
+ return mEntityConfidence.getEntities().size();
}
/**
@@ -82,7 +81,7 @@ public final class TextSelection {
*/
@NonNull
public @EntityType String getEntity(int index) {
- return mEntities.get(index);
+ return mEntityConfidence.getEntities().get(index);
}
/**
@@ -126,8 +125,7 @@ public final class TextSelection {
private final int mStartIndex;
private final int mEndIndex;
- @NonNull private final EntityConfidence<String> mEntityConfidence =
- new EntityConfidence<>();
+ @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
@NonNull private String mLogSource = "";
@NonNull private String mVersionInfo = "";
@@ -154,7 +152,7 @@ public final class TextSelection {
public Builder setEntityType(
@NonNull @EntityType String type,
@FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
- mEntityConfidence.setEntityType(type, confidenceScore);
+ mEntityConfidence.put(type, confidenceScore);
return this;
}
@@ -181,7 +179,8 @@ public final class TextSelection {
*/
public TextSelection build() {
return new TextSelection(
- mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo);
+ mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence), mLogSource,
+ mVersionInfo);
}
}