Merge "[KA02.5] Use binder thread and executor to invoke callback" am: f9ae70a41c am: 1ed9d716fc
am: d3b8a5c332

Change-Id: I881c1ab09187ab23facc03bb71cc38a7978e442c
diff --git a/Android.bp b/Android.bp
index 5f81191..f4e8b63 100644
--- a/Android.bp
+++ b/Android.bp
@@ -225,6 +225,7 @@
         "core/java/android/net/INetworkScoreService.aidl",
         "core/java/android/net/INetworkStatsService.aidl",
         "core/java/android/net/INetworkStatsSession.aidl",
+        "core/java/android/net/ISocketKeepaliveCallback.aidl",
         "core/java/android/net/ITestNetworkManager.aidl",
         "core/java/android/net/ITetheringEventCallback.aidl",
         "core/java/android/net/ITetheringStatsProvider.aidl",
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index 2a357ff..d08379f 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -38,7 +38,6 @@
 import android.os.Build.VERSION_CODES;
 import android.os.Bundle;
 import android.os.Handler;
-import android.os.HandlerThread;
 import android.os.IBinder;
 import android.os.INetworkActivityListener;
 import android.os.INetworkManagementService;
@@ -75,6 +74,9 @@
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.RejectedExecutionException;
 
 /**
  * Class that answers queries about the state of network connectivity. It also
@@ -1813,23 +1815,26 @@
         public static final int MIN_INTERVAL = 10;
 
         private final Network mNetwork;
-        private final PacketKeepaliveCallback mCallback;
-        private final Looper mLooper;
-        private final Messenger mMessenger;
+        private final ISocketKeepaliveCallback mCallback;
+        private final ExecutorService mExecutor;
 
         private volatile Integer mSlot;
 
-        void stopLooper() {
-            mLooper.quit();
-        }
-
         @UnsupportedAppUsage
         public void stop() {
             try {
-                mService.stopKeepalive(mNetwork, mSlot);
-            } catch (RemoteException e) {
-                Log.e(TAG, "Error stopping packet keepalive: ", e);
-                stopLooper();
+                mExecutor.execute(() -> {
+                    try {
+                        if (mSlot != null) {
+                            mService.stopKeepalive(mNetwork, mSlot);
+                        }
+                    } catch (RemoteException e) {
+                        Log.e(TAG, "Error stopping packet keepalive: ", e);
+                        throw e.rethrowFromSystemServer();
+                    }
+                });
+            } catch (RejectedExecutionException e) {
+                // The internal executor has already stopped due to previous event.
             }
         }
 
@@ -1837,40 +1842,43 @@
             Preconditions.checkNotNull(network, "network cannot be null");
             Preconditions.checkNotNull(callback, "callback cannot be null");
             mNetwork = network;
-            mCallback = callback;
-            HandlerThread thread = new HandlerThread(TAG);
-            thread.start();
-            mLooper = thread.getLooper();
-            mMessenger = new Messenger(new Handler(mLooper) {
+            mExecutor = Executors.newSingleThreadExecutor();
+            mCallback = new ISocketKeepaliveCallback.Stub() {
                 @Override
-                public void handleMessage(Message message) {
-                    switch (message.what) {
-                        case NetworkAgent.EVENT_SOCKET_KEEPALIVE:
-                            int error = message.arg2;
-                            try {
-                                if (error == SUCCESS) {
-                                    if (mSlot == null) {
-                                        mSlot = message.arg1;
-                                        mCallback.onStarted();
-                                    } else {
-                                        mSlot = null;
-                                        stopLooper();
-                                        mCallback.onStopped();
-                                    }
-                                } else {
-                                    stopLooper();
-                                    mCallback.onError(error);
-                                }
-                            } catch (Exception e) {
-                                Log.e(TAG, "Exception in keepalive callback(" + error + ")", e);
-                            }
-                            break;
-                        default:
-                            Log.e(TAG, "Unhandled message " + Integer.toHexString(message.what));
-                            break;
-                    }
+                public void onStarted(int slot) {
+                    Binder.withCleanCallingIdentity(() ->
+                            mExecutor.execute(() -> {
+                                mSlot = slot;
+                                callback.onStarted();
+                            }));
                 }
-            });
+
+                @Override
+                public void onStopped() {
+                    Binder.withCleanCallingIdentity(() ->
+                            mExecutor.execute(() -> {
+                                mSlot = null;
+                                callback.onStopped();
+                            }));
+                    mExecutor.shutdown();
+                }
+
+                @Override
+                public void onError(int error) {
+                    Binder.withCleanCallingIdentity(() ->
+                            mExecutor.execute(() -> {
+                                mSlot = null;
+                                callback.onError(error);
+                            }));
+                    mExecutor.shutdown();
+                }
+
+                @Override
+                public void onDataReceived() {
+                    // PacketKeepalive is only used for Nat-T keepalive and as such does not invoke
+                    // this callback when data is received.
+                }
+            };
         }
     }
 
@@ -1887,12 +1895,11 @@
             InetAddress srcAddr, int srcPort, InetAddress dstAddr) {
         final PacketKeepalive k = new PacketKeepalive(network, callback);
         try {
-            mService.startNattKeepalive(network, intervalSeconds, k.mMessenger, new Binder(),
+            mService.startNattKeepalive(network, intervalSeconds, k.mCallback,
                     srcAddr.getHostAddress(), srcPort, dstAddr.getHostAddress());
         } catch (RemoteException e) {
             Log.e(TAG, "Error starting packet keepalive: ", e);
-            k.stopLooper();
-            return null;
+            throw e.rethrowFromSystemServer();
         }
         return k;
     }
diff --git a/core/java/android/net/IConnectivityManager.aidl b/core/java/android/net/IConnectivityManager.aidl
index f1e4f64..24e6a85 100644
--- a/core/java/android/net/IConnectivityManager.aidl
+++ b/core/java/android/net/IConnectivityManager.aidl
@@ -27,6 +27,7 @@
 import android.net.NetworkQuotaInfo;
 import android.net.NetworkRequest;
 import android.net.NetworkState;
+import android.net.ISocketKeepaliveCallback;
 import android.net.ProxyInfo;
 import android.os.Bundle;
 import android.os.IBinder;
@@ -194,15 +195,15 @@
 
     void factoryReset();
 
-    void startNattKeepalive(in Network network, int intervalSeconds, in Messenger messenger,
-            in IBinder binder, String srcAddr, int srcPort, String dstAddr);
+    void startNattKeepalive(in Network network, int intervalSeconds,
+            in ISocketKeepaliveCallback cb, String srcAddr, int srcPort, String dstAddr);
 
     void startNattKeepaliveWithFd(in Network network, in FileDescriptor fd, int resourceId,
-            int intervalSeconds, in Messenger messenger, in IBinder binder, String srcAddr,
+            int intervalSeconds, in ISocketKeepaliveCallback cb, String srcAddr,
             String dstAddr);
 
     void startTcpKeepalive(in Network network, in FileDescriptor fd, int intervalSeconds,
-            in Messenger messenger, in IBinder binder);
+            in ISocketKeepaliveCallback cb);
 
     void stopKeepalive(in Network network, int slot);
 
diff --git a/core/java/android/net/ISocketKeepaliveCallback.aidl b/core/java/android/net/ISocketKeepaliveCallback.aidl
new file mode 100644
index 0000000..020fbca
--- /dev/null
+++ b/core/java/android/net/ISocketKeepaliveCallback.aidl
@@ -0,0 +1,34 @@
+/**
+ * Copyright (c) 2019, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net;
+
+/**
+ * Callback to provide status changes of keepalive offload.
+ *
+ * @hide
+ */
+oneway interface ISocketKeepaliveCallback
+{
+    /** The keepalive was successfully started. */
+    void onStarted(int slot);
+    /** The keepalive was successfully stopped. */
+    void onStopped();
+    /** The keepalive was stopped because of an error. */
+    void onError(int error);
+    /** The keepalive on a TCP socket was stopped because the socket received data. */
+    void onDataReceived();
+}
diff --git a/core/java/android/net/NattSocketKeepalive.java b/core/java/android/net/NattSocketKeepalive.java
index 88631ae..84da294 100644
--- a/core/java/android/net/NattSocketKeepalive.java
+++ b/core/java/android/net/NattSocketKeepalive.java
@@ -17,7 +17,6 @@
 package android.net;
 
 import android.annotation.NonNull;
-import android.os.Binder;
 import android.os.RemoteException;
 import android.util.Log;
 
@@ -52,24 +51,30 @@
 
     @Override
     void startImpl(int intervalSec) {
-        try {
-            mService.startNattKeepaliveWithFd(mNetwork, mFd, mResourceId, intervalSec, mMessenger,
-                    new Binder(), mSource.getHostAddress(), mDestination.getHostAddress());
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error starting packet keepalive: ", e);
-            stopLooper();
-        }
+        mExecutor.execute(() -> {
+            try {
+                mService.startNattKeepaliveWithFd(mNetwork, mFd, mResourceId, intervalSec,
+                        mCallback,
+                        mSource.getHostAddress(), mDestination.getHostAddress());
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error starting socket keepalive: ", e);
+                throw e.rethrowFromSystemServer();
+            }
+        });
     }
 
     @Override
     void stopImpl() {
-        try {
-            if (mSlot != null) {
-                mService.stopKeepalive(mNetwork, mSlot);
+        mExecutor.execute(() -> {
+            try {
+                if (mSlot != null) {
+                    mService.stopKeepalive(mNetwork, mSlot);
+                }
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error stopping socket keepalive: ", e);
+                throw e.rethrowFromSystemServer();
             }
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error stopping packet keepalive: ", e);
-            stopLooper();
-        }
+        });
+
     }
 }
diff --git a/core/java/android/net/SocketKeepalive.java b/core/java/android/net/SocketKeepalive.java
index 07728be..0e768df 100644
--- a/core/java/android/net/SocketKeepalive.java
+++ b/core/java/android/net/SocketKeepalive.java
@@ -20,13 +20,8 @@
 import android.annotation.IntRange;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.os.Handler;
-import android.os.HandlerThread;
-import android.os.Looper;
-import android.os.Message;
-import android.os.Messenger;
-import android.os.Process;
-import android.util.Log;
+import android.os.Binder;
+import android.os.RemoteException;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
@@ -152,10 +147,9 @@
 
     @NonNull final IConnectivityManager mService;
     @NonNull final Network mNetwork;
-    @NonNull private final Executor mExecutor;
-    @NonNull private final SocketKeepalive.Callback mCallback;
-    @NonNull private final Looper mLooper;
-    @NonNull final Messenger mMessenger;
+    @NonNull final Executor mExecutor;
+    @NonNull final ISocketKeepaliveCallback mCallback;
+    // TODO: remove slot since mCallback could be used to identify which keepalive to stop.
     @Nullable Integer mSlot;
 
     SocketKeepalive(@NonNull IConnectivityManager service, @NonNull Network network,
@@ -163,53 +157,53 @@
         mService = service;
         mNetwork = network;
         mExecutor = executor;
-        mCallback = callback;
-        // TODO: 1. Use other thread modeling instead of create one thread for every instance to
-        //          reduce the memory cost.
-        //       2. support restart.
-        //       3. Fix race condition which caused by rapidly start and stop.
-        HandlerThread thread = new HandlerThread(TAG, Process.THREAD_PRIORITY_BACKGROUND
-                + Process.THREAD_PRIORITY_LESS_FAVORABLE);
-        thread.start();
-        mLooper = thread.getLooper();
-        mMessenger = new Messenger(new Handler(mLooper) {
+        mCallback = new ISocketKeepaliveCallback.Stub() {
             @Override
-            public void handleMessage(Message message) {
-                switch (message.what) {
-                    case NetworkAgent.EVENT_SOCKET_KEEPALIVE:
-                        final int status = message.arg2;
-                        try {
-                            if (status == SUCCESS) {
-                                if (mSlot == null) {
-                                    mSlot = message.arg1;
-                                    mExecutor.execute(() -> mCallback.onStarted());
-                                } else {
-                                    mSlot = null;
-                                    stopLooper();
-                                    mExecutor.execute(() -> mCallback.onStopped());
-                                }
-                            } else if (status == DATA_RECEIVED) {
-                                stopLooper();
-                                mExecutor.execute(() -> mCallback.onDataReceived());
-                            } else {
-                                stopLooper();
-                                mExecutor.execute(() -> mCallback.onError(status));
-                            }
-                        } catch (Exception e) {
-                            Log.e(TAG, "Exception in keepalive callback(" + status + ")", e);
-                        }
-                        break;
-                    default:
-                        Log.e(TAG, "Unhandled message " + Integer.toHexString(message.what));
-                        break;
-                }
+            public void onStarted(int slot) {
+                Binder.withCleanCallingIdentity(() ->
+                        mExecutor.execute(() -> {
+                            mSlot = slot;
+                            callback.onStarted();
+                        }));
             }
-        });
+
+            @Override
+            public void onStopped() {
+                Binder.withCleanCallingIdentity(() ->
+                        executor.execute(() -> {
+                            mSlot = null;
+                            callback.onStopped();
+                        }));
+            }
+
+            @Override
+            public void onError(int error) {
+                Binder.withCleanCallingIdentity(() ->
+                        executor.execute(() -> {
+                            mSlot = null;
+                            callback.onError(error);
+                        }));
+            }
+
+            @Override
+            public void onDataReceived() {
+                Binder.withCleanCallingIdentity(() ->
+                        executor.execute(() -> {
+                            mSlot = null;
+                            callback.onDataReceived();
+                        }));
+            }
+        };
     }
 
     /**
      * Request that keepalive be started with the given {@code intervalSec}. See
-     * {@link SocketKeepalive}.
+     * {@link SocketKeepalive}. If the remote binder dies, or the binder call throws an exception
+     * when invoking start or stop of the {@link SocketKeepalive}, a {@link RemoteException} will be
+     * thrown into the {@code executor}. This is typically not important to catch because the remote
+     * party is the system, so if it is not in shape to communicate through binder the system is
+     * probably going down anyway. If the caller cares regardless, it can use a custom
+     * {@link Executor} to catch the {@link RemoteException}.
      *
      * @param intervalSec The target interval in seconds between keepalive packet transmissions.
      *                    The interval should be between 10 seconds and 3600 seconds, otherwise
@@ -222,12 +216,6 @@
 
     abstract void startImpl(int intervalSec);
 
-    /** @hide */
-    protected void stopLooper() {
-        // TODO: remove this after changing thread modeling.
-        mLooper.quit();
-    }
-
     /**
      * Requests that keepalive be stopped. The application must wait for {@link Callback#onStopped}
      * before using the object. See {@link SocketKeepalive}.
@@ -245,7 +233,6 @@
     @Override
     public final void close() {
         stop();
-        stopLooper();
     }
 
     /**
@@ -259,7 +246,8 @@
         public void onStopped() {}
         /** An error occurred. */
         public void onError(@ErrorCode int error) {}
-        /** The keepalive on a TCP socket was stopped because the socket received data. */
+        /** The keepalive on a TCP socket was stopped because the socket received data. This is
+         * never called for UDP sockets. */
         public void onDataReceived() {}
     }
 }
diff --git a/core/java/android/net/TcpSocketKeepalive.java b/core/java/android/net/TcpSocketKeepalive.java
index f691a0d..26cc8ff 100644
--- a/core/java/android/net/TcpSocketKeepalive.java
+++ b/core/java/android/net/TcpSocketKeepalive.java
@@ -17,7 +17,6 @@
 package android.net;
 
 import android.annotation.NonNull;
-import android.os.Binder;
 import android.os.RemoteException;
 import android.util.Log;
 
@@ -56,24 +55,28 @@
      */
     @Override
     void startImpl(int intervalSec) {
-        try {
-            final FileDescriptor fd = mSocket.getFileDescriptor$();
-            mService.startTcpKeepalive(mNetwork, fd, intervalSec, mMessenger, new Binder());
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error starting packet keepalive: ", e);
-            stopLooper();
-        }
+        mExecutor.execute(() -> {
+            try {
+                final FileDescriptor fd = mSocket.getFileDescriptor$();
+                mService.startTcpKeepalive(mNetwork, fd, intervalSec, mCallback);
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error starting packet keepalive: ", e);
+                throw e.rethrowFromSystemServer();
+            }
+        });
     }
 
     @Override
     void stopImpl() {
-        try {
-            if (mSlot != null) {
-                mService.stopKeepalive(mNetwork, mSlot);
+        mExecutor.execute(() -> {
+            try {
+                if (mSlot != null) {
+                    mService.stopKeepalive(mNetwork, mSlot);
+                }
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error stopping packet keepalive: ", e);
+                throw e.rethrowFromSystemServer();
             }
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error stopping packet keepalive: ", e);
-            stopLooper();
-        }
+        });
     }
 }
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 3ed2948..343cee1 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -73,6 +73,7 @@
 import android.net.INetworkPolicyListener;
 import android.net.INetworkPolicyManager;
 import android.net.INetworkStatsService;
+import android.net.ISocketKeepaliveCallback;
 import android.net.ITetheringEventCallback;
 import android.net.InetAddresses;
 import android.net.IpPrefix;
@@ -6699,32 +6700,32 @@
     }
 
     @Override
-    public void startNattKeepalive(Network network, int intervalSeconds, Messenger messenger,
-            IBinder binder, String srcAddr, int srcPort, String dstAddr) {
+    public void startNattKeepalive(Network network, int intervalSeconds,
+            ISocketKeepaliveCallback cb, String srcAddr, int srcPort, String dstAddr) {
         enforceKeepalivePermission();
         mKeepaliveTracker.startNattKeepalive(
                 getNetworkAgentInfoForNetwork(network),
-                intervalSeconds, messenger, binder,
+                intervalSeconds, cb,
                 srcAddr, srcPort, dstAddr, NattSocketKeepalive.NATT_PORT);
     }
 
     @Override
     public void startNattKeepaliveWithFd(Network network, FileDescriptor fd, int resourceId,
-            int intervalSeconds, Messenger messenger, IBinder binder, String srcAddr,
+            int intervalSeconds, ISocketKeepaliveCallback cb, String srcAddr,
             String dstAddr) {
         enforceKeepalivePermission();
         mKeepaliveTracker.startNattKeepalive(
                 getNetworkAgentInfoForNetwork(network), fd, resourceId,
-                intervalSeconds, messenger, binder,
+                intervalSeconds, cb,
                 srcAddr, dstAddr, NattSocketKeepalive.NATT_PORT);
     }
 
     @Override
     public void startTcpKeepalive(Network network, FileDescriptor fd, int intervalSeconds,
-            Messenger messenger, IBinder binder) {
+            ISocketKeepaliveCallback cb) {
         enforceKeepalivePermission();
         mKeepaliveTracker.startTcpKeepalive(
-                getNetworkAgentInfoForNetwork(network), fd, intervalSeconds, messenger, binder);
+                getNetworkAgentInfoForNetwork(network), fd, intervalSeconds, cb);
     }
 
     @Override
diff --git a/services/core/java/com/android/server/connectivity/KeepaliveTracker.java b/services/core/java/com/android/server/connectivity/KeepaliveTracker.java
index cc4c173..35d6860 100644
--- a/services/core/java/com/android/server/connectivity/KeepaliveTracker.java
+++ b/services/core/java/com/android/server/connectivity/KeepaliveTracker.java
@@ -21,8 +21,8 @@
 import static android.net.NetworkAgent.CMD_REMOVE_KEEPALIVE_PACKET_FILTER;
 import static android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE;
 import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
-import static android.net.NetworkAgent.EVENT_SOCKET_KEEPALIVE;
 import static android.net.SocketKeepalive.BINDER_DIED;
+import static android.net.SocketKeepalive.DATA_RECEIVED;
 import static android.net.SocketKeepalive.ERROR_INVALID_INTERVAL;
 import static android.net.SocketKeepalive.ERROR_INVALID_IP_ADDRESS;
 import static android.net.SocketKeepalive.ERROR_INVALID_NETWORK;
@@ -34,6 +34,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.ISocketKeepaliveCallback;
 import android.net.KeepalivePacketData;
 import android.net.NattKeepalivePacketData;
 import android.net.NetworkAgent;
@@ -47,7 +48,6 @@
 import android.os.Handler;
 import android.os.IBinder;
 import android.os.Message;
-import android.os.Messenger;
 import android.os.Process;
 import android.os.RemoteException;
 import android.system.ErrnoException;
@@ -99,8 +99,7 @@
      */
     class KeepaliveInfo implements IBinder.DeathRecipient {
         // Bookkeeping data.
-        private final Messenger mMessenger;
-        private final IBinder mBinder;
+        private final ISocketKeepaliveCallback mCallback;
         private final int mUid;
         private final int mPid;
         private final NetworkAgentInfo mNai;
@@ -124,15 +123,13 @@
         private static final int STARTED = 3;
         private int mStartedState = NOT_STARTED;
 
-        KeepaliveInfo(@NonNull Messenger messenger,
-                @NonNull IBinder binder,
+        KeepaliveInfo(@NonNull ISocketKeepaliveCallback callback,
                 @NonNull NetworkAgentInfo nai,
                 @NonNull KeepalivePacketData packet,
                 int interval,
                 int type,
                 @NonNull FileDescriptor fd) {
-            mMessenger = messenger;
-            mBinder = binder;
+            mCallback = callback;
             mPid = Binder.getCallingPid();
             mUid = Binder.getCallingUid();
 
@@ -143,7 +140,7 @@
             mFd = fd;
 
             try {
-                mBinder.linkToDeath(this, 0);
+                mCallback.asBinder().linkToDeath(this, 0);
             } catch (RemoteException e) {
                 binderDied();
             }
@@ -176,22 +173,14 @@
                     + " ]";
         }
 
-        /** Sends a message back to the application via its SocketKeepalive.Callback. */
-        void notifyMessenger(int slot, int err) {
-            if (DBG) {
-                Log.d(TAG, "notify keepalive " + mSlot + " on " + mNai.network + " for " + err);
-            }
-            KeepaliveTracker.this.notifyMessenger(mMessenger, slot, err);
-        }
-
         /** Called when the application process is killed. */
         public void binderDied() {
             stop(BINDER_DIED);
         }
 
         void unlinkDeathRecipient() {
-            if (mBinder != null) {
-                mBinder.unlinkToDeath(this, 0);
+            if (mCallback != null) {
+                mCallback.asBinder().unlinkToDeath(this, 0);
             }
         }
 
@@ -283,9 +272,23 @@
                     Log.wtf(TAG, "Stopping keepalive with unknown type: " + mType);
                 }
             }
-            // TODO: at the moment we unconditionally return failure here. In cases where the
-            // NetworkAgent is alive, should we ask it to reply, so it can return failure?
-            notifyMessenger(mSlot, reason);
+
+            if (reason == SUCCESS) {
+                try {
+                    mCallback.onStopped();
+                } catch (RemoteException e) {
+                    Log.w(TAG, "Discarded onStop callback: " + reason);
+                }
+            } else if (reason == DATA_RECEIVED) {
+                try {
+                    mCallback.onDataReceived();
+                } catch (RemoteException e) {
+                    Log.w(TAG, "Discarded onDataReceived callback: " + reason);
+                }
+            } else {
+                notifyErrorCallback(mCallback, reason);
+            }
+
             unlinkDeathRecipient();
         }
 
@@ -294,16 +297,12 @@
         }
     }
 
-    void notifyMessenger(Messenger messenger, int slot, int err) {
-        Message message = Message.obtain();
-        message.what = EVENT_SOCKET_KEEPALIVE;
-        message.arg1 = slot;
-        message.arg2 = err;
-        message.obj = null;
+    void notifyErrorCallback(ISocketKeepaliveCallback cb, int error) {
+        if (DBG) Log.w(TAG, "Sending onError(" + error + ") callback");
         try {
-            messenger.send(message);
+            cb.onError(error);
         } catch (RemoteException e) {
-            // Process died?
+            Log.w(TAG, "Discarded onError(" + error + ") callback");
         }
     }
 
@@ -414,7 +413,11 @@
             // Keepalive successfully started.
             if (DBG) Log.d(TAG, "Started keepalive " + slot + " on " + nai.name());
             ki.mStartedState = KeepaliveInfo.STARTED;
-            ki.notifyMessenger(slot, reason);
+            try {
+                ki.mCallback.onStarted(slot);
+            } catch (RemoteException e) {
+                Log.w(TAG, "Discarded onStarted(" + slot + ") callback");
+            }
         } else {
             // Keepalive successfully stopped, or error.
             ki.mStartedState = KeepaliveInfo.NOT_STARTED;
@@ -436,14 +439,13 @@
      **/
     public void startNattKeepalive(@Nullable NetworkAgentInfo nai,
             int intervalSeconds,
-            @NonNull Messenger messenger,
-            @NonNull IBinder binder,
+            @NonNull ISocketKeepaliveCallback cb,
             @NonNull String srcAddrString,
             int srcPort,
             @NonNull String dstAddrString,
             int dstPort) {
         if (nai == null) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_NETWORK);
+            notifyErrorCallback(cb, ERROR_INVALID_NETWORK);
             return;
         }
 
@@ -452,7 +454,7 @@
             srcAddress = NetworkUtils.numericToInetAddress(srcAddrString);
             dstAddress = NetworkUtils.numericToInetAddress(dstAddrString);
         } catch (IllegalArgumentException e) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_IP_ADDRESS);
+            notifyErrorCallback(cb, ERROR_INVALID_IP_ADDRESS);
             return;
         }
 
@@ -461,11 +463,12 @@
             packet = NattKeepalivePacketData.nattKeepalivePacket(
                     srcAddress, srcPort, dstAddress, NATT_PORT);
         } catch (InvalidPacketException e) {
-            notifyMessenger(messenger, NO_KEEPALIVE, e.error);
+            notifyErrorCallback(cb, e.error);
             return;
         }
-        KeepaliveInfo ki = new KeepaliveInfo(messenger, binder, nai, packet, intervalSeconds,
+        KeepaliveInfo ki = new KeepaliveInfo(cb, nai, packet, intervalSeconds,
                 KeepaliveInfo.TYPE_NATT, null);
+        Log.d(TAG, "Created keepalive: " + ki.toString());
         mConnectivityServiceHandler.obtainMessage(
                 NetworkAgent.CMD_START_SOCKET_KEEPALIVE, ki).sendToTarget();
     }
@@ -483,10 +486,9 @@
     public void startTcpKeepalive(@Nullable NetworkAgentInfo nai,
             @NonNull FileDescriptor fd,
             int intervalSeconds,
-            @NonNull Messenger messenger,
-            @NonNull IBinder binder) {
+            @NonNull ISocketKeepaliveCallback cb) {
         if (nai == null) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_NETWORK);
+            notifyErrorCallback(cb, ERROR_INVALID_NETWORK);
             return;
         }
 
@@ -500,10 +502,10 @@
             } catch (ErrnoException e1) {
                 Log.e(TAG, "Couldn't move fd out of repair mode after failure to start keepalive");
             }
-            notifyMessenger(messenger, NO_KEEPALIVE, e.error);
+            notifyErrorCallback(cb, e.error);
             return;
         }
-        KeepaliveInfo ki = new KeepaliveInfo(messenger, binder, nai, packet, intervalSeconds,
+        KeepaliveInfo ki = new KeepaliveInfo(cb, nai, packet, intervalSeconds,
                 KeepaliveInfo.TYPE_TCP, fd);
         Log.d(TAG, "Created keepalive: " + ki.toString());
         mConnectivityServiceHandler.obtainMessage(CMD_START_SOCKET_KEEPALIVE, ki).sendToTarget();
@@ -520,14 +522,13 @@
             @Nullable FileDescriptor fd,
             int resourceId,
             int intervalSeconds,
-            @NonNull Messenger messenger,
-            @NonNull IBinder binder,
+            @NonNull ISocketKeepaliveCallback cb,
             @NonNull String srcAddrString,
             @NonNull String dstAddrString,
             int dstPort) {
         // Ensure that the socket is created by IpSecService.
         if (!isNattKeepaliveSocketValid(fd, resourceId)) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_SOCKET);
+            notifyErrorCallback(cb, ERROR_INVALID_SOCKET);
         }
 
         // Get src port to adopt old API.
@@ -536,11 +537,11 @@
             final SocketAddress srcSockAddr = Os.getsockname(fd);
             srcPort = ((InetSocketAddress) srcSockAddr).getPort();
         } catch (ErrnoException e) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_SOCKET);
+            notifyErrorCallback(cb, ERROR_INVALID_SOCKET);
         }
 
         // Forward request to old API.
-        startNattKeepalive(nai, intervalSeconds, messenger, binder, srcAddrString, srcPort,
+        startNattKeepalive(nai, intervalSeconds, cb, srcAddrString, srcPort,
                 dstAddrString, dstPort);
     }
 
diff --git a/tests/net/java/com/android/internal/util/TestUtils.java b/tests/net/java/com/android/internal/util/TestUtils.java
index 7e5a1d3..57cc172 100644
--- a/tests/net/java/com/android/internal/util/TestUtils.java
+++ b/tests/net/java/com/android/internal/util/TestUtils.java
@@ -19,6 +19,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 
+import android.annotation.NonNull;
 import android.os.ConditionVariable;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -26,6 +27,8 @@
 import android.os.Parcel;
 import android.os.Parcelable;
 
+import java.util.concurrent.Executor;
+
 public final class TestUtils {
     private TestUtils() { }
 
@@ -54,6 +57,17 @@
         }
     }
 
+    /**
+     * Block until the given Serial Executor becomes idle, or until timeoutMs has passed.
+     */
+    public static void waitForIdleSerialExecutor(@NonNull Executor executor, long timeoutMs) {
+        final ConditionVariable cv = new ConditionVariable();
+        executor.execute(() -> cv.open());
+        if (!cv.block(timeoutMs)) {
+            fail(executor.toString() + " did not become idle after " + timeoutMs + " ms");
+        }
+    }
+
     // TODO : fetch the creator through reflection or something instead of passing it
     public static <T extends Parcelable, C extends Parcelable.Creator<T>>
             void assertParcelingIsLossless(T source, C creator) {
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index dc11bf3..0633322 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -64,6 +64,7 @@
 
 import static com.android.internal.util.TestUtils.waitForIdleHandler;
 import static com.android.internal.util.TestUtils.waitForIdleLooper;
+import static com.android.internal.util.TestUtils.waitForIdleSerialExecutor;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -88,6 +89,7 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
+import android.annotation.NonNull;
 import android.app.NotificationManager;
 import android.app.PendingIntent;
 import android.content.BroadcastReceiver;
@@ -3762,7 +3764,7 @@
             }
         }
 
-        private LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
+        private final LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
 
         @Override
         public void onStarted() {
@@ -3837,6 +3839,11 @@
         }
 
         private LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
+        private final Executor mExecutor;
+
+        TestSocketKeepaliveCallback(@NonNull Executor executor) {
+            mExecutor = executor;
+        }
 
         @Override
         public void onStarted() {
@@ -3874,6 +3881,12 @@
         public void expectError(int error) {
             expectCallback(new CallbackValue(CallbackType.ON_ERROR, error));
         }
+
+        public void assertNoCallback() {
+            waitForIdleSerialExecutor(mExecutor, TIMEOUT_MS);
+            CallbackValue cv = mCallbacks.peek();
+            assertNull("Unexpected callback: " + cv, cv);
+        }
     }
 
     private Network connectKeepaliveNetwork(LinkProperties lp) {
@@ -3980,19 +3993,6 @@
         myNet = connectKeepaliveNetwork(lp);
         mWiFiNetworkAgent.setStartKeepaliveError(PacketKeepalive.SUCCESS);
 
-        // Check things work as expected when the keepalive is stopped and the network disconnects.
-        ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 12345, dstIPv4);
-        callback.expectStarted();
-        ka.stop();
-        mWiFiNetworkAgent.disconnect();
-        waitFor(mWiFiNetworkAgent.getDisconnectedCV());
-        waitForIdle();
-        callback.expectStopped();
-
-        // Reconnect.
-        myNet = connectKeepaliveNetwork(lp);
-        mWiFiNetworkAgent.setStartKeepaliveError(PacketKeepalive.SUCCESS);
-
         // Check that keepalive slots start from 1 and increment. The first one gets slot 1.
         mWiFiNetworkAgent.setExpectedKeepaliveSlot(1);
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 12345, dstIPv4);
@@ -4068,7 +4068,7 @@
         Network notMyNet = new Network(61234);
         Network myNet = connectKeepaliveNetwork(lp);
 
-        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback();
+        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback(executor);
         SocketKeepalive ka;
 
         // Attempt to start keepalives with invalid parameters and check for errors.
@@ -4111,6 +4111,22 @@
         ka.stop();
         callback.expectStopped();
 
+        // Check that keepalive could be restarted.
+        ka.start(validKaInterval);
+        callback.expectStarted();
+        ka.stop();
+        callback.expectStopped();
+
+        // Check that keepalive can be restarted without waiting for callback.
+        ka.start(validKaInterval);
+        callback.expectStarted();
+        ka.stop();
+        ka.start(validKaInterval);
+        callback.expectStopped();
+        callback.expectStarted();
+        ka.stop();
+        callback.expectStopped();
+
         // Check that deleting the IP address stops the keepalive.
         LinkProperties bogusLp = new LinkProperties(lp);
         ka = mCm.createSocketKeepalive(myNet, testSocket, myIPv4, dstIPv4, executor, callback);
@@ -4135,20 +4151,7 @@
         final Network myNetAlias = myNet;
         assertNull(mCm.getNetworkCapabilities(myNetAlias));
         ka.stop();
-
-        // Reconnect.
-        myNet = connectKeepaliveNetwork(lp);
-        mWiFiNetworkAgent.setStartKeepaliveError(SocketKeepalive.SUCCESS);
-
-        // Check things work as expected when the keepalive is stopped and the network disconnects.
-        ka = mCm.createSocketKeepalive(myNet, testSocket, myIPv4, dstIPv4, executor, callback);
-        ka.start(validKaInterval);
-        callback.expectStarted();
-        ka.stop();
-        mWiFiNetworkAgent.disconnect();
-        waitFor(mWiFiNetworkAgent.getDisconnectedCV());
-        waitForIdle();
-        callback.expectStopped();
+        callback.assertNoCallback();
 
         // Reconnect.
         myNet = connectKeepaliveNetwork(lp);
@@ -4163,7 +4166,7 @@
         // The second one gets slot 2.
         mWiFiNetworkAgent.setExpectedKeepaliveSlot(2);
         final UdpEncapsulationSocket testSocket2 = mIpSec.openUdpEncapsulationSocket(6789);
-        TestSocketKeepaliveCallback callback2 = new TestSocketKeepaliveCallback();
+        TestSocketKeepaliveCallback callback2 = new TestSocketKeepaliveCallback(executor);
         SocketKeepalive ka2 =
                 mCm.createSocketKeepalive(myNet, testSocket2, myIPv4, dstIPv4, executor, callback2);
         ka2.start(validKaInterval);
@@ -4216,7 +4219,7 @@
         final Socket testSocketV4 = new Socket();
         final Socket testSocketV6 = new Socket();
 
-        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback();
+        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback(executor);
         SocketKeepalive ka;
 
         // Attempt to start Tcp keepalives with invalid parameters and check for errors.