Enable App Prediction for Work Profile

call to different instance of AppPredictionPerUserService based on
calling context.

Bug: 148230574
Test: Manual
Change-Id: I5987ed6a80155a8aed7c1985e8edd1ba98e51149
diff --git a/core/java/android/app/prediction/AppPredictionSessionId.java b/core/java/android/app/prediction/AppPredictionSessionId.java
index e5e06f8..876bafd 100644
--- a/core/java/android/app/prediction/AppPredictionSessionId.java
+++ b/core/java/android/app/prediction/AppPredictionSessionId.java
@@ -22,6 +22,8 @@
 import android.os.Parcel;
 import android.os.Parcelable;
 
+import java.util.Objects;
+
 /**
  * The id for an app prediction session. See {@link AppPredictor}.
  *
@@ -32,18 +34,28 @@
 public final class AppPredictionSessionId implements Parcelable {
 
     private final String mId;
+    private final int mUserId;
 
     /**
      * Creates a new id for a prediction session.
      *
      * @hide
      */
-    public AppPredictionSessionId(@NonNull String id) {
+    public AppPredictionSessionId(@NonNull final String id, final int userId) {
         mId = id;
+        mUserId = userId;
     }
 
     private AppPredictionSessionId(Parcel p) {
         mId = p.readString();
+        mUserId = p.readInt();
+    }
+
+    /**
+     * @hide
+     */
+    public int getUserId() {
+        return mUserId;
     }
 
     @Override
@@ -51,17 +63,17 @@
         if (!getClass().equals(o != null ? o.getClass() : null)) return false;
 
         AppPredictionSessionId other = (AppPredictionSessionId) o;
-        return mId.equals(other.mId);
+        return mId.equals(other.mId) && mUserId == other.mUserId;
     }
 
     @Override
     public @NonNull String toString() {
-        return mId;
+        return mId + "," + mUserId;
     }
 
     @Override
     public int hashCode() {
-        return mId.hashCode();
+        return Objects.hash(mId, mUserId);
     }
 
     @Override
@@ -72,6 +84,7 @@
     @Override
     public void writeToParcel(Parcel dest, int flags) {
         dest.writeString(mId);
+        dest.writeInt(mUserId);
     }
 
     public static final @android.annotation.NonNull Parcelable.Creator<AppPredictionSessionId> CREATOR =
diff --git a/core/java/android/app/prediction/AppPredictor.java b/core/java/android/app/prediction/AppPredictor.java
index cd635d6..f0eedf3 100644
--- a/core/java/android/app/prediction/AppPredictor.java
+++ b/core/java/android/app/prediction/AppPredictor.java
@@ -96,7 +96,7 @@
         IBinder b = ServiceManager.getService(Context.APP_PREDICTION_SERVICE);
         mPredictionManager = IPredictionManager.Stub.asInterface(b);
         mSessionId = new AppPredictionSessionId(
-                context.getPackageName() + ":" + UUID.randomUUID().toString());
+                context.getPackageName() + ":" + UUID.randomUUID().toString(), context.getUserId());
         try {
             mPredictionManager.createPredictionSession(predictionContext, mSessionId);
         } catch (RemoteException e) {
diff --git a/services/appprediction/java/com/android/server/appprediction/AppPredictionManagerService.java b/services/appprediction/java/com/android/server/appprediction/AppPredictionManagerService.java
index 5844f98..1c4db12 100644
--- a/services/appprediction/java/com/android/server/appprediction/AppPredictionManagerService.java
+++ b/services/appprediction/java/com/android/server/appprediction/AppPredictionManagerService.java
@@ -18,12 +18,14 @@
 
 import static android.Manifest.permission.MANAGE_APP_PREDICTIONS;
 import static android.Manifest.permission.PACKAGE_USAGE_STATS;
+import static android.app.ActivityManagerInternal.ALLOW_NON_FULL;
 import static android.content.Context.APP_PREDICTION_SERVICE;
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.UserIdInt;
+import android.app.ActivityManagerInternal;
 import android.app.prediction.AppPredictionContext;
 import android.app.prediction.AppPredictionSessionId;
 import android.app.prediction.AppTargetEvent;
@@ -34,7 +36,6 @@
 import android.os.Binder;
 import android.os.ResultReceiver;
 import android.os.ShellCallback;
-import android.os.UserHandle;
 import android.util.Slog;
 
 import com.android.server.LocalServices;
@@ -108,21 +109,21 @@
         @Override
         public void createPredictionSession(@NonNull AppPredictionContext context,
                 @NonNull AppPredictionSessionId sessionId) {
-            runForUserLocked("createPredictionSession",
+            runForUserLocked("createPredictionSession", sessionId,
                     (service) -> service.onCreatePredictionSessionLocked(context, sessionId));
         }
 
         @Override
         public void notifyAppTargetEvent(@NonNull AppPredictionSessionId sessionId,
                 @NonNull AppTargetEvent event) {
-            runForUserLocked("notifyAppTargetEvent",
+            runForUserLocked("notifyAppTargetEvent", sessionId,
                     (service) -> service.notifyAppTargetEventLocked(sessionId, event));
         }
 
         @Override
         public void notifyLaunchLocationShown(@NonNull AppPredictionSessionId sessionId,
                 @NonNull String launchLocation, @NonNull ParceledListSlice targetIds) {
-            runForUserLocked("notifyLaunchLocationShown", (service) ->
+            runForUserLocked("notifyLaunchLocationShown", sessionId, (service) ->
                     service.notifyLaunchLocationShownLocked(sessionId, launchLocation, targetIds));
         }
 
@@ -130,32 +131,32 @@
         public void sortAppTargets(@NonNull AppPredictionSessionId sessionId,
                 @NonNull ParceledListSlice targets,
                 IPredictionCallback callback) {
-            runForUserLocked("sortAppTargets",
+            runForUserLocked("sortAppTargets", sessionId,
                     (service) -> service.sortAppTargetsLocked(sessionId, targets, callback));
         }
 
         @Override
         public void registerPredictionUpdates(@NonNull AppPredictionSessionId sessionId,
                 @NonNull IPredictionCallback callback) {
-            runForUserLocked("registerPredictionUpdates",
+            runForUserLocked("registerPredictionUpdates", sessionId,
                     (service) -> service.registerPredictionUpdatesLocked(sessionId, callback));
         }
 
         public void unregisterPredictionUpdates(@NonNull AppPredictionSessionId sessionId,
                 @NonNull IPredictionCallback callback) {
-            runForUserLocked("unregisterPredictionUpdates",
+            runForUserLocked("unregisterPredictionUpdates", sessionId,
                     (service) -> service.unregisterPredictionUpdatesLocked(sessionId, callback));
         }
 
         @Override
         public void requestPredictionUpdate(@NonNull AppPredictionSessionId sessionId) {
-            runForUserLocked("requestPredictionUpdate",
+            runForUserLocked("requestPredictionUpdate", sessionId,
                     (service) -> service.requestPredictionUpdateLocked(sessionId));
         }
 
         @Override
         public void onDestroyPredictionSession(@NonNull AppPredictionSessionId sessionId) {
-            runForUserLocked("onDestroyPredictionSession",
+            runForUserLocked("onDestroyPredictionSession", sessionId,
                     (service) -> service.onDestroyPredictionSessionLocked(sessionId));
         }
 
@@ -167,9 +168,12 @@
                     .exec(this, in, out, err, args, callback, resultReceiver);
         }
 
-        private void runForUserLocked(@NonNull String func,
-                @NonNull Consumer<AppPredictionPerUserService> c) {
-            final int userId = UserHandle.getCallingUserId();
+        private void runForUserLocked(@NonNull final String func,
+                @NonNull final AppPredictionSessionId sessionId,
+                @NonNull final Consumer<AppPredictionPerUserService> c) {
+            ActivityManagerInternal am = LocalServices.getService(ActivityManagerInternal.class);
+            final int userId = am.handleIncomingUser(Binder.getCallingPid(), Binder.getCallingUid(),
+                    sessionId.getUserId(), false, ALLOW_NON_FULL, null, null);
 
             Context ctx = getContext();
             if (!(ctx.checkCallingPermission(PACKAGE_USAGE_STATS) == PERMISSION_GRANTED
diff --git a/services/tests/servicestests/src/com/android/server/people/PeopleServiceTest.java b/services/tests/servicestests/src/com/android/server/people/PeopleServiceTest.java
index 4ae374a..9213e1f 100644
--- a/services/tests/servicestests/src/com/android/server/people/PeopleServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/people/PeopleServiceTest.java
@@ -51,6 +51,7 @@
     private static final String APP_PREDICTION_SHARE_UI_SURFACE = "share";
     private static final int APP_PREDICTION_TARGET_COUNT = 4;
     private static final String TEST_PACKAGE_NAME = "com.example";
+    private static final int USER_ID = 0;
 
     private PeopleServiceInternal mServiceInternal;
     private PeopleService.LocalService mLocalService;
@@ -73,7 +74,7 @@
         mServiceInternal = LocalServices.getService(PeopleServiceInternal.class);
         mLocalService = (PeopleService.LocalService) mServiceInternal;
 
-        mSessionId = new AppPredictionSessionId("abc");
+        mSessionId = new AppPredictionSessionId("abc", USER_ID);
         mPredictionContext = new AppPredictionContext.Builder(mContext)
                 .setUiSurface(APP_PREDICTION_SHARE_UI_SURFACE)
                 .setPredictedTargetCount(APP_PREDICTION_TARGET_COUNT)