MTP: Add support for restricting PTP to only certain subdirectories of the storage

Bug: 5527220

Change-Id: If68e7481617ecb62abd24e2d89e6b7dfdf95ba2b
Signed-off-by: Mike Lockwood <lockwood@google.com>
diff --git a/media/java/android/mtp/MtpDatabase.java b/media/java/android/mtp/MtpDatabase.java
index 98617d2..19db1c0 100755
--- a/media/java/android/mtp/MtpDatabase.java
+++ b/media/java/android/mtp/MtpDatabase.java
@@ -51,7 +51,15 @@
     private final IContentProvider mMediaProvider;
     private final String mVolumeName;
     private final Uri mObjectsUri;
-    private final String mMediaStoragePath; // path to primary storage
+    // path to primary storage
+    private final String mMediaStoragePath;
+    // if not null, restrict all queries to these subdirectories
+    private final String[] mSubDirectories;
+    // where clause for restricting queries to files in mSubDirectories
+    private String mSubDirectoriesWhere;
+    // where arguments for restricting queries to files in mSubDirectories
+    private String[] mSubDirectoriesWhereArgs;
+
     private final HashMap<String, MtpStorage> mStorageMap = new HashMap<String, MtpStorage>();
 
     // cached property groups for single properties
@@ -112,7 +120,8 @@
         System.loadLibrary("media_jni");
     }
 
-    public MtpDatabase(Context context, String volumeName, String storagePath) {
+    public MtpDatabase(Context context, String volumeName, String storagePath,
+            String[] subDirectories) {
         native_setup();
 
         mContext = context;
@@ -122,6 +131,31 @@
         mObjectsUri = Files.getMtpObjectsUri(volumeName);
         mMediaScanner = new MediaScanner(context);
 
+        mSubDirectories = subDirectories;
+        if (subDirectories != null) {
+            // Compute "where" string for restricting queries to subdirectories
+            StringBuilder builder = new StringBuilder();
+            builder.append("(");
+            int count = subDirectories.length;
+            for (int i = 0; i < count; i++) {
+                builder.append(Files.FileColumns.DATA + "=? OR "
+                        + Files.FileColumns.DATA + " LIKE ?");
+                if (i != count - 1) {
+                    builder.append(" OR ");
+                }
+            }
+            builder.append(")");
+            mSubDirectoriesWhere = builder.toString();
+
+            // Compute "where" arguments for restricting queries to subdirectories
+            mSubDirectoriesWhereArgs = new String[count * 2];
+            for (int i = 0, j = 0; i < count; i++) {
+                String path = subDirectories[i];
+                mSubDirectoriesWhereArgs[j++] = path;
+                mSubDirectoriesWhereArgs[j++] = path + "/%";
+            }
+        }
+
         // Set locale to MediaScanner.
         Locale locale = context.getResources().getConfiguration().locale;
         if (locale != null) {
@@ -190,9 +224,44 @@
         }
     }
 
+    // check to see if the path is contained in one of our storage subdirectories
+    // returns true if we have no special subdirectories
+    private boolean inStorageSubDirectory(String path) {
+        if (mSubDirectories == null) return true;
+        if (path == null) return false;
+
+        boolean allowed = false;
+        int pathLength = path.length();
+        for (int i = 0; i < mSubDirectories.length && !allowed; i++) {
+            String subdir = mSubDirectories[i];
+            int subdirLength = subdir.length();
+            if (subdirLength < pathLength &&
+                    path.charAt(subdirLength) == '/' &&
+                    path.startsWith(subdir)) {
+                allowed = true;
+            }
+        }
+        return allowed;
+    }
+
+    // check to see if the path matches one of our storage subdirectories
+    // returns true if we have no special subdirectories
+    private boolean isStorageSubDirectory(String path) {
+    if (mSubDirectories == null) return false;
+        for (int i = 0; i < mSubDirectories.length; i++) {
+            if (path.equals(mSubDirectories[i])) {
+                return true;
+            }
+        }
+        return false;
+    }
+
     private int beginSendObject(String path, int format, int parent,
                          int storageId, long size, long modified) {
-        // first make sure the object does not exist
+        // if mSubDirectories is not null, do not allow copying files to any other locations
+        if (!inStorageSubDirectory(path)) return -1;
+
+        // make sure the object does not exist
         if (path != null) {
             Cursor c = null;
             try {
@@ -269,33 +338,40 @@
     }
 
     private Cursor createObjectQuery(int storageID, int format, int parent) throws RemoteException {
+        String where;
+        String[] whereArgs;
+
         if (storageID == 0xFFFFFFFF) {
             // query all stores
             if (format == 0) {
                 // query all formats
                 if (parent == 0) {
                     // query all objects
-                    return mMediaProvider.query(mObjectsUri, ID_PROJECTION, null, null, null);
+                    where = null;
+                    whereArgs = null;
+                } else {
+                    if (parent == 0xFFFFFFFF) {
+                        // all objects in root of store
+                        parent = 0;
+                    }
+                    where = PARENT_WHERE;
+                    whereArgs = new String[] { Integer.toString(parent) };
                 }
-                if (parent == 0xFFFFFFFF) {
-                    // all objects in root of store
-                    parent = 0;
-                }
-                return mMediaProvider.query(mObjectsUri, ID_PROJECTION, PARENT_WHERE,
-                        new String[] { Integer.toString(parent) }, null);
             } else {
                 // query specific format
                 if (parent == 0) {
                     // query all objects
-                    return mMediaProvider.query(mObjectsUri, ID_PROJECTION, FORMAT_WHERE,
-                            new String[] { Integer.toString(format) }, null);
+                    where = FORMAT_WHERE;
+                    whereArgs = new String[] { Integer.toString(format) };
+                } else {
+                    if (parent == 0xFFFFFFFF) {
+                        // all objects in root of store
+                        parent = 0;
+                    }
+                    where = FORMAT_PARENT_WHERE;
+                    whereArgs = new String[] { Integer.toString(format),
+                                               Integer.toString(parent) };
                 }
-                if (parent == 0xFFFFFFFF) {
-                    // all objects in root of store
-                    parent = 0;
-                }
-                return mMediaProvider.query(mObjectsUri, ID_PROJECTION, FORMAT_PARENT_WHERE,
-                        new String[] { Integer.toString(format), Integer.toString(parent) }, null);
             }
         } else {
             // query specific store
@@ -303,35 +379,61 @@
                 // query all formats
                 if (parent == 0) {
                     // query all objects
-                    return mMediaProvider.query(mObjectsUri, ID_PROJECTION, STORAGE_WHERE,
-                            new String[] { Integer.toString(storageID) }, null);
+                    where = STORAGE_WHERE;
+                    whereArgs = new String[] { Integer.toString(storageID) };
+                } else {
+                    if (parent == 0xFFFFFFFF) {
+                        // all objects in root of store
+                        parent = 0;
+                    }
+                    where = STORAGE_PARENT_WHERE;
+                    whereArgs = new String[] { Integer.toString(storageID),
+                                               Integer.toString(parent) };
                 }
-                if (parent == 0xFFFFFFFF) {
-                    // all objects in root of store
-                    parent = 0;
-                }
-                return mMediaProvider.query(mObjectsUri, ID_PROJECTION, STORAGE_PARENT_WHERE,
-                        new String[] { Integer.toString(storageID), Integer.toString(parent) },
-                        null);
             } else {
                 // query specific format
                 if (parent == 0) {
                     // query all objects
-                    return mMediaProvider.query(mObjectsUri, ID_PROJECTION, STORAGE_FORMAT_WHERE,
-                            new String[] {  Integer.toString(storageID), Integer.toString(format) },
-                            null);
+                    where = STORAGE_FORMAT_WHERE;
+                    whereArgs = new String[] {  Integer.toString(storageID),
+                                                Integer.toString(format) };
+                } else {
+                    if (parent == 0xFFFFFFFF) {
+                        // all objects in root of store
+                        parent = 0;
+                    }
+                    where = STORAGE_FORMAT_PARENT_WHERE;
+                    whereArgs = new String[] { Integer.toString(storageID),
+                                               Integer.toString(format),
+                                               Integer.toString(parent) };
                 }
-                if (parent == 0xFFFFFFFF) {
-                    // all objects in root of store
-                    parent = 0;
-                }
-                return mMediaProvider.query(mObjectsUri, ID_PROJECTION, STORAGE_FORMAT_PARENT_WHERE,
-                        new String[] { Integer.toString(storageID),
-                                       Integer.toString(format),
-                                       Integer.toString(parent) },
-                        null);
             }
         }
+
+        // if we are restricting queries to mSubDirectories, we need to add the restriction
+        // onto our "where" arguments
+        if (mSubDirectoriesWhere != null) {
+            if (where == null) {
+                where = mSubDirectoriesWhere;
+                whereArgs = mSubDirectoriesWhereArgs;
+            } else {
+                where = where + " AND " + mSubDirectoriesWhere;
+
+                // create new array to hold whereArgs and mSubDirectoriesWhereArgs
+                String[] newWhereArgs =
+                        new String[whereArgs.length + mSubDirectoriesWhereArgs.length];
+                int i, j;
+                for (i = 0; i < whereArgs.length; i++) {
+                    newWhereArgs[i] = whereArgs[i];
+                }
+                for (j = 0; j < mSubDirectoriesWhereArgs.length; i++, j++) {
+                    newWhereArgs[i] = mSubDirectoriesWhereArgs[j];
+                }
+                whereArgs = newWhereArgs;
+            }
+        }
+
+        return mMediaProvider.query(mObjectsUri, ID_PROJECTION, where, whereArgs, null);
     }
 
     private int[] getObjectList(int storageID, int format, int parent) {
@@ -613,6 +715,11 @@
             return MtpConstants.RESPONSE_INVALID_OBJECT_HANDLE;
         }
 
+        // do not allow renaming any of the special subdirectories
+        if (isStorageSubDirectory(path)) {
+            return MtpConstants.RESPONSE_OBJECT_WRITE_PROTECTED;
+        }
+
         // now rename the file.  make sure this succeeds before updating database
         File oldFile = new File(path);
         int lastSlash = path.lastIndexOf('/');
@@ -794,6 +901,11 @@
                 return MtpConstants.RESPONSE_GENERAL_ERROR;
             }
 
+            // do not allow deleting any of the special subdirectories
+            if (isStorageSubDirectory(path)) {
+                return MtpConstants.RESPONSE_OBJECT_WRITE_PROTECTED;
+            }
+
             if (format == MtpConstants.FORMAT_ASSOCIATION) {
                 // recursive case - delete all children first
                 Uri uri = Files.getMtpObjectsUri(mVolumeName);
diff --git a/media/mtp/MtpServer.cpp b/media/mtp/MtpServer.cpp
index 51eb97f..1334e6c 100644
--- a/media/mtp/MtpServer.cpp
+++ b/media/mtp/MtpServer.cpp
@@ -1053,11 +1053,14 @@
     int result = mDatabase->getObjectFilePath(handle, filePath, fileLength, format);
     if (result == MTP_RESPONSE_OK) {
         ALOGV("deleting %s", (const char *)filePath);
-        deletePath((const char *)filePath);
-        return mDatabase->deleteFile(handle);
-    } else {
-        return result;
+        result = mDatabase->deleteFile(handle);
+        // Don't delete the actual files unless the database deletion is allowed
+        if (result == MTP_RESPONSE_OK) {
+            deletePath((const char *)filePath);
+        }
     }
+
+    return result;
 }
 
 MtpResponseCode MtpServer::doGetObjectPropDesc() {