Improve startBackNavigation stability

  - Use the focused window instead of the topApp window
    - Instead we now rely on WindowManagerService to get
     the focused window.
    - SystemUI does not have ActivityRecord so we can't rely
     on the top window of the Task to find the correct
      window on which the callback will be called.

  - Introduce a Builder for BackNavigationInfo
    - This reduces the number of variable needed outside the synchonized
    block.
    - It also reduces the number of early return of BackNavigationInfo
    instances

  - Adding log messages to help further debug the method.

Test: BackNavigationControllerTests
Test: Manual dismiss of SystemUi dialog in QS
Bug: 216604581
Fixes: 221458292
Change-Id: I9ba2c7f89956f34d6338824502c210b3e58dc076

Introduce builder for BackNavigationInfo

Change-Id: I14b4a4b3abc8f417998b7b32831cb3d5c4faa491
diff --git a/core/java/android/view/ViewRootImpl.java b/core/java/android/view/ViewRootImpl.java
index 172cd03..17e3914 100644
--- a/core/java/android/view/ViewRootImpl.java
+++ b/core/java/android/view/ViewRootImpl.java
@@ -147,6 +147,7 @@
 import android.os.Trace;
 import android.os.UserHandle;
 import android.sysprop.DisplayProperties;
+import android.text.TextUtils;
 import android.util.AndroidRuntimeException;
 import android.util.DisplayMetrics;
 import android.util.EventLog;
@@ -10885,6 +10886,12 @@
      * {@link OnBackInvokedCallback} to be called to the server.
      */
     private void registerBackCallbackOnWindow() {
+        if (OnBackInvokedDispatcher.DEBUG) {
+            Log.d(OnBackInvokedDispatcher.TAG, TextUtils.formatSimple(
+                    "ViewRootImpl.registerBackCallbackOnWindow. Callback:%s Package:%s "
+                            + "IWindow:%s Session:%s",
+                    mOnBackInvokedDispatcher, mBasePackageName, mWindow, mWindowSession));
+        }
         mOnBackInvokedDispatcher.attachToWindow(mWindowSession, mWindow);
     }
 
diff --git a/core/java/android/window/BackNavigationInfo.java b/core/java/android/window/BackNavigationInfo.java
index 6653758..0ab6db5 100644
--- a/core/java/android/window/BackNavigationInfo.java
+++ b/core/java/android/window/BackNavigationInfo.java
@@ -81,7 +81,9 @@
             TYPE_DIALOG_CLOSE,
             TYPE_RETURN_TO_HOME,
             TYPE_CROSS_ACTIVITY,
-            TYPE_CROSS_TASK})
+            TYPE_CROSS_TASK,
+            TYPE_CALLBACK
+    })
     @interface BackTargetType {
     }
 
@@ -121,8 +123,8 @@
             @Nullable SurfaceControl screenshotSurface,
             @Nullable HardwareBuffer screenshotBuffer,
             @Nullable WindowConfiguration taskWindowConfiguration,
-            @NonNull RemoteCallback onBackNavigationDone,
-            @NonNull IOnBackInvokedCallback onBackInvokedCallback) {
+            @Nullable RemoteCallback onBackNavigationDone,
+            @Nullable IOnBackInvokedCallback onBackInvokedCallback) {
         mType = type;
         mDepartingAnimationTarget = departingAnimationTarget;
         mScreenshotSurface = screenshotSurface;
@@ -278,7 +280,98 @@
                 return "TYPE_CROSS_ACTIVITY";
             case TYPE_CROSS_TASK:
                 return "TYPE_CROSS_TASK";
+            case TYPE_CALLBACK:
+                return "TYPE_CALLBACK";
         }
         return String.valueOf(type);
     }
+
+    /**
+     * @hide
+     */
+    @SuppressWarnings("UnusedReturnValue") // Builder pattern
+    public static class Builder {
+
+        private int mType = TYPE_UNDEFINED;
+        @Nullable
+        private RemoteAnimationTarget mDepartingAnimationTarget = null;
+        @Nullable
+        private SurfaceControl mScreenshotSurface = null;
+        @Nullable
+        private HardwareBuffer mScreenshotBuffer = null;
+        @Nullable
+        private WindowConfiguration mTaskWindowConfiguration = null;
+        @Nullable
+        private RemoteCallback mOnBackNavigationDone = null;
+        @Nullable
+        private IOnBackInvokedCallback mOnBackInvokedCallback = null;
+
+        /**
+         * @see BackNavigationInfo#getType()
+         */
+        public Builder setType(@BackTargetType int type) {
+            mType = type;
+            return this;
+        }
+
+        /**
+         * @see BackNavigationInfo#getDepartingAnimationTarget
+         */
+        public Builder setDepartingAnimationTarget(
+                @Nullable RemoteAnimationTarget departingAnimationTarget) {
+            mDepartingAnimationTarget = departingAnimationTarget;
+            return this;
+        }
+
+        /**
+         * @see BackNavigationInfo#getScreenshotSurface
+         */
+        public Builder setScreenshotSurface(@Nullable SurfaceControl screenshotSurface) {
+            mScreenshotSurface = screenshotSurface;
+            return this;
+        }
+
+        /**
+         * @see BackNavigationInfo#getScreenshotHardwareBuffer()
+         */
+        public Builder setScreenshotBuffer(@Nullable HardwareBuffer screenshotBuffer) {
+            mScreenshotBuffer = screenshotBuffer;
+            return this;
+        }
+
+        /**
+         * @see BackNavigationInfo#getTaskWindowConfiguration
+         */
+        public Builder setTaskWindowConfiguration(
+                @Nullable WindowConfiguration taskWindowConfiguration) {
+            mTaskWindowConfiguration = taskWindowConfiguration;
+            return this;
+        }
+
+        /**
+         * @see BackNavigationInfo#onBackNavigationFinished(boolean)
+         */
+        public Builder setOnBackNavigationDone(@Nullable RemoteCallback onBackNavigationDone) {
+            mOnBackNavigationDone = onBackNavigationDone;
+            return this;
+        }
+
+        /**
+         * @see BackNavigationInfo#getOnBackInvokedCallback
+         */
+        public Builder setOnBackInvokedCallback(
+                @Nullable IOnBackInvokedCallback onBackInvokedCallback) {
+            mOnBackInvokedCallback = onBackInvokedCallback;
+            return this;
+        }
+
+        /**
+         * Builds and returns an instance of {@link BackNavigationInfo}
+         */
+        public BackNavigationInfo build() {
+            return new BackNavigationInfo(mType, mDepartingAnimationTarget, mScreenshotSurface,
+                    mScreenshotBuffer, mTaskWindowConfiguration, mOnBackNavigationDone,
+                    mOnBackInvokedCallback);
+        }
+    }
 }
diff --git a/core/java/android/window/ProxyOnBackInvokedDispatcher.java b/core/java/android/window/ProxyOnBackInvokedDispatcher.java
index 2b2f5e9..eb77631 100644
--- a/core/java/android/window/ProxyOnBackInvokedDispatcher.java
+++ b/core/java/android/window/ProxyOnBackInvokedDispatcher.java
@@ -73,7 +73,7 @@
     public void unregisterOnBackInvokedCallback(
             @NonNull OnBackInvokedCallback callback) {
         if (DEBUG) {
-            Log.v(TAG, String.format("Pending unregister %s. Actual=%s", callback,
+            Log.v(TAG, String.format("Proxy unregister %s. Actual=%s", callback,
                     mActualDispatcherOwner));
         }
         synchronized (mLock) {
@@ -109,8 +109,8 @@
         OnBackInvokedDispatcher dispatcher =
                 mActualDispatcherOwner.getOnBackInvokedDispatcher();
         if (DEBUG) {
-            Log.v(TAG, String.format("Pending transferring %d callbacks to %s", mCallbacks.size(),
-                    dispatcher));
+            Log.v(TAG, String.format("Proxy: transferring %d pending callbacks to %s",
+                    mCallbacks.size(), dispatcher));
         }
         for (Pair<OnBackInvokedCallback, Integer> callbackPair : mCallbacks) {
             int priority = callbackPair.second;
@@ -144,7 +144,7 @@
      */
     public void reset() {
         if (DEBUG) {
-            Log.v(TAG, "Pending reset callbacks");
+            Log.v(TAG, "Proxy: reset callbacks");
         }
         synchronized (mLock) {
             mCallbacks.clear();
@@ -165,7 +165,7 @@
     public void setActualDispatcherOwner(
             @Nullable OnBackInvokedDispatcherOwner actualDispatcherOwner) {
         if (DEBUG) {
-            Log.v(TAG, String.format("Pending setActual %s. Current %s",
+            Log.v(TAG, String.format("Proxy setActual %s. Current %s",
                             actualDispatcherOwner, mActualDispatcherOwner));
         }
         synchronized (mLock) {
diff --git a/data/etc/services.core.protolog.json b/data/etc/services.core.protolog.json
index 99cb40a..6852c06 100644
--- a/data/etc/services.core.protolog.json
+++ b/data/etc/services.core.protolog.json
@@ -397,6 +397,12 @@
       "group": "WM_DEBUG_WINDOW_TRANSITIONS",
       "at": "com\/android\/server\/wm\/Transition.java"
     },
+    "-1717147904": {
+      "message": "Current focused window is embeddedWindow. Dispatch KEYCODE_BACK.",
+      "level": "DEBUG",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/BackNavigationController.java"
+    },
     "-1715268616": {
       "message": "Last window, removing starting window %s",
       "level": "VERBOSE",
@@ -1087,6 +1093,12 @@
       "group": "WM_DEBUG_STATES",
       "at": "com\/android\/server\/wm\/ActivityRecord.java"
     },
+    "-1010850753": {
+      "message": "No focused window, defaulting to top task's window",
+      "level": "WARN",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/BackNavigationController.java"
+    },
     "-1009117329": {
       "message": "isFetchingAppTransitionSpecs=true",
       "level": "VERBOSE",
@@ -1105,6 +1117,12 @@
       "group": "WM_DEBUG_STATES",
       "at": "com\/android\/server\/wm\/ActivityRecord.java"
     },
+    "-997565097": {
+      "message": "Focused window found using getFocusedWindowToken",
+      "level": "DEBUG",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/BackNavigationController.java"
+    },
     "-993378225": {
       "message": "finishDrawingLocked: mDrawState=COMMIT_DRAW_PENDING %s in %s",
       "level": "VERBOSE",
@@ -1327,6 +1345,12 @@
       "group": "WM_DEBUG_FOCUS_LIGHT",
       "at": "com\/android\/server\/wm\/ActivityRecord.java"
     },
+    "-767349300": {
+      "message": "%s: Setting back callback %s. Client IWindow %s",
+      "level": "DEBUG",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/WindowState.java"
+    },
     "-766059044": {
       "message": "Display id=%d selected orientation %s (%d), got rotation %s (%d)",
       "level": "VERBOSE",
@@ -1669,12 +1693,6 @@
       "group": "WM_DEBUG_STATES",
       "at": "com\/android\/server\/wm\/RootWindowContainer.java"
     },
-    "-432881038": {
-      "message": "startBackNavigation task=%s, topRunningActivity=%s, applicationBackCallback=%s, systemBackCallback=%s",
-      "level": "DEBUG",
-      "group": "WM_DEBUG_BACK_PREVIEW",
-      "at": "com\/android\/server\/wm\/BackNavigationController.java"
-    },
     "-415865166": {
       "message": "findFocusedWindow: Found new focus @ %s",
       "level": "VERBOSE",
@@ -1885,12 +1903,6 @@
       "group": "WM_DEBUG_SYNC_ENGINE",
       "at": "com\/android\/server\/wm\/BLASTSyncEngine.java"
     },
-    "-228813488": {
-      "message": "%s: Setting back callback %s",
-      "level": "DEBUG",
-      "group": "WM_DEBUG_BACK_PREVIEW",
-      "at": "com\/android\/server\/wm\/WindowState.java"
-    },
     "-208825711": {
       "message": "shouldWaitAnimatingExit: isWallpaperTarget: %s",
       "level": "DEBUG",
@@ -2359,6 +2371,12 @@
       "group": "WM_DEBUG_REMOTE_ANIMATIONS",
       "at": "com\/android\/server\/wm\/RemoteAnimationController.java"
     },
+    "250620778": {
+      "message": "startBackNavigation task=%s, topRunningActivity=%s, applicationBackCallback=%s, systemBackCallback=%s, currentFocus=%s",
+      "level": "DEBUG",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/BackNavigationController.java"
+    },
     "251812577": {
       "message": "Register display organizer=%s uid=%d",
       "level": "VERBOSE",
@@ -2677,6 +2695,12 @@
       "group": "WM_SHOW_TRANSACTIONS",
       "at": "com\/android\/server\/wm\/WindowContainerThumbnail.java"
     },
+    "531891870": {
+      "message": "Previous Destination is Activity:%s Task:%s removedContainer:%s, backType=%s",
+      "level": "DEBUG",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/BackNavigationController.java"
+    },
     "535103992": {
       "message": "Wallpaper may change!  Adjusting",
       "level": "VERBOSE",
@@ -2911,6 +2935,12 @@
       "group": "WM_DEBUG_LOCKTASK",
       "at": "com\/android\/server\/wm\/ActivityTaskManagerService.java"
     },
+    "716528224": {
+      "message": "Focused window found using wmService.getFocusedWindowLocked()",
+      "level": "DEBUG",
+      "group": "WM_DEBUG_BACK_PREVIEW",
+      "at": "com\/android\/server\/wm\/BackNavigationController.java"
+    },
     "726205185": {
       "message": "Moving to DESTROYED: %s (destroy skipped)",
       "level": "VERBOSE",
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/back/BackAnimationControllerTest.java b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/back/BackAnimationControllerTest.java
index 3e7ee25..05230a9 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/back/BackAnimationControllerTest.java
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/back/BackAnimationControllerTest.java
@@ -44,6 +44,7 @@
 import com.android.wm.shell.common.ShellExecutor;
 
 import org.junit.Before;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
@@ -109,6 +110,7 @@
     }
 
     @Test
+    @Ignore("b/207481538")
     public void crossActivity_screenshotAttachedAndVisible() {
         SurfaceControl screenshotSurface = new SurfaceControl();
         HardwareBuffer hardwareBuffer = mock(HardwareBuffer.class);
diff --git a/services/core/java/com/android/server/wm/ActivityTaskManagerService.java b/services/core/java/com/android/server/wm/ActivityTaskManagerService.java
index ad6f354..7746dfd 100644
--- a/services/core/java/com/android/server/wm/ActivityTaskManagerService.java
+++ b/services/core/java/com/android/server/wm/ActivityTaskManagerService.java
@@ -1851,7 +1851,7 @@
         if (mBackNavigationController == null) {
             return null;
         }
-        return mBackNavigationController.startBackNavigation(getTopDisplayFocusedRootTask());
+        return mBackNavigationController.startBackNavigation(mWindowManager);
     }
 
     /**
diff --git a/services/core/java/com/android/server/wm/BackNavigationController.java b/services/core/java/com/android/server/wm/BackNavigationController.java
index dbc0141..ef0b737 100644
--- a/services/core/java/com/android/server/wm/BackNavigationController.java
+++ b/services/core/java/com/android/server/wm/BackNavigationController.java
@@ -26,6 +26,7 @@
 import android.graphics.Rect;
 import android.hardware.HardwareBuffer;
 import android.os.Bundle;
+import android.os.IBinder;
 import android.os.RemoteCallback;
 import android.os.RemoteException;
 import android.os.SystemProperties;
@@ -38,6 +39,7 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.protolog.common.ProtoLog;
+import com.android.server.LocalServices;
 
 /**
  * Controller to handle actions related to the back gesture on the server side.
@@ -73,24 +75,24 @@
      * Set up the necessary leashes and build a {@link BackNavigationInfo} instance for an upcoming
      * back gesture animation.
      *
-     * @param task the currently focused {@link Task}.
      * @return a {@link BackNavigationInfo} instance containing the required leashes and metadata
-     * for the animation.
+     * for the animation, or null if we don't know how to animate the current window and need to
+     * fallback on dispatching the key event.
      */
     @Nullable
-    BackNavigationInfo startBackNavigation(@NonNull Task task) {
-        return startBackNavigation(task, null);
+    BackNavigationInfo startBackNavigation(@NonNull WindowManagerService wmService) {
+        return startBackNavigation(wmService, null);
     }
 
     /**
      * @param tx, a transaction to be used for the attaching the animation leash.
      *            This is used in tests. If null, the object will be initialized with a new {@link
-     *            android.view.SurfaceControl.Transaction}
-     * @see #startBackNavigation(Task)
+     *            SurfaceControl.Transaction}
+     * @see #startBackNavigation(WindowManagerService)
      */
     @VisibleForTesting
     @Nullable
-    BackNavigationInfo startBackNavigation(@NonNull Task task,
+    BackNavigationInfo startBackNavigation(WindowManagerService wmService,
             @Nullable SurfaceControl.Transaction tx) {
 
         if (tx == null) {
@@ -98,88 +100,125 @@
         }
 
         int backType = BackNavigationInfo.TYPE_UNDEFINED;
-        Task prevTask = task;
+        Task prevTask = null;
         ActivityRecord prev;
-        WindowContainer<?> removedWindowContainer;
-        ActivityRecord activityRecord;
+        WindowContainer<?> removedWindowContainer = null;
+        ActivityRecord activityRecord = null;
         ActivityRecord prevTaskTopActivity = null;
-        SurfaceControl animationLeashParent;
-        WindowConfiguration taskWindowConfiguration;
+        Task task = null;
+        SurfaceControl animationLeashParent = null;
         HardwareBuffer screenshotBuffer = null;
-        SurfaceControl screenshotSurface;
+        RemoteAnimationTarget topAppTarget = null;
         int prevTaskId;
         int prevUserId;
-        RemoteAnimationTarget topAppTarget;
-        SurfaceControl animLeash;
-        IOnBackInvokedCallback applicationCallback = null;
-        IOnBackInvokedCallback systemCallback = null;
 
-        synchronized (task.mWmService.mGlobalLock) {
+        BackNavigationInfo.Builder infoBuilder = new BackNavigationInfo.Builder();
+        synchronized (wmService.mGlobalLock) {
+            WindowState window;
+            WindowConfiguration taskWindowConfiguration;
+            WindowManagerInternal windowManagerInternal =
+                    LocalServices.getService(WindowManagerInternal.class);
+            IBinder focusedWindowToken = windowManagerInternal.getFocusedWindowToken();
 
-            // TODO Temp workaround for Sysui until b/221071505 is fixed
-            WindowState window = task.mWmService.getFocusedWindowLocked();
+            window = wmService.windowForClientLocked(null, focusedWindowToken,
+                    false /* throwOnError */);
+
             if (window == null) {
-                activityRecord = task.topRunningActivity();
-                removedWindowContainer = activityRecord;
-                taskWindowConfiguration = task.getTaskInfo().configuration.windowConfiguration;
-                window = task.getWindow(WindowState::isFocused);
-            } else {
-                activityRecord = window.mActivityRecord;
-                removedWindowContainer = activityRecord;
-                taskWindowConfiguration = window.getWindowConfiguration();
-            }
-            if (window != null) {
-                applicationCallback = window.getApplicationOnBackInvokedCallback();
-                systemCallback = window.getSystemOnBackInvokedCallback();
-            }
-            if (applicationCallback == null && systemCallback == null) {
-                // Return null when either there's no window, or apps have just initialized and
-                // have not finished registering callbacks.
-                return null;
-            }
-
-            ProtoLog.d(WM_DEBUG_BACK_PREVIEW, "startBackNavigation task=%s, "
-                            + "topRunningActivity=%s, applicationBackCallback=%s, "
-                            + "systemBackCallback=%s",
-                    task, activityRecord, applicationCallback, systemCallback);
-
-            // TODO Temp workaround for Sysui until b/221071505 is fixed
-            if (activityRecord == null && applicationCallback != null) {
-                return new BackNavigationInfo(BackNavigationInfo.TYPE_CALLBACK,
-                        null /* topWindowLeash */, null /* screenshotSurface */,
-                        null /* screenshotBuffer */, null /* taskWindowConfiguration */,
-                        null /* onBackNavigationDone */,
-                        applicationCallback /* onBackInvokedCallback */);
-            }
-
-            // For IME and Home, either a callback is registered, or we do nothing. In both cases,
-            // we don't need to pass the leashes below.
-            if (activityRecord == null || task.getDisplayContent().getImeContainer().isVisible()
-                    || activityRecord.isActivityTypeHome()) {
-                if (applicationCallback != null) {
-                    return new BackNavigationInfo(BackNavigationInfo.TYPE_CALLBACK,
-                            null /* topWindowLeash */, null /* screenshotSurface */,
-                            null /* screenshotBuffer */, null /* taskWindowConfiguration */,
-                            null /* onBackNavigationDone */,
-                            applicationCallback /* onBackInvokedCallback */);
-                } else {
+                EmbeddedWindowController.EmbeddedWindow embeddedWindow =
+                        wmService.mEmbeddedWindowController.getByFocusToken(focusedWindowToken);
+                if (embeddedWindow != null) {
+                    ProtoLog.d(WM_DEBUG_BACK_PREVIEW,
+                            "Current focused window is embeddedWindow. Dispatch KEYCODE_BACK.");
                     return null;
                 }
             }
 
-            prev = task.getActivity(
-                    (r) -> !r.finishing && r.getTask() == task && !r.isTopRunningActivity());
+            // Lets first gather the states of things
+            //  - What is our current window ?
+            //  - Does it has an Activity and a Task ?
+            // TODO Temp workaround for Sysui until b/221071505 is fixed
+            if (window != null) {
+                ProtoLog.d(WM_DEBUG_BACK_PREVIEW,
+                        "Focused window found using getFocusedWindowToken");
+            }
 
-            if (applicationCallback != null) {
-                return new BackNavigationInfo(BackNavigationInfo.TYPE_CALLBACK,
-                        null /* topWindowLeash */, null /* screenshotSurface */,
-                        null /* screenshotBuffer */, null /* taskWindowConfiguration */,
-                        null /* onBackNavigationDone */,
-                        applicationCallback /* onBackInvokedCallback */);
+            if (window == null) {
+                window = wmService.getFocusedWindowLocked();
+                ProtoLog.d(WM_DEBUG_BACK_PREVIEW,
+                        "Focused window found using wmService.getFocusedWindowLocked()");
+            }
+
+            if (window == null) {
+                // We don't have any focused window, fallback ont the top task of the focused
+                // display.
+                ProtoLog.w(WM_DEBUG_BACK_PREVIEW,
+                        "No focused window, defaulting to top task's window");
+                task = wmService.mAtmService.getTopDisplayFocusedRootTask();
+                window = task.getWindow(WindowState::isFocused);
+            }
+
+            // Now let's find if this window has a callback from the client side.
+            IOnBackInvokedCallback applicationCallback = null;
+            IOnBackInvokedCallback systemCallback = null;
+            if (window != null) {
+                activityRecord = window.mActivityRecord;
+                task = window.getTask();
+                applicationCallback = window.getApplicationOnBackInvokedCallback();
+                if (applicationCallback != null) {
+                    backType = BackNavigationInfo.TYPE_CALLBACK;
+                    infoBuilder.setOnBackInvokedCallback(applicationCallback);
+                } else {
+                    systemCallback = window.getSystemOnBackInvokedCallback();
+                    infoBuilder.setOnBackInvokedCallback(systemCallback);
+                }
+            }
+
+            ProtoLog.d(WM_DEBUG_BACK_PREVIEW, "startBackNavigation task=%s, "
+                            + "topRunningActivity=%s, applicationBackCallback=%s, "
+                            + "systemBackCallback=%s, currentFocus=%s",
+                    task, activityRecord, applicationCallback, systemCallback, window);
+
+            if (window == null) {
+                Slog.e(TAG, "Window is null, returning null.");
+                return null;
+            }
+
+            if (systemCallback == null && applicationCallback == null) {
+                Slog.e(TAG, "No callback registered, returning null.");
+                return null;
+            }
+
+            // If we don't need to set up the animation, we return early. This is the case when
+            // - We have an application callback.
+            // - We don't have any ActivityRecord or Task to animate.
+            // - The IME is opened, and we just need to close it.
+            // - The home activity is the focused activity.
+            if (backType == BackNavigationInfo.TYPE_CALLBACK
+                    || activityRecord == null
+                    || task == null
+                    || task.getDisplayContent().getImeContainer().isVisible()
+                    || activityRecord.isActivityTypeHome()) {
+                return infoBuilder
+                        .setType(backType)
+                        .build();
+            }
+
+            // We don't have an application callback, let's find the destination of the back gesture
+            Task finalTask = task;
+            prev = task.getActivity(
+                    (r) -> !r.finishing && r.getTask() == finalTask && !r.isTopRunningActivity());
+            if (window.getParent().getChildCount() > 1 && window.getParent().getChildAt(0)
+                    != window) {
+                // Are we the top window of our parent? If not, we are a window on top of the
+                // activity, we won't close the activity.
+                backType = BackNavigationInfo.TYPE_DIALOG_CLOSE;
+                removedWindowContainer = window;
             } else if (prev != null) {
+                // We have another Activity in the same task to go to
                 backType = BackNavigationInfo.TYPE_CROSS_ACTIVITY;
+                removedWindowContainer = activityRecord;
             } else if (task.returnsToHomeRootTask()) {
-                prevTask = null;
+                // Our Task should bring back to home
                 removedWindowContainer = task;
                 backType = BackNavigationInfo.TYPE_RETURN_TO_HOME;
             } else if (activityRecord.isRootOfTask()) {
@@ -195,12 +234,42 @@
                     backType = BackNavigationInfo.TYPE_CROSS_TASK;
                 }
             }
+            infoBuilder.setType(backType);
 
             prevTaskId = prevTask != null ? prevTask.mTaskId : 0;
             prevUserId = prevTask != null ? prevTask.mUserId : 0;
 
-            ProtoLog.d(WM_DEBUG_BACK_PREVIEW, "Previous Activity is %s. "
-                    + "Back type is %s", prev != null ? prev.mActivityComponent : null, backType);
+            ProtoLog.d(WM_DEBUG_BACK_PREVIEW, "Previous Destination is Activity:%s Task:%s "
+                            + "removedContainer:%s, backType=%s",
+                    prev != null ? prev.mActivityComponent : null,
+                    prevTask != null ? prevTask.getName() : null,
+                    removedWindowContainer,
+                    BackNavigationInfo.typeToString(backType));
+
+            // For now, we only animate when going home.
+            boolean prepareAnimation = backType == BackNavigationInfo.TYPE_RETURN_TO_HOME
+                    // Only create a new leash if no leash has been created.
+                    // Otherwise return null for animation target to avoid conflict.
+                    && !removedWindowContainer.hasCommittedReparentToAnimationLeash();
+
+            if (prepareAnimation) {
+                taskWindowConfiguration = task.getTaskInfo().configuration.windowConfiguration;
+
+                infoBuilder.setTaskWindowConfiguration(taskWindowConfiguration);
+                // Prepare a leash to animate the current top window
+                // TODO(b/220934562): Use surface animator to better manage animation conflicts.
+                SurfaceControl animLeash = removedWindowContainer.makeAnimationLeash()
+                        .setName("BackPreview Leash for " + removedWindowContainer)
+                        .setHidden(false)
+                        .setBLASTLayer()
+                        .build();
+                removedWindowContainer.reparentSurfaceControl(tx, animLeash);
+                animationLeashParent = removedWindowContainer.getAnimationLeashParent();
+                topAppTarget = createRemoteAnimationTargetLocked(removedWindowContainer,
+                        activityRecord,
+                        task, animLeash);
+                infoBuilder.setDepartingAnimationTarget(topAppTarget);
+            }
 
             //TODO(207481538) Remove once the infrastructure to support per-activity screenshot is
             // implemented. For now we simply have the mBackScreenshots hash map that dumbly
@@ -209,101 +278,85 @@
                 screenshotBuffer = getActivitySnapshot(task, prev.mActivityComponent);
             }
 
-            // Only create a new leash if no leash has been created.
-            // Otherwise return null for animation target to avoid conflict.
-            if (removedWindowContainer.hasCommittedReparentToAnimationLeash()) {
+            if (backType == BackNavigationInfo.TYPE_RETURN_TO_HOME && isAnimationEnabled()) {
+                task.mBackGestureStarted = true;
+                // Make launcher show from behind by marking its top activity as visible and
+                // launch-behind to bump its visibility for the duration of the back gesture.
+                prevTaskTopActivity = prevTask.getTopNonFinishingActivity();
+                if (prevTaskTopActivity != null) {
+                    if (!prevTaskTopActivity.mVisibleRequested) {
+                        prevTaskTopActivity.setVisibility(true);
+                    }
+                    prevTaskTopActivity.mLaunchTaskBehind = true;
+                    ProtoLog.d(WM_DEBUG_BACK_PREVIEW,
+                            "Setting Activity.mLauncherTaskBehind to true. Activity=%s",
+                            prevTaskTopActivity);
+                    prevTaskTopActivity.mRootWindowContainer.ensureActivitiesVisible(
+                            null /* starting */, 0 /* configChanges */,
+                            false /* preserveWindows */);
+                }
+            }
+        } // Release wm Lock
+
+        // Find a screenshot of the previous activity if we actually have an animation
+        if (topAppTarget != null && needsScreenshot(backType) && prevTask != null
+                && screenshotBuffer == null) {
+            SurfaceControl.Builder builder = new SurfaceControl.Builder()
+                    .setName("BackPreview Screenshot for " + prev)
+                    .setParent(animationLeashParent)
+                    .setHidden(false)
+                    .setBLASTLayer();
+            infoBuilder.setScreenshotSurface(builder.build());
+            screenshotBuffer = getTaskSnapshot(prevTaskId, prevUserId);
+            infoBuilder.setScreenshotBuffer(screenshotBuffer);
+
+
+            // The Animation leash needs to be above the screenshot surface, but the animation leash
+            // needs to be added before to be in the synchronized block.
+            tx.setLayer(topAppTarget.leash, 1);
+            tx.apply();
+
+
+            WindowContainer<?> finalRemovedWindowContainer = removedWindowContainer;
+            try {
+                activityRecord.token.linkToDeath(
+                        () -> resetSurfaces(finalRemovedWindowContainer), 0);
+            } catch (RemoteException e) {
+                Slog.e(TAG, "Failed to link to death", e);
+                resetSurfaces(removedWindowContainer);
                 return null;
             }
-            // Prepare a leash to animate the current top window
-            // TODO(b/220934562): Use surface animator to better manage animation conflicts.
-            animLeash = removedWindowContainer.makeAnimationLeash()
-                    .setName("BackPreview Leash for " + removedWindowContainer)
-                    .setHidden(false)
-                    .setBLASTLayer()
-                    .build();
-            removedWindowContainer.reparentSurfaceControl(tx, animLeash);
-            animationLeashParent = removedWindowContainer.getAnimationLeashParent();
-            topAppTarget = new RemoteAnimationTarget(
-                    task.mTaskId,
-                    RemoteAnimationTarget.MODE_CLOSING,
-                    animLeash,
-                    false /* isTransluscent */,
-                    new Rect() /* clipRect */,
-                    new Rect() /* contentInsets */,
-                    activityRecord.getPrefixOrderIndex(),
-                    new Point(0, 0) /* position */,
-                    new Rect() /* localBounds */,
-                    new Rect() /* screenSpaceBounds */,
-                    removedWindowContainer.getWindowConfiguration(),
-                    true /* isNotInRecent */,
-                    null,
-                    null,
-                    task.getTaskInfo(),
-                    false,
-                    activityRecord.windowType);
+
+            RemoteCallback onBackNavigationDone = new RemoteCallback(
+                    result -> resetSurfaces(finalRemovedWindowContainer
+                    ));
+            infoBuilder.setOnBackNavigationDone(onBackNavigationDone);
         }
+        return infoBuilder.build();
+    }
 
-        screenshotSurface = new SurfaceControl.Builder()
-                .setName("BackPreview Screenshot for " + prev)
-                .setParent(animationLeashParent)
-                .setHidden(false)
-                .setBLASTLayer()
-                .build();
-        if (backType == BackNavigationInfo.TYPE_RETURN_TO_HOME && isAnimationEnabled()) {
-            task.mBackGestureStarted = true;
-            // Make launcher show from behind by marking its top activity as visible and
-            // launch-behind to bump its visibility for the duration of the back gesture.
-            prevTaskTopActivity = prevTask.getTopNonFinishingActivity();
-            if (prevTaskTopActivity != null) {
-                if (!prevTaskTopActivity.mVisibleRequested) {
-                    prevTaskTopActivity.setVisibility(true);
-                }
-                prevTaskTopActivity.mLaunchTaskBehind = true;
-                ProtoLog.d(WM_DEBUG_BACK_PREVIEW,
-                        "Setting Activity.mLauncherTaskBehind to true. Activity=%s",
-                        prevTaskTopActivity);
-                prevTaskTopActivity.mRootWindowContainer.ensureActivitiesVisible(
-                        null /* starting */, 0 /* configChanges */,
-                        false /* preserveWindows */);
-            }
-        }
-
-        // Find a screenshot of the previous activity
-
-        if (needsScreenshot(backType) && prevTask != null) {
-            if (screenshotBuffer == null) {
-                screenshotBuffer = getTaskSnapshot(prevTaskId, prevUserId);
-            }
-        }
-
-        // The Animation leash needs to be above the screenshot surface, but the animation leash
-        // needs to be added before to be in the synchronized block.
-        tx.setLayer(topAppTarget.leash, 1);
-        tx.apply();
-
-        WindowContainer<?> finalRemovedWindowContainer = removedWindowContainer;
-        try {
-            activityRecord.token.linkToDeath(() -> resetSurfaces(finalRemovedWindowContainer), 0);
-        } catch (RemoteException e) {
-            Slog.e(TAG, "Failed to link to death", e);
-            resetSurfaces(removedWindowContainer);
-            return null;
-        }
-
-        int finalBackType = backType;
-        final IOnBackInvokedCallback callback =
-                applicationCallback != null ? applicationCallback : systemCallback;
-        ActivityRecord finalPrevTaskTopActivity = prevTaskTopActivity;
-        RemoteCallback onBackNavigationDone = new RemoteCallback(result -> onBackNavigationDone(
-                result, finalRemovedWindowContainer, finalBackType, task,
-                finalPrevTaskTopActivity));
-        return new BackNavigationInfo(backType,
-                topAppTarget,
-                screenshotSurface,
-                screenshotBuffer,
-                taskWindowConfiguration,
-                onBackNavigationDone,
-                callback);
+    @NonNull
+    private static RemoteAnimationTarget createRemoteAnimationTargetLocked(
+            WindowContainer<?> removedWindowContainer,
+            ActivityRecord activityRecord, Task task, SurfaceControl animLeash) {
+        return new RemoteAnimationTarget(
+                task.mTaskId,
+                RemoteAnimationTarget.MODE_CLOSING,
+                animLeash,
+                false /* isTransluscent */,
+                new Rect() /* clipRect */,
+                new Rect() /* contentInsets */,
+                activityRecord.getPrefixOrderIndex(),
+                new Point(0, 0) /* position */,
+                new Rect() /* localBounds */,
+                new Rect() /* screenSpaceBounds */,
+                removedWindowContainer.getWindowConfiguration(),
+                true /* isNotInRecent */,
+                null,
+                null,
+                task.getTaskInfo(),
+                false,
+                activityRecord.windowType);
     }
 
     private void onBackNavigationDone(
@@ -360,6 +413,9 @@
     }
 
     private boolean needsScreenshot(int backType) {
+        if (!isScreenshotEnabled()) {
+            return false;
+        }
         switch (backType) {
             case BackNavigationInfo.TYPE_RETURN_TO_HOME:
             case BackNavigationInfo.TYPE_DIALOG_CLOSE:
diff --git a/services/core/java/com/android/server/wm/Session.java b/services/core/java/com/android/server/wm/Session.java
index 9ad25ac8..c4b7c56 100644
--- a/services/core/java/com/android/server/wm/Session.java
+++ b/services/core/java/com/android/server/wm/Session.java
@@ -917,10 +917,11 @@
             IOnBackInvokedCallback onBackInvokedCallback,
             @OnBackInvokedDispatcher.Priority int priority) throws RemoteException {
         synchronized (mService.mGlobalLock) {
-            WindowState windowState = mService.windowForClientLocked(this, window, false);
+            WindowState windowState = mService.windowForClientLocked(this, window, true);
             if (windowState == null) {
                 Slog.e(TAG_WM,
-                        "setOnBackInvokedCallback(): Can't find window state for window:" + window);
+                        "setOnBackInvokedCallback(): Can't find window state for package:"
+                                + mPackageName);
             } else {
                 windowState.setOnBackInvokedCallback(onBackInvokedCallback, priority);
             }
diff --git a/services/core/java/com/android/server/wm/WindowState.java b/services/core/java/com/android/server/wm/WindowState.java
index 0ca1058..532abce 100644
--- a/services/core/java/com/android/server/wm/WindowState.java
+++ b/services/core/java/com/android/server/wm/WindowState.java
@@ -1125,8 +1125,8 @@
      */
     void setOnBackInvokedCallback(
             @Nullable IOnBackInvokedCallback onBackInvokedCallback, int priority) {
-        ProtoLog.d(WM_DEBUG_BACK_PREVIEW, "%s: Setting back callback %s",
-                this, onBackInvokedCallback);
+        ProtoLog.d(WM_DEBUG_BACK_PREVIEW, "%s: Setting back callback %s. Client IWindow %s",
+                this, onBackInvokedCallback, mClient);
         if (priority >= 0) {
             mApplicationOnBackInvokedCallback = onBackInvokedCallback;
             mSystemOnBackInvokedCallback = null;
diff --git a/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java b/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
index 92550a3..f44de1e 100644
--- a/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
@@ -21,21 +21,27 @@
 import static android.window.BackNavigationInfo.typeToString;
 
 import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.truth.Truth.assertWithMessage;
 
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.hardware.HardwareBuffer;
 import android.platform.test.annotations.Presubmit;
+import android.view.WindowManager;
 import android.window.BackEvent;
 import android.window.BackNavigationInfo;
 import android.window.IOnBackInvokedCallback;
 import android.window.OnBackInvokedDispatcher;
 import android.window.TaskSnapshot;
 
+import com.android.server.LocalServices;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -45,22 +51,27 @@
 public class BackNavigationControllerTests extends WindowTestsBase {
 
     private BackNavigationController mBackNavigationController;
-    private IOnBackInvokedCallback mOnBackInvokedCallback;
+    private WindowManagerInternal mWindowManagerInternal;
 
     @Before
     public void setUp() throws Exception {
         mBackNavigationController = new BackNavigationController();
-        mOnBackInvokedCallback = createBackCallback();
+        LocalServices.removeServiceForTest(WindowManagerInternal.class);
+        mWindowManagerInternal = mock(WindowManagerInternal.class);
+        LocalServices.addService(WindowManagerInternal.class, mWindowManagerInternal);
+        TaskSnapshotController taskSnapshotController = createMockTaskSnapshotController();
+        mBackNavigationController.setTaskSnapshotController(taskSnapshotController);
     }
 
     @Test
-    public void backTypeHomeWhenBackToLauncher() {
-        Task task = createTopTaskWithActivity();
-        registerSystemOnBackInvokedCallback();
+    public void backNavInfo_HomeWhenBackToLauncher() {
+        IOnBackInvokedCallback callback = withSystemCallback(createTopTaskWithActivity());
 
-        BackNavigationInfo backNavigationInfo =
-                mBackNavigationController.startBackNavigation(task, new StubTransaction());
-        assertThat(backNavigationInfo).isNotNull();
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
+        assertWithMessage("BackNavigationInfo").that(backNavigationInfo).isNotNull();
+        assertThat(backNavigationInfo.getDepartingAnimationTarget()).isNotNull();
+        assertThat(backNavigationInfo.getTaskWindowConfiguration()).isNotNull();
+        assertThat(backNavigationInfo.getOnBackInvokedCallback()).isEqualTo(callback);
         assertThat(typeToString(backNavigationInfo.getType()))
                 .isEqualTo(typeToString(BackNavigationInfo.TYPE_RETURN_TO_HOME));
     }
@@ -69,12 +80,9 @@
     public void backTypeCrossTaskWhenBackToPreviousTask() {
         Task taskA = createTask(mDefaultDisplay);
         createActivityRecord(taskA);
-        Task task = createTopTaskWithActivity();
-        registerSystemOnBackInvokedCallback();
-
-        BackNavigationInfo backNavigationInfo =
-                mBackNavigationController.startBackNavigation(task, new StubTransaction());
-        assertThat(backNavigationInfo).isNotNull();
+        withSystemCallback(createTopTaskWithActivity());
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
+        assertWithMessage("BackNavigationInfo").that(backNavigationInfo).isNotNull();
         assertThat(typeToString(backNavigationInfo.getType()))
                 .isEqualTo(typeToString(BackNavigationInfo.TYPE_CROSS_TASK));
     }
@@ -82,48 +90,49 @@
     @Test
     public void backTypeCrossActivityWhenBackToPreviousActivity() {
         Task task = createTopTaskWithActivity();
-        mAtm.setFocusedTask(task.mTaskId,
-                createAppWindow(task, FIRST_APPLICATION_WINDOW, "window").mActivityRecord);
-        registerSystemOnBackInvokedCallback();
-
-        BackNavigationInfo backNavigationInfo =
-                mBackNavigationController.startBackNavigation(task, new StubTransaction());
-        assertThat(backNavigationInfo).isNotNull();
+        WindowState window = createAppWindow(task, FIRST_APPLICATION_WINDOW, "window");
+        addToWindowMap(window, true);
+        IOnBackInvokedCallback callback = createOnBackInvokedCallback();
+        window.setOnBackInvokedCallback(callback, OnBackInvokedDispatcher.PRIORITY_SYSTEM);
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
+        assertWithMessage("BackNavigationInfo").that(backNavigationInfo).isNotNull();
         assertThat(typeToString(backNavigationInfo.getType()))
                 .isEqualTo(typeToString(BackNavigationInfo.TYPE_CROSS_ACTIVITY));
+        assertWithMessage("Activity callback").that(
+                backNavigationInfo.getOnBackInvokedCallback()).isEqualTo(callback);
+
+        // Until b/207481538 is implemented, this should be null
+        assertThat(backNavigationInfo.getScreenshotSurface()).isNull();
+        assertThat(backNavigationInfo.getScreenshotHardwareBuffer()).isNull();
     }
 
-    /**
-     * Checks that we are able to fill all the field of the {@link BackNavigationInfo} object.
-     */
     @Test
-    public void backNavInfoFullyPopulated() {
-        Task task = createTopTaskWithActivity();
-        createAppWindow(task, FIRST_APPLICATION_WINDOW, "window");
-        registerSystemOnBackInvokedCallback();
+    public void backInfoWithNullWindow() {
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
+        assertThat(backNavigationInfo).isNull();
+    }
 
-        // We need a mock screenshot so
-        TaskSnapshotController taskSnapshotController = createMockTaskSnapshotController();
+    @Test
+    public void backInfoWindowWithNoActivity() {
+        WindowState window = createWindow(null, WindowManager.LayoutParams.TYPE_WALLPAPER,
+                "Wallpaper");
+        addToWindowMap(window, true);
 
-        mBackNavigationController.setTaskSnapshotController(taskSnapshotController);
+        IOnBackInvokedCallback callback = createOnBackInvokedCallback();
+        window.setOnBackInvokedCallback(callback, OnBackInvokedDispatcher.PRIORITY_DEFAULT);
 
-        BackNavigationInfo backNavigationInfo =
-                mBackNavigationController.startBackNavigation(task, new StubTransaction());
-        assertThat(backNavigationInfo).isNotNull();
-        assertThat(backNavigationInfo.getDepartingAnimationTarget()).isNotNull();
-        assertThat(backNavigationInfo.getScreenshotSurface()).isNotNull();
-        assertThat(backNavigationInfo.getScreenshotHardwareBuffer()).isNotNull();
-        assertThat(backNavigationInfo.getTaskWindowConfiguration()).isNotNull();
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
+        assertWithMessage("BackNavigationInfo").that(backNavigationInfo).isNotNull();
+        assertThat(backNavigationInfo.getType()).isEqualTo(BackNavigationInfo.TYPE_CALLBACK);
+        assertThat(backNavigationInfo.getOnBackInvokedCallback()).isEqualTo(callback);
     }
 
     @Test
     public void preparesForBackToHome() {
         Task task = createTopTaskWithActivity();
-        ActivityRecord activity = task.getTopActivity(false, false);
-        registerSystemOnBackInvokedCallback();
+        withSystemCallback(task);
 
-        BackNavigationInfo backNavigationInfo =
-                mBackNavigationController.startBackNavigation(task, new StubTransaction());
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
         assertThat(typeToString(backNavigationInfo.getType()))
                 .isEqualTo(typeToString(BackNavigationInfo.TYPE_RETURN_TO_HOME));
     }
@@ -131,13 +140,52 @@
     @Test
     public void backTypeCallback() {
         Task task = createTopTaskWithActivity();
-        ActivityRecord activity = task.getTopActivity(false, false);
-        registerApplicationOnBackInvokedCallback();
+        IOnBackInvokedCallback appCallback = withAppCallback(task);
 
-        BackNavigationInfo backNavigationInfo =
-                mBackNavigationController.startBackNavigation(task, new StubTransaction());
+        BackNavigationInfo backNavigationInfo = startBackNavigation();
         assertThat(typeToString(backNavigationInfo.getType()))
                 .isEqualTo(typeToString(BackNavigationInfo.TYPE_CALLBACK));
+        assertThat(backNavigationInfo.getOnBackInvokedCallback()).isEqualTo(appCallback);
+    }
+
+    private IOnBackInvokedCallback withSystemCallback(Task task) {
+        IOnBackInvokedCallback callback = createOnBackInvokedCallback();
+        task.getTopMostActivity().getTopChild().setOnBackInvokedCallback(callback,
+                OnBackInvokedDispatcher.PRIORITY_SYSTEM);
+        return callback;
+    }
+
+    private IOnBackInvokedCallback withAppCallback(Task task) {
+        IOnBackInvokedCallback callback = createOnBackInvokedCallback();
+        task.getTopMostActivity().getTopChild().setOnBackInvokedCallback(callback,
+                OnBackInvokedDispatcher.PRIORITY_DEFAULT);
+        return callback;
+    }
+
+    @Nullable
+    private BackNavigationInfo startBackNavigation() {
+        return mBackNavigationController.startBackNavigation(mWm, new StubTransaction());
+    }
+
+    @NonNull
+    private IOnBackInvokedCallback createOnBackInvokedCallback() {
+        return new IOnBackInvokedCallback.Stub() {
+            @Override
+            public void onBackStarted() {
+            }
+
+            @Override
+            public void onBackProgressed(BackEvent backEvent) {
+            }
+
+            @Override
+            public void onBackCancelled() {
+            }
+
+            @Override
+            public void onBackInvoked() {
+            }
+        };
     }
 
     @NonNull
@@ -157,35 +205,18 @@
         // enable OnBackInvokedCallbacks
         record.info.applicationInfo.privateFlagsExt |=
                 PRIVATE_FLAG_EXT_ENABLE_ON_BACK_INVOKED_CALLBACK;
-        createWindow(null, FIRST_APPLICATION_WINDOW, record, "window");
+        WindowState window = createWindow(null, FIRST_APPLICATION_WINDOW, record, "window");
         when(record.mSurfaceControl.isValid()).thenReturn(true);
         mAtm.setFocusedTask(task.mTaskId, record);
+        addToWindowMap(window, true);
         return task;
     }
 
-    private void registerSystemOnBackInvokedCallback() {
-        mWm.getFocusedWindowLocked().setOnBackInvokedCallback(
-                mOnBackInvokedCallback, OnBackInvokedDispatcher.PRIORITY_SYSTEM);
-    }
-
-    private void registerApplicationOnBackInvokedCallback() {
-        mWm.getFocusedWindowLocked().setOnBackInvokedCallback(
-                mOnBackInvokedCallback, OnBackInvokedDispatcher.PRIORITY_DEFAULT);
-    }
-
-    private IOnBackInvokedCallback createBackCallback() {
-        return new IOnBackInvokedCallback.Stub() {
-            @Override
-            public void onBackStarted() { }
-
-            @Override
-            public void onBackProgressed(BackEvent backEvent) { }
-
-            @Override
-            public void onBackCancelled() { }
-
-            @Override
-            public void onBackInvoked() { }
-        };
+    private void addToWindowMap(WindowState window, boolean focus) {
+        mWm.mWindowMap.put(window.mClient.asBinder(), window);
+        if (focus) {
+            doReturn(window.getWindowInfo().token)
+                    .when(mWindowManagerInternal).getFocusedWindowToken();
+        }
     }
 }