diff options
| author | 2019-02-12 20:18:39 +0000 | |
|---|---|---|
| committer | 2019-02-12 20:18:39 +0000 | |
| commit | 33adf0df1d962303d0a5c3236acd9ac7ed72dae1 (patch) | |
| tree | db25ca6467345bd99e9a089b09a372687bcf1289 | |
| parent | ae78bea482dc37dd6cf15e6d6401cd59e085165e (diff) | |
| parent | 40d5ce0b54e54f524f978cb6abfa0eb404247971 (diff) | |
Merge "A few missing pieces for logging in ExtServices"
2 files changed, 201 insertions, 92 deletions
diff --git a/packages/ExtServices/src/android/ext/services/notification/SmartActionsHelper.java b/packages/ExtServices/src/android/ext/services/notification/SmartActionsHelper.java index f372fe55dfb0..24fa87a7f39b 100644 --- a/packages/ExtServices/src/android/ext/services/notification/SmartActionsHelper.java +++ b/packages/ExtServices/src/android/ext/services/notification/SmartActionsHelper.java @@ -15,11 +15,10 @@ */ package android.ext.services.notification; -import android.annotation.NonNull; -import android.annotation.Nullable; import android.app.Notification; import android.app.Person; import android.app.RemoteAction; +import android.app.RemoteInput; import android.content.Context; import android.graphics.drawable.Icon; import android.os.Bundle; @@ -27,7 +26,9 @@ import android.os.Parcelable; import android.os.Process; import android.service.notification.NotificationAssistantService; import android.text.TextUtils; +import android.util.ArrayMap; import android.util.LruCache; +import android.util.Pair; import android.view.textclassifier.ConversationAction; import android.view.textclassifier.ConversationActions; import android.view.textclassifier.TextClassificationContext; @@ -35,6 +36,8 @@ import android.view.textclassifier.TextClassificationManager; import android.view.textclassifier.TextClassifier; import android.view.textclassifier.TextClassifierEvent; +import com.android.internal.util.ArrayUtils; + import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; @@ -43,11 +46,13 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; public class SmartActionsHelper { private static final String KEY_ACTION_TYPE = "action_type"; + private static final String KEY_ACTION_SCORE = "action_score"; // If a notification has any of these flags set, it's inelgibile for actions being added. private static final int FLAG_MASK_INELGIBILE_FOR_ACTIONS = Notification.FLAG_ONGOING_EVENT @@ -58,75 +63,136 @@ public class SmartActionsHelper { private static final List<String> HINTS = Collections.singletonList(ConversationActions.Request.HINT_FOR_NOTIFICATION); + private static final ConversationActions EMPTY_CONVERSATION_ACTIONS = + new ConversationActions(Collections.emptyList(), null); private Context mContext; - @Nullable private TextClassifier mTextClassifier; - @NonNull private AssistantSettings mSettings; - private LruCache<String, String> mNotificationKeyToResultIdCache = - new LruCache<>(MAX_RESULT_ID_TO_CACHE); + private LruCache<String, Session> mSessionCache = new LruCache<>(MAX_RESULT_ID_TO_CACHE); SmartActionsHelper(Context context, AssistantSettings settings) { mContext = context; TextClassificationManager textClassificationManager = mContext.getSystemService(TextClassificationManager.class); - if (textClassificationManager != null) { - mTextClassifier = textClassificationManager.getTextClassifier(); - } + mTextClassifier = textClassificationManager.getTextClassifier(); mSettings = settings; } - SmartSuggestions suggest(@NonNull NotificationEntry entry) { + SmartSuggestions suggest(NotificationEntry entry) { // Whenever suggest() is called on a notification, its previous session is ended. - mNotificationKeyToResultIdCache.remove(entry.getSbn().getKey()); + mSessionCache.remove(entry.getSbn().getKey()); boolean eligibleForReplyAdjustment = mSettings.mGenerateReplies && isEligibleForReplyAdjustment(entry); boolean eligibleForActionAdjustment = mSettings.mGenerateActions && isEligibleForActionAdjustment(entry); - List<ConversationAction> conversationActions = + ConversationActions conversationActionsResult = suggestConversationActions( entry, eligibleForReplyAdjustment, eligibleForActionAdjustment); - ArrayList<CharSequence> replies = conversationActions.stream() - .map(ConversationAction::getTextReply) - .filter(textReply -> !TextUtils.isEmpty(textReply)) - .collect(Collectors.toCollection(ArrayList::new)); + String resultId = conversationActionsResult.getId(); + List<ConversationAction> conversationActions = + conversationActionsResult.getConversationActions(); + + ArrayList<CharSequence> replies = new ArrayList<>(); + Map<CharSequence, Float> repliesScore = new ArrayMap<>(); + for (ConversationAction conversationAction : conversationActions) { + CharSequence textReply = conversationAction.getTextReply(); + if (TextUtils.isEmpty(textReply)) { + continue; + } + replies.add(textReply); + repliesScore.put(textReply, conversationAction.getConfidenceScore()); + } ArrayList<Notification.Action> actions = conversationActions.stream() .filter(conversationAction -> conversationAction.getAction() != null) - .map(action -> createNotificationAction(action.getAction(), action.getType())) + .map(action -> createNotificationAction( + action.getAction(), action.getType(), action.getConfidenceScore())) .collect(Collectors.toCollection(ArrayList::new)); + + // Start a new session for logging if necessary. + if (!TextUtils.isEmpty(resultId) + && !conversationActions.isEmpty() + && suggestionsMightBeUsedInNotification( + entry, !actions.isEmpty(), !replies.isEmpty())) { + mSessionCache.put(entry.getSbn().getKey(), new Session(resultId, repliesScore)); + } + return new SmartSuggestions(replies, actions); } /** + * Returns whether the suggestion might be used in the notifications in SysUI. + * <p> + * Currently, NAS has no idea if suggestions will actually be used in the notification, and thus + * this function tries to make a heuristic. This function tries to optimize the precision, + * that means when it is unsure, it will return false. The objective is to avoid false positive, + * which could pollute the log and CTR as we are logging click rate of suggestions that could + * be never visible to users. On the other hand, it is fine to have false negative because + * it would be just like sampling. + */ + private boolean suggestionsMightBeUsedInNotification( + NotificationEntry notificationEntry, boolean hasSmartAction, boolean hasSmartReply) { + Notification notification = notificationEntry.getNotification(); + boolean hasAppGeneratedContextualActions = !notification.getContextualActions().isEmpty(); + + Pair<RemoteInput, Notification.Action> freeformRemoteInputAndAction = + notification.findRemoteInputActionPair(/* requiresFreeform */ true); + boolean hasAppGeneratedReplies = false; + boolean allowGeneratedReplies = false; + if (freeformRemoteInputAndAction != null) { + RemoteInput freeformRemoteInput = freeformRemoteInputAndAction.first; + Notification.Action actionWithFreeformRemoteInput = freeformRemoteInputAndAction.second; + hasAppGeneratedReplies = !ArrayUtils.isEmpty(freeformRemoteInput.getChoices()); + allowGeneratedReplies = actionWithFreeformRemoteInput.getAllowGeneratedReplies(); + } + + if (hasAppGeneratedReplies || hasAppGeneratedContextualActions) { + return false; + } + return hasSmartAction && notification.getAllowSystemGeneratedContextualActions() + || hasSmartReply && allowGeneratedReplies; + } + + private void reportActionsGenerated( + String resultId, List<ConversationAction> conversationActions) { + if (TextUtils.isEmpty(resultId)) { + return; + } + TextClassifierEvent textClassifierEvent = + createTextClassifierEventBuilder( + TextClassifierEvent.TYPE_ACTIONS_GENERATED, resultId) + .setEntityTypes(conversationActions.stream() + .map(ConversationAction::getType) + .toArray(String[]::new)) + .build(); + mTextClassifier.onTextClassifierEvent(textClassifierEvent); + } + + /** * Adds action adjustments based on the notification contents. */ - @NonNull - private List<ConversationAction> suggestConversationActions( - @NonNull NotificationEntry entry, + private ConversationActions suggestConversationActions( + NotificationEntry entry, boolean includeReplies, boolean includeActions) { if (!includeReplies && !includeActions) { - return Collections.emptyList(); - } - if (mTextClassifier == null) { - return Collections.emptyList(); + return EMPTY_CONVERSATION_ACTIONS; } List<ConversationActions.Message> messages = extractMessages(entry.getNotification()); if (messages.isEmpty()) { - return Collections.emptyList(); + return EMPTY_CONVERSATION_ACTIONS; } // Do not generate smart actions if the last message is from the local user. ConversationActions.Message lastMessage = messages.get(messages.size() - 1); if (arePersonsEqual( ConversationActions.Message.PERSON_USER_SELF, lastMessage.getAuthor())) { - return Collections.emptyList(); + return EMPTY_CONVERSATION_ACTIONS; } TextClassifier.EntityConfig.Builder typeConfigBuilder = @@ -146,25 +212,20 @@ public class SmartActionsHelper { .setHints(HINTS) .setTypeConfig(typeConfigBuilder.build()) .build(); - - ConversationActions conversationActionsResult = + ConversationActions conversationActions = mTextClassifier.suggestConversationActions(request); - - String resultId = conversationActionsResult.getId(); - if (!TextUtils.isEmpty(resultId) - && !conversationActionsResult.getConversationActions().isEmpty()) { - mNotificationKeyToResultIdCache.put(entry.getSbn().getKey(), resultId); - } - return conversationActionsResult.getConversationActions(); + reportActionsGenerated( + conversationActions.getId(), conversationActions.getConversationActions()); + return conversationActions; } - void onNotificationExpansionChanged(@NonNull NotificationEntry entry, boolean isUserAction, + void onNotificationExpansionChanged(NotificationEntry entry, boolean isUserAction, boolean isExpanded) { if (!isExpanded) { return; } - String resultId = mNotificationKeyToResultIdCache.get(entry.getSbn().getKey()); - if (resultId == null) { + Session session = mSessionCache.get(entry.getSbn().getKey()); + if (session == null) { return; } // Only report if this is the first time the user sees these suggestions. @@ -173,56 +234,50 @@ public class SmartActionsHelper { } entry.setShowActionEventLogged(); TextClassifierEvent textClassifierEvent = - createTextClassifierEventBuilder(TextClassifierEvent.TYPE_ACTIONS_SHOWN, - resultId) + createTextClassifierEventBuilder( + TextClassifierEvent.TYPE_ACTIONS_SHOWN, session.resultId) .build(); // TODO: If possible, report which replies / actions are actually seen by user. mTextClassifier.onTextClassifierEvent(textClassifierEvent); } - void onNotificationDirectReplied(@NonNull String key) { - if (mTextClassifier == null) { - return; - } - String resultId = mNotificationKeyToResultIdCache.get(key); - if (resultId == null) { + void onNotificationDirectReplied(String key) { + Session session = mSessionCache.get(key); + if (session == null) { return; } TextClassifierEvent textClassifierEvent = - createTextClassifierEventBuilder(TextClassifierEvent.TYPE_MANUAL_REPLY, resultId) + createTextClassifierEventBuilder( + TextClassifierEvent.TYPE_MANUAL_REPLY, session.resultId) .build(); mTextClassifier.onTextClassifierEvent(textClassifierEvent); } - void onSuggestedReplySent(@NonNull String key, @NonNull CharSequence reply, + void onSuggestedReplySent(String key, CharSequence reply, @NotificationAssistantService.Source int source) { - if (mTextClassifier == null) { - return; - } if (source != NotificationAssistantService.SOURCE_FROM_ASSISTANT) { return; } - String resultId = mNotificationKeyToResultIdCache.get(key); - if (resultId == null) { + Session session = mSessionCache.get(key); + if (session == null) { return; } TextClassifierEvent textClassifierEvent = - createTextClassifierEventBuilder(TextClassifierEvent.TYPE_SMART_ACTION, resultId) + createTextClassifierEventBuilder( + TextClassifierEvent.TYPE_SMART_ACTION, session.resultId) .setEntityTypes(ConversationAction.TYPE_TEXT_REPLY) + .setScore(session.repliesScores.getOrDefault(reply, 0f)) .build(); mTextClassifier.onTextClassifierEvent(textClassifierEvent); } - void onActionClicked(@NonNull String key, @NonNull Notification.Action action, + void onActionClicked(String key, Notification.Action action, @NotificationAssistantService.Source int source) { - if (mTextClassifier == null) { - return; - } if (source != NotificationAssistantService.SOURCE_FROM_ASSISTANT) { return; } - String resultId = mNotificationKeyToResultIdCache.get(key); - if (resultId == null) { + Session session = mSessionCache.get(key); + if (session == null) { return; } String actionType = action.getExtras().getString(KEY_ACTION_TYPE); @@ -230,28 +285,32 @@ public class SmartActionsHelper { return; } TextClassifierEvent textClassifierEvent = - createTextClassifierEventBuilder(TextClassifierEvent.TYPE_SMART_ACTION, resultId) + createTextClassifierEventBuilder( + TextClassifierEvent.TYPE_SMART_ACTION, session.resultId) .setEntityTypes(actionType) .build(); mTextClassifier.onTextClassifierEvent(textClassifierEvent); } private Notification.Action createNotificationAction( - RemoteAction remoteAction, String actionType) { + RemoteAction remoteAction, String actionType, float score) { Icon icon = remoteAction.shouldShowIcon() ? remoteAction.getIcon() : Icon.createWithResource(mContext, com.android.internal.R.drawable.ic_action_open); + Bundle extras = new Bundle(); + extras.putString(KEY_ACTION_TYPE, actionType); + extras.putFloat(KEY_ACTION_SCORE, score); return new Notification.Action.Builder( icon, remoteAction.getTitle(), remoteAction.getActionIntent()) .setContextual(true) - .addExtras(Bundle.forPair(KEY_ACTION_TYPE, actionType)) + .addExtras(extras) .build(); } private TextClassifierEvent.Builder createTextClassifierEventBuilder( - int eventType, @NonNull String resultId) { + int eventType, String resultId) { return new TextClassifierEvent.Builder( TextClassifierEvent.CATEGORY_CONVERSATION_ACTIONS, eventType) .setEventTime(System.currentTimeMillis()) @@ -269,7 +328,7 @@ public class SmartActionsHelper { * to fundamental phone functionality where any error would result in a very negative user * experience. */ - private boolean isEligibleForActionAdjustment(@NonNull NotificationEntry entry) { + private boolean isEligibleForActionAdjustment(NotificationEntry entry) { Notification notification = entry.getNotification(); String pkg = entry.getSbn().getPackageName(); if (!Process.myUserHandle().equals(entry.getSbn().getUser())) { @@ -285,7 +344,7 @@ public class SmartActionsHelper { return entry.isMessaging(); } - private boolean isEligibleForReplyAdjustment(@NonNull NotificationEntry entry) { + private boolean isEligibleForReplyAdjustment(NotificationEntry entry) { if (!Process.myUserHandle().equals(entry.getSbn().getUser())) { return false; } @@ -306,8 +365,7 @@ public class SmartActionsHelper { } /** Returns the text most salient for action extraction in a notification. */ - @Nullable - private List<ConversationActions.Message> extractMessages(@NonNull Notification notification) { + private List<ConversationActions.Message> extractMessages(Notification notification) { Parcelable[] messages = notification.extras.getParcelableArray(Notification.EXTRA_MESSAGES); if (messages == null || messages.length == 0) { return Collections.singletonList(new ConversationActions.Message.Builder( @@ -343,7 +401,7 @@ public class SmartActionsHelper { return new ArrayList<>(extractMessages); } - private static boolean arePersonsEqual(@NonNull Person left, @NonNull Person right) { + private static boolean arePersonsEqual(Person left, Person right) { return Objects.equals(left.getKey(), right.getKey()) && Objects.equals(left.getName(), right.getName()) && Objects.equals(left.getUri(), right.getUri()); @@ -359,4 +417,14 @@ public class SmartActionsHelper { this.actions = actions; } } + + private static class Session { + public final String resultId; + public final Map<CharSequence, Float> repliesScores; + + Session(String resultId, Map<CharSequence, Float> repliesScores) { + this.resultId = resultId; + this.repliesScores = repliesScores; + } + } } diff --git a/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java b/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java index 74c20fc09df2..d0b6d0061166 100644 --- a/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java +++ b/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertAbout; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -53,8 +55,8 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import java.time.Instant; @@ -71,9 +73,12 @@ import javax.annotation.Nullable; public class SmartActionsHelperTest { private static final String NOTIFICATION_KEY = "key"; private static final String RESULT_ID = "id"; + private static final float SCORE = 0.7f; + private static final CharSequence SMART_REPLY = "Home"; private static final ConversationAction REPLY_ACTION = new ConversationAction.Builder(ConversationAction.TYPE_TEXT_REPLY) - .setTextReply("Home") + .setTextReply(SMART_REPLY) + .setConfidenceScore(SCORE) .build(); private static final String MESSAGE = "Where are you?"; @@ -197,8 +202,16 @@ public class SmartActionsHelperTest { List<ConversationActions.Message> messages = runSuggestAndCaptureRequest().getConversation(); + assertThat(messages).hasSize(1); MessageSubject.assertThat(messages.get(0)).hasText(MESSAGE); + ArgumentCaptor<TextClassifierEvent> argumentCaptor = + ArgumentCaptor.forClass(TextClassifierEvent.class); + verify(mTextClassifier).onTextClassifierEvent(argumentCaptor.capture()); + TextClassifierEvent textClassifierEvent = argumentCaptor.getValue(); + assertTextClassifierEvent(textClassifierEvent, TextClassifierEvent.TYPE_ACTIONS_GENERATED); + assertThat(textClassifierEvent.getEntityTypes()).asList() + .containsExactly(ConversationAction.TYPE_TEXT_REPLY); } @Test @@ -249,6 +262,14 @@ public class SmartActionsHelperTest { MessageSubject.assertThat(fourthMessage).hasPerson(userB); MessageSubject.assertThat(fourthMessage) .hasReferenceTime(createZonedDateTimeFromMsUtc(4000)); + + ArgumentCaptor<TextClassifierEvent> argumentCaptor = + ArgumentCaptor.forClass(TextClassifierEvent.class); + verify(mTextClassifier).onTextClassifierEvent(argumentCaptor.capture()); + TextClassifierEvent textClassifierEvent = argumentCaptor.getValue(); + assertTextClassifierEvent(textClassifierEvent, TextClassifierEvent.TYPE_ACTIONS_GENERATED); + assertThat(textClassifierEvent.getEntityTypes()).asList() + .containsExactly(ConversationAction.TYPE_TEXT_REPLY); } @Test @@ -299,13 +320,15 @@ public class SmartActionsHelperTest { mSmartActionsHelper.suggest(createNotificationEntry()); mSmartActionsHelper.onSuggestedReplySent( - NOTIFICATION_KEY, MESSAGE, NotificationAssistantService.SOURCE_FROM_ASSISTANT); + NOTIFICATION_KEY, SMART_REPLY, NotificationAssistantService.SOURCE_FROM_ASSISTANT); ArgumentCaptor<TextClassifierEvent> argumentCaptor = ArgumentCaptor.forClass(TextClassifierEvent.class); - verify(mTextClassifier).onTextClassifierEvent(argumentCaptor.capture()); - TextClassifierEvent textClassifierEvent = argumentCaptor.getValue(); - assertTextClassifierEvent(textClassifierEvent, TextClassifierEvent.TYPE_SMART_ACTION); + verify(mTextClassifier, times(2)).onTextClassifierEvent(argumentCaptor.capture()); + List<TextClassifierEvent> events = argumentCaptor.getAllValues(); + assertTextClassifierEvent(events.get(0), TextClassifierEvent.TYPE_ACTIONS_GENERATED); + assertTextClassifierEvent(events.get(1), TextClassifierEvent.TYPE_SMART_ACTION); + assertThat(events.get(1).getScore()).isEqualTo(SCORE); } @Test @@ -317,24 +340,22 @@ public class SmartActionsHelperTest { mSmartActionsHelper.onSuggestedReplySent( "something_else", MESSAGE, NotificationAssistantService.SOURCE_FROM_ASSISTANT); - verify(mTextClassifier, never()) - .onTextClassifierEvent(Mockito.any(TextClassifierEvent.class)); + verify(mTextClassifier, never()).onTextClassifierEvent( + argThat(new TextClassifierEventMatcher(TextClassifierEvent.TYPE_SMART_ACTION))); } @Test public void testOnSuggestedReplySent_missingResultId() { when(mTextClassifier.suggestConversationActions(any(ConversationActions.Request.class))) - .thenReturn(new ConversationActions(Collections.emptyList(), null)); - + .thenReturn(new ConversationActions(Collections.singletonList(REPLY_ACTION), null)); Notification notification = createMessageNotification(); when(mStatusBarNotification.getNotification()).thenReturn(notification); mSmartActionsHelper.suggest(createNotificationEntry()); mSmartActionsHelper.onSuggestedReplySent( - "something_else", MESSAGE, NotificationAssistantService.SOURCE_FROM_ASSISTANT); + NOTIFICATION_KEY, SMART_REPLY, NotificationAssistantService.SOURCE_FROM_ASSISTANT); - verify(mTextClassifier, never()) - .onTextClassifierEvent(Mockito.any(TextClassifierEvent.class)); + verify(mTextClassifier, never()).onTextClassifierEvent(any(TextClassifierEvent.class)); } @Test @@ -347,9 +368,10 @@ public class SmartActionsHelperTest { ArgumentCaptor<TextClassifierEvent> argumentCaptor = ArgumentCaptor.forClass(TextClassifierEvent.class); - verify(mTextClassifier).onTextClassifierEvent(argumentCaptor.capture()); - TextClassifierEvent textClassifierEvent = argumentCaptor.getValue(); - assertTextClassifierEvent(textClassifierEvent, TextClassifierEvent.TYPE_MANUAL_REPLY); + verify(mTextClassifier, times(2)).onTextClassifierEvent(argumentCaptor.capture()); + List<TextClassifierEvent> events = argumentCaptor.getAllValues(); + assertTextClassifierEvent(events.get(0), TextClassifierEvent.TYPE_ACTIONS_GENERATED); + assertTextClassifierEvent(events.get(1), TextClassifierEvent.TYPE_MANUAL_REPLY); } @Test @@ -362,9 +384,10 @@ public class SmartActionsHelperTest { ArgumentCaptor<TextClassifierEvent> argumentCaptor = ArgumentCaptor.forClass(TextClassifierEvent.class); - verify(mTextClassifier).onTextClassifierEvent(argumentCaptor.capture()); - TextClassifierEvent textClassifierEvent = argumentCaptor.getValue(); - assertTextClassifierEvent(textClassifierEvent, TextClassifierEvent.TYPE_ACTIONS_SHOWN); + verify(mTextClassifier, times(2)).onTextClassifierEvent(argumentCaptor.capture()); + List<TextClassifierEvent> events = argumentCaptor.getAllValues(); + assertTextClassifierEvent(events.get(0), TextClassifierEvent.TYPE_ACTIONS_GENERATED); + assertTextClassifierEvent(events.get(1), TextClassifierEvent.TYPE_ACTIONS_SHOWN); } @Test @@ -376,7 +399,7 @@ public class SmartActionsHelperTest { mSmartActionsHelper.onNotificationExpansionChanged(createNotificationEntry(), false, false); verify(mTextClassifier, never()).onTextClassifierEvent( - Mockito.any(TextClassifierEvent.class)); + argThat(new TextClassifierEventMatcher(TextClassifierEvent.TYPE_ACTIONS_SHOWN))); } @Test @@ -389,9 +412,10 @@ public class SmartActionsHelperTest { ArgumentCaptor<TextClassifierEvent> argumentCaptor = ArgumentCaptor.forClass(TextClassifierEvent.class); - verify(mTextClassifier).onTextClassifierEvent(argumentCaptor.capture()); - TextClassifierEvent textClassifierEvent = argumentCaptor.getValue(); - assertTextClassifierEvent(textClassifierEvent, TextClassifierEvent.TYPE_ACTIONS_SHOWN); + verify(mTextClassifier, times(2)).onTextClassifierEvent(argumentCaptor.capture()); + List<TextClassifierEvent> events = argumentCaptor.getAllValues(); + assertTextClassifierEvent(events.get(0), TextClassifierEvent.TYPE_ACTIONS_GENERATED); + assertTextClassifierEvent(events.get(1), TextClassifierEvent.TYPE_ACTIONS_SHOWN); } private ZonedDateTime createZonedDateTimeFromMsUtc(long msUtc) { @@ -490,4 +514,21 @@ public class SmartActionsHelperTest { return assertAbout(FACTORY).that(message); } } + + private final class TextClassifierEventMatcher implements ArgumentMatcher<TextClassifierEvent> { + + private int mType; + + private TextClassifierEventMatcher(int type) { + mType = type; + } + + @Override + public boolean matches(TextClassifierEvent textClassifierEvent) { + if (textClassifierEvent == null) { + return false; + } + return mType == textClassifierEvent.getEventType(); + } + } } |