Avoid NPE when quickly toggling USB debugging state

If you try to disable USB debugging before the socket
to listen is opened in the thread, it will end up
with an NPE.

Do some locking around socket creation and closing
to avoid this.

Bug: 18708503
Change-Id: Iac43e4806fff1e411772b1ba1a070d8a7c776fcb
diff --git a/services/usb/java/com/android/server/usb/UsbDebuggingManager.java b/services/usb/java/com/android/server/usb/UsbDebuggingManager.java
index e489279..6fcd1eb 100644
--- a/services/usb/java/com/android/server/usb/UsbDebuggingManager.java
+++ b/services/usb/java/com/android/server/usb/UsbDebuggingManager.java
@@ -35,6 +35,7 @@
 import android.os.UserHandle;
 import android.util.Slog;
 import android.util.Base64;
+
 import com.android.server.FgThread;
 
 import java.lang.Thread;
@@ -48,7 +49,7 @@
 import java.security.MessageDigest;
 import java.util.Arrays;
 
-public class UsbDebuggingManager implements Runnable {
+public class UsbDebuggingManager {
     private static final String TAG = "UsbDebuggingManager";
     private static final boolean DEBUG = false;
 
@@ -59,86 +60,135 @@
 
     private final Context mContext;
     private final Handler mHandler;
-    private Thread mThread;
+    private UsbDebuggingThread mThread;
     private boolean mAdbEnabled = false;
     private String mFingerprints;
-    private LocalSocket mSocket = null;
-    private OutputStream mOutputStream = null;
 
     public UsbDebuggingManager(Context context) {
         mHandler = new UsbDebuggingHandler(FgThread.get().getLooper());
         mContext = context;
     }
 
-    private void listenToSocket() throws IOException {
-        try {
-            byte[] buffer = new byte[BUFFER_SIZE];
-            LocalSocketAddress address = new LocalSocketAddress(ADBD_SOCKET,
-                                         LocalSocketAddress.Namespace.RESERVED);
-            InputStream inputStream = null;
+    class UsbDebuggingThread extends Thread {
+        private boolean mStopped;
+        private LocalSocket mSocket;
+        private OutputStream mOutputStream;
+        private InputStream mInputStream;
 
-            mSocket = new LocalSocket();
-            mSocket.connect(address);
+        UsbDebuggingThread() {
+            super(TAG);
+        }
 
-            mOutputStream = mSocket.getOutputStream();
-            inputStream = mSocket.getInputStream();
-
+        @Override
+        public void run() {
+            if (DEBUG) Slog.d(TAG, "Entering thread");
             while (true) {
-                int count = inputStream.read(buffer);
-                if (count < 0) {
-                    break;
+                synchronized (this) {
+                    if (mStopped) {
+                        if (DEBUG) Slog.d(TAG, "Exiting thread");
+                        return;
+                    }
+                    try {
+                        openSocketLocked();
+                    } catch (Exception e) {
+                        /* Don't loop too fast if adbd dies, before init restarts it */
+                        SystemClock.sleep(1000);
+                    }
                 }
-
-                if (buffer[0] == 'P' && buffer[1] == 'K') {
-                    String key = new String(Arrays.copyOfRange(buffer, 2, count));
-                    Slog.d(TAG, "Received public key: " + key);
-                    Message msg = mHandler.obtainMessage(UsbDebuggingHandler.MESSAGE_ADB_CONFIRM);
-                    msg.obj = key;
-                    mHandler.sendMessage(msg);
-                }
-                else {
-                    Slog.e(TAG, "Wrong message: " + (new String(Arrays.copyOfRange(buffer, 0, 2))));
-                    break;
+                try {
+                    listenToSocket();
+                } catch (IOException e) {
+                    /* Don't loop too fast if adbd dies, before init restarts it */
+                    SystemClock.sleep(1000);
                 }
             }
-        } finally {
-            closeSocket();
         }
-    }
 
-    @Override
-    public void run() {
-        while (mAdbEnabled) {
+        private void openSocketLocked() throws IOException {
             try {
-                listenToSocket();
-            } catch (Exception e) {
-                /* Don't loop too fast if adbd dies, before init restarts it */
-                SystemClock.sleep(1000);
+                LocalSocketAddress address = new LocalSocketAddress(ADBD_SOCKET,
+                        LocalSocketAddress.Namespace.RESERVED);
+                mInputStream = null;
+
+                if (DEBUG) Slog.d(TAG, "Creating socket");
+                mSocket = new LocalSocket();
+                mSocket.connect(address);
+
+                mOutputStream = mSocket.getOutputStream();
+                mInputStream = mSocket.getInputStream();
+            } catch (IOException ioe) {
+                closeSocketLocked();
+                throw ioe;
             }
         }
-    }
 
-    private void closeSocket() {
-        try {
-            mOutputStream.close();
-        } catch (IOException e) {
-            Slog.e(TAG, "Failed closing output stream: " + e);
-        }
-
-        try {
-            mSocket.close();
-        } catch (IOException ex) {
-            Slog.e(TAG, "Failed closing socket: " + ex);
-        }
-    }
-
-    private void sendResponse(String msg) {
-        if (mOutputStream != null) {
+        private void listenToSocket() throws IOException {
             try {
-                mOutputStream.write(msg.getBytes());
+                byte[] buffer = new byte[BUFFER_SIZE];
+                while (true) {
+                    int count = mInputStream.read(buffer);
+                    if (count < 0) {
+                        break;
+                    }
+
+                    if (buffer[0] == 'P' && buffer[1] == 'K') {
+                        String key = new String(Arrays.copyOfRange(buffer, 2, count));
+                        Slog.d(TAG, "Received public key: " + key);
+                        Message msg = mHandler.obtainMessage(UsbDebuggingHandler.MESSAGE_ADB_CONFIRM);
+                        msg.obj = key;
+                        mHandler.sendMessage(msg);
+                    } else {
+                        Slog.e(TAG, "Wrong message: "
+                                + (new String(Arrays.copyOfRange(buffer, 0, 2))));
+                        break;
+                    }
+                }
+            } finally {
+                synchronized (this) {
+                    closeSocketLocked();
+                }
             }
-            catch (IOException ex) {
-                Slog.e(TAG, "Failed to write response:", ex);
+        }
+
+        private void closeSocketLocked() {
+            if (DEBUG) Slog.d(TAG, "Closing socket");
+            try {
+                if (mOutputStream != null) {
+                    mOutputStream.close();
+                    mOutputStream = null;
+                }
+            } catch (IOException e) {
+                Slog.e(TAG, "Failed closing output stream: " + e);
+            }
+
+            try {
+                if (mSocket != null) {
+                    mSocket.close();
+                    mSocket = null;
+                }
+            } catch (IOException ex) {
+                Slog.e(TAG, "Failed closing socket: " + ex);
+            }
+        }
+
+        /** Call to stop listening on the socket and exit the thread. */
+        void stopListening() {
+            synchronized (this) {
+                mStopped = true;
+                closeSocketLocked();
+            }
+        }
+
+        void sendResponse(String msg) {
+            synchronized (this) {
+                if (!mStopped && mOutputStream != null) {
+                    try {
+                        mOutputStream.write(msg.getBytes());
+                    }
+                    catch (IOException ex) {
+                        Slog.e(TAG, "Failed to write response:", ex);
+                    }
+                }
             }
         }
     }
@@ -163,7 +213,7 @@
 
                     mAdbEnabled = true;
 
-                    mThread = new Thread(UsbDebuggingManager.this, TAG);
+                    mThread = new UsbDebuggingThread();
                     mThread.start();
 
                     break;
@@ -173,16 +223,12 @@
                         break;
 
                     mAdbEnabled = false;
-                    closeSocket();
 
-                    try {
-                        mThread.join();
-                    } catch (Exception ex) {
+                    if (mThread != null) {
+                        mThread.stopListening();
+                        mThread = null;
                     }
 
-                    mThread = null;
-                    mOutputStream = null;
-                    mSocket = null;
                     break;
 
                 case MESSAGE_ADB_ALLOW: {
@@ -199,25 +245,33 @@
                         writeKey(key);
                     }
 
-                    sendResponse("OK");
+                    if (mThread != null) {
+                        mThread.sendResponse("OK");
+                    }
                     break;
                 }
 
                 case MESSAGE_ADB_DENY:
-                    sendResponse("NO");
+                    if (mThread != null) {
+                        mThread.sendResponse("NO");
+                    }
                     break;
 
                 case MESSAGE_ADB_CONFIRM: {
                     if ("trigger_restart_min_framework".equals(
                             SystemProperties.get("vold.decrypt"))) {
                         Slog.d(TAG, "Deferring adb confirmation until after vold decrypt");
-                        sendResponse("NO");
+                        if (mThread != null) {
+                            mThread.sendResponse("NO");
+                        }
                         break;
                     }
                     String key = (String)msg.obj;
                     String fingerprints = getFingerprints(key);
                     if ("".equals(fingerprints)) {
-                        sendResponse("NO");
+                        if (mThread != null) {
+                            mThread.sendResponse("NO");
+                        }
                         break;
                     }
                     mFingerprints = fingerprints;
@@ -387,7 +441,7 @@
 
     public void dump(FileDescriptor fd, PrintWriter pw) {
         pw.println("  USB Debugging State:");
-        pw.println("    Connected to adbd: " + (mOutputStream != null));
+        pw.println("    Connected to adbd: " + (mThread != null));
         pw.println("    Last key received: " + mFingerprints);
         pw.println("    User keys:");
         try {