summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/java/com/android/internal/app/AbstractResolverComparator.java6
-rw-r--r--core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java151
-rw-r--r--core/java/com/android/internal/app/ResolverComparatorModel.java57
-rw-r--r--core/java/com/android/internal/app/ResolverListAdapter.java8
-rw-r--r--core/java/com/android/internal/app/ResolverListController.java8
-rw-r--r--core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java241
-rw-r--r--core/tests/coretests/src/com/android/internal/app/AbstractResolverComparatorTest.java5
-rw-r--r--core/tests/coretests/src/com/android/internal/app/FakeResolverComparatorModel.java61
8 files changed, 383 insertions, 154 deletions
diff --git a/core/java/com/android/internal/app/AbstractResolverComparator.java b/core/java/com/android/internal/app/AbstractResolverComparator.java
index 42fc7bd6e6fc..975954035c17 100644
--- a/core/java/com/android/internal/app/AbstractResolverComparator.java
+++ b/core/java/com/android/internal/app/AbstractResolverComparator.java
@@ -228,12 +228,6 @@ public abstract class AbstractResolverComparator implements Comparator<ResolvedC
*/
abstract float getScore(ComponentName name);
- /**
- * Returns the list of top K component names which have highest
- * {@link #getScore(ComponentName)}
- */
- abstract List<ComponentName> getTopComponentNames(int topK);
-
/** Handles result message sent to mHandler. */
abstract void handleResultMessage(Message message);
diff --git a/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java b/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
index bc9eff04636d..b19ac2fec640 100644
--- a/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
+++ b/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
@@ -18,6 +18,7 @@ package com.android.internal.app;
import static android.app.prediction.AppTargetEvent.ACTION_LAUNCH;
+import android.annotation.Nullable;
import android.app.prediction.AppPredictor;
import android.app.prediction.AppTarget;
import android.app.prediction.AppTargetEvent;
@@ -33,12 +34,11 @@ import android.util.Log;
import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;
import java.util.ArrayList;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Map.Entry;
import java.util.concurrent.Executors;
-import java.util.stream.Collectors;
/**
* Uses an {@link AppPredictor} to sort Resolver targets. If the AppPredictionService appears to be
@@ -58,7 +58,9 @@ class AppPredictionServiceResolverComparator extends AbstractResolverComparator
private final String mReferrerPackage;
// If this is non-null (and this is not destroyed), it means APS is disabled and we should fall
// back to using the ResolverRankerService.
+ // TODO: responsibility for this fallback behavior can live outside of the AppPrediction client.
private ResolverRankerServiceResolverComparator mResolverRankerService;
+ private AppPredictionServiceComparatorModel mComparatorModel;
AppPredictionServiceResolverComparator(
Context context,
@@ -74,25 +76,12 @@ class AppPredictionServiceResolverComparator extends AbstractResolverComparator
mUser = user;
mReferrerPackage = referrerPackage;
setChooserActivityLogger(chooserActivityLogger);
+ mComparatorModel = buildUpdatedModel();
}
@Override
int compare(ResolveInfo lhs, ResolveInfo rhs) {
- if (mResolverRankerService != null) {
- return mResolverRankerService.compare(lhs, rhs);
- }
- Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName,
- lhs.activityInfo.name));
- Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName,
- rhs.activityInfo.name));
- if (lhsRank == null && rhsRank == null) {
- return 0;
- } else if (lhsRank == null) {
- return -1;
- } else if (rhsRank == null) {
- return 1;
- }
- return lhsRank - rhsRank;
+ return mComparatorModel.getComparator().compare(lhs, rhs);
}
@Override
@@ -121,6 +110,7 @@ class AppPredictionServiceResolverComparator extends AbstractResolverComparator
mContext, mIntent, mReferrerPackage,
() -> mHandler.sendEmptyMessage(RANKER_SERVICE_RESULT),
getChooserActivityLogger());
+ mComparatorModel = buildUpdatedModel();
mResolverRankerService.compute(targets);
} else {
Log.i(TAG, "AppPredictionService response received");
@@ -163,6 +153,7 @@ class AppPredictionServiceResolverComparator extends AbstractResolverComparator
mTargetRanks.put(componentName, i);
Log.i(TAG, "handleSortedAppTargets, sortedAppTargets #" + i + ": " + componentName);
}
+ mComparatorModel = buildUpdatedModel();
}
private boolean checkAppTargetRankValid(List<AppTarget> sortedAppTargets) {
@@ -176,43 +167,12 @@ class AppPredictionServiceResolverComparator extends AbstractResolverComparator
@Override
float getScore(ComponentName name) {
- if (mResolverRankerService != null) {
- return mResolverRankerService.getScore(name);
- }
- Integer rank = mTargetRanks.get(name);
- if (rank == null) {
- Log.w(TAG, "Score requested for unknown component. Did you call compute yet?");
- return 0f;
- }
- int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2;
- return 1.0f - (((float) rank) / consecutiveSumOfRanks);
- }
-
- @Override
- List<ComponentName> getTopComponentNames(int topK) {
- if (mResolverRankerService != null) {
- return mResolverRankerService.getTopComponentNames(topK);
- }
- return mTargetRanks.entrySet().stream()
- .sorted(Entry.comparingByValue())
- .limit(topK)
- .map(Entry::getKey)
- .collect(Collectors.toList());
+ return mComparatorModel.getScore(name);
}
@Override
void updateModel(ComponentName componentName) {
- if (mResolverRankerService != null) {
- mResolverRankerService.updateModel(componentName);
- return;
- }
- mAppPredictor.notifyAppTargetEvent(
- new AppTargetEvent.Builder(
- new AppTarget.Builder(
- new AppTargetId(componentName.toString()),
- componentName.getPackageName(), mUser)
- .setClassName(componentName.getClassName()).build(),
- ACTION_LAUNCH).build());
+ mComparatorModel.notifyOnTargetSelected(componentName);
}
@Override
@@ -220,6 +180,97 @@ class AppPredictionServiceResolverComparator extends AbstractResolverComparator
if (mResolverRankerService != null) {
mResolverRankerService.destroy();
mResolverRankerService = null;
+ mComparatorModel = buildUpdatedModel();
+ }
+ }
+
+ /**
+ * Re-construct an {@code AppPredictionServiceComparatorModel} to replace the current model
+ * instance (if any) using the up-to-date {@code AppPredictionServiceResolverComparator} ivar
+ * values.
+ *
+ * TODO: each time we replace the model instance, we're either updating the model to use
+ * adjusted data (which is appropriate), or we're providing a (late) value for one of our ivars
+ * that wasn't available the last time the model was updated. For those latter cases, we should
+ * just avoid creating the model altogether until we have all the prerequisites we'll need. Then
+ * we can probably simplify the logic in {@code AppPredictionServiceComparatorModel} since we
+ * won't need to handle edge cases when the model data isn't fully prepared.
+ * (In some cases, these kinds of "updates" might interleave -- e.g., we might have finished
+ * initializing the first time and now want to adjust some data, but still need to wait for
+ * changes to propagate to the other ivars before rebuilding the model.)
+ */
+ private AppPredictionServiceComparatorModel buildUpdatedModel() {
+ return new AppPredictionServiceComparatorModel(
+ mAppPredictor, mResolverRankerService, mUser, mTargetRanks);
+ }
+
+ // TODO: Finish separating behaviors of AbstractResolverComparator, then (probably) make this a
+ // standalone class once clients are written in terms of ResolverComparatorModel.
+ static class AppPredictionServiceComparatorModel implements ResolverComparatorModel {
+ private final AppPredictor mAppPredictor;
+ private final ResolverRankerServiceResolverComparator mResolverRankerService;
+ private final UserHandle mUser;
+ private final Map<ComponentName, Integer> mTargetRanks; // Treat as immutable.
+
+ AppPredictionServiceComparatorModel(
+ AppPredictor appPredictor,
+ @Nullable ResolverRankerServiceResolverComparator resolverRankerService,
+ UserHandle user,
+ Map<ComponentName, Integer> targetRanks) {
+ mAppPredictor = appPredictor;
+ mResolverRankerService = resolverRankerService;
+ mUser = user;
+ mTargetRanks = targetRanks;
+ }
+
+ @Override
+ public Comparator<ResolveInfo> getComparator() {
+ return (lhs, rhs) -> {
+ if (mResolverRankerService != null) {
+ return mResolverRankerService.compare(lhs, rhs);
+ }
+ Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName,
+ lhs.activityInfo.name));
+ Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName,
+ rhs.activityInfo.name));
+ if (lhsRank == null && rhsRank == null) {
+ return 0;
+ } else if (lhsRank == null) {
+ return -1;
+ } else if (rhsRank == null) {
+ return 1;
+ }
+ return lhsRank - rhsRank;
+ };
+ }
+
+ @Override
+ public float getScore(ComponentName name) {
+ if (mResolverRankerService != null) {
+ return mResolverRankerService.getScore(name);
+ }
+ Integer rank = mTargetRanks.get(name);
+ if (rank == null) {
+ Log.w(TAG, "Score requested for unknown component. Did you call compute yet?");
+ return 0f;
+ }
+ int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2;
+ return 1.0f - (((float) rank) / consecutiveSumOfRanks);
+ }
+
+ @Override
+ public void notifyOnTargetSelected(ComponentName componentName) {
+ if (mResolverRankerService != null) {
+ mResolverRankerService.updateModel(componentName);
+ return;
+ }
+ mAppPredictor.notifyAppTargetEvent(
+ new AppTargetEvent.Builder(
+ new AppTarget.Builder(
+ new AppTargetId(componentName.toString()),
+ componentName.getPackageName(), mUser)
+ .setClassName(componentName.getClassName()).build(),
+ ACTION_LAUNCH).build());
}
}
}
diff --git a/core/java/com/android/internal/app/ResolverComparatorModel.java b/core/java/com/android/internal/app/ResolverComparatorModel.java
new file mode 100644
index 000000000000..3e8f64bf4ed3
--- /dev/null
+++ b/core/java/com/android/internal/app/ResolverComparatorModel.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.internal.app;
+
+import android.content.ComponentName;
+import android.content.pm.ResolveInfo;
+
+import java.util.Comparator;
+import java.util.List;
+
+/**
+ * A ranking model for resolver targets, providing ordering and (optionally) numerical scoring.
+ *
+ * As required by the {@link Comparator} contract, objects returned by {@code getComparator()} must
+ * apply a total ordering on its inputs consistent across all calls to {@code Comparator#compare()}.
+ * Other query methods and ranking feedback should refer to that same ordering, so implementors are
+ * generally advised to "lock in" an immutable snapshot of their model data when this object is
+ * initialized (preferring to replace the entire {@code ResolverComparatorModel} instance if the
+ * backing data needs to be updated in the future).
+ */
+interface ResolverComparatorModel {
+ /**
+ * Get a {@code Comparator} that can be used to sort {@code ResolveInfo} targets according to
+ * the model ranking.
+ */
+ Comparator<ResolveInfo> getComparator();
+
+ /**
+ * Get the numerical score, if any, that the model assigns to the component with the specified
+ * {@code name}. Scores range from zero to one, with one representing the highest possible
+ * likelihood that the user will select that component as the target. Implementations that don't
+ * assign numerical scores are <em>recommended</em> to return a value of 0 for all components.
+ */
+ float getScore(ComponentName name);
+
+ /**
+ * Notify the model that the user selected a target. (Models may log this information, use it as
+ * a feedback signal for their ranking, etc.) Because the data in this
+ * {@code ResolverComparatorModel} instance is immutable, clients will need to get an up-to-date
+ * instance in order to see any changes in the ranking that might result from this feedback.
+ */
+ void notifyOnTargetSelected(ComponentName componentName);
+}
diff --git a/core/java/com/android/internal/app/ResolverListAdapter.java b/core/java/com/android/internal/app/ResolverListAdapter.java
index fe66cad89fbb..351ac4587def 100644
--- a/core/java/com/android/internal/app/ResolverListAdapter.java
+++ b/core/java/com/android/internal/app/ResolverListAdapter.java
@@ -155,14 +155,6 @@ public class ResolverListAdapter extends BaseAdapter {
return mResolverListController.getScore(componentName);
}
- /**
- * Returns the list of top K component names which have highest
- * {@link #getScore(DisplayResolveInfo)}
- */
- public List<ComponentName> getTopComponentNames(int topK) {
- return mResolverListController.getTopComponentNames(topK);
- }
-
public void updateModel(ComponentName componentName) {
mResolverListController.updateModel(componentName);
}
diff --git a/core/java/com/android/internal/app/ResolverListController.java b/core/java/com/android/internal/app/ResolverListController.java
index 9a95e6411fa4..27573631b2ce 100644
--- a/core/java/com/android/internal/app/ResolverListController.java
+++ b/core/java/com/android/internal/app/ResolverListController.java
@@ -393,14 +393,6 @@ public class ResolverListController {
return mResolverComparator.getScore(componentName);
}
- /**
- * Returns the list of top K component names which have highest
- * {@link #getScore(DisplayResolveInfo)}
- */
- public List<ComponentName> getTopComponentNames(int topK) {
- return mResolverComparator.getTopComponentNames(topK);
- }
-
public void updateModel(ComponentName componentName) {
mResolverComparator.updateModel(componentName);
}
diff --git a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
index cb946c0dcf99..c5b21ac4da90 100644
--- a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
+++ b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
@@ -43,12 +43,12 @@ import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
import java.text.Collator;
import java.util.ArrayList;
+import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
/**
* Ranks and compares packages based on usage stats and uses the {@link ResolverRankerService}.
@@ -83,6 +83,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
private ResolverRankerServiceConnection mConnection;
private Context mContext;
private CountDownLatch mConnectSignal;
+ private ResolverRankerServiceComparatorModel mComparatorModel;
public ResolverRankerServiceResolverComparator(Context context, Intent intent,
String referrerPackage, AfterCompute afterCompute,
@@ -99,6 +100,8 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
mRankerServiceName = new ComponentName(mContext, this.getClass());
setCallBack(afterCompute);
setChooserActivityLogger(chooserActivityLogger);
+
+ mComparatorModel = buildUpdatedModel();
}
@Override
@@ -125,6 +128,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
}
if (isUpdated) {
mRankerServiceName = mResolvedRankerName;
+ mComparatorModel = buildUpdatedModel();
}
} else {
Log.e(TAG, "Sizes of sent and received ResolverTargets diff.");
@@ -218,83 +222,25 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
}
}
predictSelectProbabilities(mTargets);
+
+ mComparatorModel = buildUpdatedModel();
}
@Override
public int compare(ResolveInfo lhs, ResolveInfo rhs) {
- if (mStats != null) {
- final ResolverTarget lhsTarget = mTargetsDict.get(new ComponentName(
- lhs.activityInfo.packageName, lhs.activityInfo.name));
- final ResolverTarget rhsTarget = mTargetsDict.get(new ComponentName(
- rhs.activityInfo.packageName, rhs.activityInfo.name));
-
- if (lhsTarget != null && rhsTarget != null) {
- final int selectProbabilityDiff = Float.compare(
- rhsTarget.getSelectProbability(), lhsTarget.getSelectProbability());
-
- if (selectProbabilityDiff != 0) {
- return selectProbabilityDiff > 0 ? 1 : -1;
- }
- }
- }
-
- CharSequence sa = lhs.loadLabel(mPm);
- if (sa == null) sa = lhs.activityInfo.name;
- CharSequence sb = rhs.loadLabel(mPm);
- if (sb == null) sb = rhs.activityInfo.name;
-
- return mCollator.compare(sa.toString().trim(), sb.toString().trim());
+ return mComparatorModel.getComparator().compare(lhs, rhs);
}
@Override
public float getScore(ComponentName name) {
- final ResolverTarget target = mTargetsDict.get(name);
- if (target != null) {
- return target.getSelectProbability();
- }
- return 0;
- }
-
- @Override
- List<ComponentName> getTopComponentNames(int topK) {
- return mTargetsDict.entrySet().stream()
- .sorted((o1, o2) -> -Float.compare(getScore(o1.getKey()), getScore(o2.getKey())))
- .limit(topK)
- .map(Map.Entry::getKey)
- .collect(Collectors.toList());
+ return mComparatorModel.getScore(name);
}
// update ranking model when the connection to it is valid.
@Override
public void updateModel(ComponentName componentName) {
synchronized (mLock) {
- if (mRanker != null) {
- try {
- int selectedPos = new ArrayList<ComponentName>(mTargetsDict.keySet())
- .indexOf(componentName);
- if (selectedPos >= 0 && mTargets != null) {
- final float selectedProbability = getScore(componentName);
- int order = 0;
- for (ResolverTarget target : mTargets) {
- if (target.getSelectProbability() > selectedProbability) {
- order++;
- }
- }
- logMetrics(order);
- mRanker.train(mTargets, selectedPos);
- } else {
- if (DEBUG) {
- Log.d(TAG, "Selected a unknown component: " + componentName);
- }
- }
- } catch (RemoteException e) {
- Log.e(TAG, "Error in Train: " + e);
- }
- } else {
- if (DEBUG) {
- Log.d(TAG, "Ranker is null; skip updateModel.");
- }
- }
+ mComparatorModel.notifyOnTargetSelected(componentName);
}
}
@@ -313,19 +259,6 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
}
}
- // records metrics for evaluation.
- private void logMetrics(int selectedPos) {
- if (mRankerServiceName != null) {
- MetricsLogger metricsLogger = new MetricsLogger();
- LogMaker log = new LogMaker(MetricsEvent.ACTION_TARGET_SELECTED);
- log.setComponentName(mRankerServiceName);
- int isCategoryUsed = (mAnnotations == null) ? 0 : 1;
- log.addTaggedData(MetricsEvent.FIELD_IS_CATEGORY_USED, isCategoryUsed);
- log.addTaggedData(MetricsEvent.FIELD_RANKED_POSITION, selectedPos);
- metricsLogger.write(log);
- }
- }
-
// connect to a ranking service.
private void initRanker(Context context) {
synchronized (mLock) {
@@ -426,6 +359,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
}
synchronized (mLock) {
mRanker = IResolverRankerService.Stub.asInterface(service);
+ mComparatorModel = buildUpdatedModel();
mConnectSignal.countDown();
}
}
@@ -443,6 +377,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
public void destroy() {
synchronized (mLock) {
mRanker = null;
+ mComparatorModel = buildUpdatedModel();
}
}
}
@@ -453,6 +388,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
mTargetsDict.clear();
mTargets = null;
mRankerServiceName = new ComponentName(mContext, this.getClass());
+ mComparatorModel = buildUpdatedModel();
mResolvedRankerName = null;
initRanker(mContext);
}
@@ -508,4 +444,155 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
}
return false;
}
+
+ /**
+ * Re-construct a {@code ResolverRankerServiceComparatorModel} to replace the current model
+ * instance (if any) using the up-to-date {@code ResolverRankerServiceResolverComparator} ivar
+ * values.
+ *
+ * TODO: each time we replace the model instance, we're either updating the model to use
+ * adjusted data (which is appropriate), or we're providing a (late) value for one of our ivars
+ * that wasn't available the last time the model was updated. For those latter cases, we should
+ * just avoid creating the model altogether until we have all the prerequisites we'll need. Then
+ * we can probably simplify the logic in {@code ResolverRankerServiceComparatorModel} since we
+ * won't need to handle edge cases when the model data isn't fully prepared.
+ * (In some cases, these kinds of "updates" might interleave -- e.g., we might have finished
+ * initializing the first time and now want to adjust some data, but still need to wait for
+ * changes to propagate to the other ivars before rebuilding the model.)
+ */
+ private ResolverRankerServiceComparatorModel buildUpdatedModel() {
+ // TODO: we don't currently guarantee that the underlying target list/map won't be mutated,
+ // so the ResolverComparatorModel may provide inconsistent results. We should make immutable
+ // copies of the data (waiting for any necessary remaining data before creating the model).
+ return new ResolverRankerServiceComparatorModel(
+ mStats,
+ mTargetsDict,
+ mTargets,
+ mCollator,
+ mRanker,
+ mRankerServiceName,
+ (mAnnotations != null),
+ mPm);
+ }
+
+ /**
+ * Implementation of a {@code ResolverComparatorModel} that provides the same ranking logic as
+ * the legacy {@code ResolverRankerServiceResolverComparator}, as a refactoring step toward
+ * removing the complex legacy API.
+ */
+ static class ResolverRankerServiceComparatorModel implements ResolverComparatorModel {
+ private final Map<String, UsageStats> mStats; // Treat as immutable.
+ private final Map<ComponentName, ResolverTarget> mTargetsDict; // Treat as immutable.
+ private final List<ResolverTarget> mTargets; // Treat as immutable.
+ private final Collator mCollator;
+ private final IResolverRankerService mRanker;
+ private final ComponentName mRankerServiceName;
+ private final boolean mAnnotationsUsed;
+ private final PackageManager mPm;
+
+ // TODO: it doesn't look like we should have to pass both targets and targetsDict, but it's
+ // not written in a way that makes it clear whether we can derive one from the other (at
+ // least in this constructor).
+ ResolverRankerServiceComparatorModel(
+ Map<String, UsageStats> stats,
+ Map<ComponentName, ResolverTarget> targetsDict,
+ List<ResolverTarget> targets,
+ Collator collator,
+ IResolverRankerService ranker,
+ ComponentName rankerServiceName,
+ boolean annotationsUsed,
+ PackageManager pm) {
+ mStats = stats;
+ mTargetsDict = targetsDict;
+ mTargets = targets;
+ mCollator = collator;
+ mRanker = ranker;
+ mRankerServiceName = rankerServiceName;
+ mAnnotationsUsed = annotationsUsed;
+ mPm = pm;
+ }
+
+ @Override
+ public Comparator<ResolveInfo> getComparator() {
+ // TODO: doCompute() doesn't seem to be concerned about null-checking mStats. Is that
+ // a bug there, or do we have a way of knowing it will be non-null under certain
+ // conditions?
+ return (lhs, rhs) -> {
+ if (mStats != null) {
+ final ResolverTarget lhsTarget = mTargetsDict.get(new ComponentName(
+ lhs.activityInfo.packageName, lhs.activityInfo.name));
+ final ResolverTarget rhsTarget = mTargetsDict.get(new ComponentName(
+ rhs.activityInfo.packageName, rhs.activityInfo.name));
+
+ if (lhsTarget != null && rhsTarget != null) {
+ final int selectProbabilityDiff = Float.compare(
+ rhsTarget.getSelectProbability(), lhsTarget.getSelectProbability());
+
+ if (selectProbabilityDiff != 0) {
+ return selectProbabilityDiff > 0 ? 1 : -1;
+ }
+ }
+ }
+
+ CharSequence sa = lhs.loadLabel(mPm);
+ if (sa == null) sa = lhs.activityInfo.name;
+ CharSequence sb = rhs.loadLabel(mPm);
+ if (sb == null) sb = rhs.activityInfo.name;
+
+ return mCollator.compare(sa.toString().trim(), sb.toString().trim());
+ };
+ }
+
+ @Override
+ public float getScore(ComponentName name) {
+ final ResolverTarget target = mTargetsDict.get(name);
+ if (target != null) {
+ return target.getSelectProbability();
+ }
+ return 0;
+ }
+
+ @Override
+ public void notifyOnTargetSelected(ComponentName componentName) {
+ if (mRanker != null) {
+ try {
+ int selectedPos = new ArrayList<ComponentName>(mTargetsDict.keySet())
+ .indexOf(componentName);
+ if (selectedPos >= 0 && mTargets != null) {
+ final float selectedProbability = getScore(componentName);
+ int order = 0;
+ for (ResolverTarget target : mTargets) {
+ if (target.getSelectProbability() > selectedProbability) {
+ order++;
+ }
+ }
+ logMetrics(order);
+ mRanker.train(mTargets, selectedPos);
+ } else {
+ if (DEBUG) {
+ Log.d(TAG, "Selected a unknown component: " + componentName);
+ }
+ }
+ } catch (RemoteException e) {
+ Log.e(TAG, "Error in Train: " + e);
+ }
+ } else {
+ if (DEBUG) {
+ Log.d(TAG, "Ranker is null; skip updateModel.");
+ }
+ }
+ }
+
+ /** Records metrics for evaluation. */
+ private void logMetrics(int selectedPos) {
+ if (mRankerServiceName != null) {
+ MetricsLogger metricsLogger = new MetricsLogger();
+ LogMaker log = new LogMaker(MetricsEvent.ACTION_TARGET_SELECTED);
+ log.setComponentName(mRankerServiceName);
+ log.addTaggedData(MetricsEvent.FIELD_IS_CATEGORY_USED, mAnnotationsUsed);
+ log.addTaggedData(MetricsEvent.FIELD_RANKED_POSITION, selectedPos);
+ metricsLogger.write(log);
+ }
+ }
+ }
}
diff --git a/core/tests/coretests/src/com/android/internal/app/AbstractResolverComparatorTest.java b/core/tests/coretests/src/com/android/internal/app/AbstractResolverComparatorTest.java
index 04b888623732..3e640c1bad39 100644
--- a/core/tests/coretests/src/com/android/internal/app/AbstractResolverComparatorTest.java
+++ b/core/tests/coretests/src/com/android/internal/app/AbstractResolverComparatorTest.java
@@ -115,11 +115,6 @@ public class AbstractResolverComparatorTest {
@Override
void handleResultMessage(Message message) {}
-
- @Override
- List<ComponentName> getTopComponentNames(int topK) {
- return null;
- }
};
return testComparator;
}
diff --git a/core/tests/coretests/src/com/android/internal/app/FakeResolverComparatorModel.java b/core/tests/coretests/src/com/android/internal/app/FakeResolverComparatorModel.java
new file mode 100644
index 000000000000..fbbe57c8e325
--- /dev/null
+++ b/core/tests/coretests/src/com/android/internal/app/FakeResolverComparatorModel.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.internal.app;
+
+import android.content.ComponentName;
+import android.content.pm.ResolveInfo;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+/**
+ * Basic {@link ResolverComparatorModel} implementation that sorts according to a pre-defined (or
+ * default) {@link java.util.Comparator}.
+ */
+public class FakeResolverComparatorModel implements ResolverComparatorModel {
+ private final Comparator<ResolveInfo> mComparator;
+
+ public static FakeResolverComparatorModel makeModelFromComparator(
+ Comparator<ResolveInfo> comparator) {
+ return new FakeResolverComparatorModel(comparator);
+ }
+
+ public static FakeResolverComparatorModel makeDefaultModel() {
+ return makeModelFromComparator(Comparator.comparing(ri -> ri.activityInfo.name));
+ }
+
+ @Override
+ public Comparator<ResolveInfo> getComparator() {
+ return mComparator;
+ }
+
+ @Override
+ public float getScore(ComponentName name) {
+ return 0.0f; // Models are not required to provide numerical scores.
+ }
+
+ @Override
+ public void notifyOnTargetSelected(ComponentName componentName) {
+ System.out.println(
+ "User selected " + componentName + " under model " + System.identityHashCode(this));
+ }
+
+ private FakeResolverComparatorModel(Comparator<ResolveInfo> comparator) {
+ mComparator = comparator;
+ }
+} \ No newline at end of file