pal: validate stream handle in pal server wrapper

This validates if pal client passes valid stream
handle to server, which is to avoid arbitrary
dereference of handle pointers.

Change-Id: I419f803e232c994f2d9865ba40f75601f6e41f0c
(cherry picked from commit e076162fec4e20948bdae29c091392b1ac93e08c)
diff --git a/ipc/HwBinders/pal_ipc_server/inc/pal_server_wrapper.h b/ipc/HwBinders/pal_ipc_server/inc/pal_server_wrapper.h
index bf73018..ecab955 100644
--- a/ipc/HwBinders/pal_ipc_server/inc/pal_server_wrapper.h
+++ b/ipc/HwBinders/pal_ipc_server/inc/pal_server_wrapper.h
@@ -27,7 +27,7 @@
  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  *
  * Changes from Qualcomm Innovation Center are provided under the following license:
- * Copyright (c) 2022 Qualcomm Innovation Center, Inc. All rights reserved.
+ * Copyright (c) 2022,2024 Qualcomm Innovation Center, Inc. All rights reserved.
  * SPDX-License-Identifier: BSD-3-Clause-Clear
  */
 
@@ -129,6 +129,7 @@
 
 struct PAL : public IPAL /*, public android::hardware::hidl_death_recipient*/{
     public:
+    std::mutex mClientLock;
     PAL()
     {
         sInstance = this;
@@ -220,6 +221,7 @@
     static PAL* sInstance;
     int find_dup_fd_from_input_fd(const uint64_t streamHandle, int input_fd, int *dup_fd);
     void add_input_and_dup_fd(const uint64_t streamHandle, int input_fd, int dup_fd);
+    bool isValidstreamHandle(const uint64_t streamHandle);
 };
 
 class PalClientDeathRecipient : public android::hardware::hidl_death_recipient
diff --git a/ipc/HwBinders/pal_ipc_server/src/pal_server_wrapper.cpp b/ipc/HwBinders/pal_ipc_server/src/pal_server_wrapper.cpp
index ba2ee12..276aa01 100644
--- a/ipc/HwBinders/pal_ipc_server/src/pal_server_wrapper.cpp
+++ b/ipc/HwBinders/pal_ipc_server/src/pal_server_wrapper.cpp
@@ -27,7 +27,7 @@
  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  *
  * Changes from Qualcomm Innovation Center are provided under the following license:
- * Copyright (c) 2022-2023 Qualcomm Innovation Center, Inc. All rights reserved.
+ * Copyright (c) 2022-2024 Qualcomm Innovation Center, Inc. All rights reserved.
  * SPDX-License-Identifier: BSD-3-Clause-Clear
  */
 
@@ -101,6 +101,7 @@
     std::lock_guard<std::mutex> guard(mLock);
     ALOGD("%s : client died pid : %d", __func__, cookie);
     int pid = (int) cookie;
+    std::lock_guard<std::mutex> lock(mPalInstance->mClientLock);
     auto &clients = mPalInstance->mPalClients;
     for (auto itr = clients.begin(); itr != clients.end(); itr++) {
         auto client = *itr;
@@ -132,6 +133,7 @@
 void PAL::add_input_and_dup_fd(const uint64_t streamHandle, int input_fd, int dup_fd)
 {
     std::vector<std::pair<int, int>>::iterator it;
+    std::lock_guard<std::mutex> guard(mClientLock);
     for (auto& s: mPalClients) {
         std::lock_guard<std::mutex> lock(s->mActiveSessionsLock);
         for (int i = 0; i < s->mActiveSessions.size(); i++) {
@@ -244,6 +246,7 @@
            ALOGE("%s: No PAL instance running", __func__);
            return false;
         }
+        std::lock_guard<std::mutex> guard(PAL::getInstance()->mClientLock);
         for (auto& s: PAL::getInstance()->mPalClients) {
             std::lock_guard<std::mutex> lock(s->mActiveSessionsLock);
             for (int idx = 0; idx < s->mActiveSessions.size(); idx++) {
@@ -277,6 +280,7 @@
          * Find the original fd that was passed by client based on what
          * input and dup fd list and send that back.
          */
+        PAL::getInstance()->mClientLock.lock();
         for (auto& s: PAL::getInstance()->mPalClients) {
             std::lock_guard<std::mutex> lock(s->mActiveSessionsLock);
             for (int idx = 0; idx < s->mActiveSessions.size(); idx++) {
@@ -300,6 +304,7 @@
                 }
             }
         }
+        PAL::getInstance()->mClientLock.unlock();
 
         rwDonePayloadHidl.resize(sizeof(pal_callback_buffer));
         rwDonePayload = (PalCallbackBuffer *)rwDonePayloadHidl.data();
@@ -445,6 +450,32 @@
    print_media_config(&attr->out_media_config);
 }
 
+bool PAL::isValidstreamHandle(const uint64_t streamHandle) {
+    int pid = ::android::hardware::IPCThreadState::self()->getCallingPid();
+
+    std::lock_guard<std::mutex> guard(mClientLock);
+    for (auto itr = mPalClients.begin(); itr != mPalClients.end(); ) {
+        auto client = *itr;
+        if (client->pid == pid) {
+            std::lock_guard<std::mutex> lock(client->mActiveSessionsLock);
+            auto sItr = client->mActiveSessions.begin();
+            for (; sItr != client->mActiveSessions.end(); sItr++) {
+                if (sItr->session_handle == streamHandle) {
+                    return true;
+                }
+            }
+            ALOGE("%s: streamHandle: %pK for pid %d not found",
+                    __func__, streamHandle, pid);
+            return false;
+        }
+        itr++;
+    }
+
+    ALOGE("%s: client info for pid %d not found",
+            __func__, pid);
+    return false;
+}
+
 Return<void> PAL::ipc_pal_stream_open(const hidl_vec<PalStreamAttributes>& attr_hidl,
                             uint32_t noOfDevices,
                             const hidl_vec<PalDevice>& devs_hidl,
@@ -546,6 +577,7 @@
                           callback, (uint64_t)sr_clbk_data.get(), &stream_handle);
 
     if (!ret) {
+        std::lock_guard<std::mutex> guard(mClientLock);
         for(auto& client: mPalClients) {
             if (client->pid == pid) {
                 /*Another session from the same client*/
@@ -601,8 +633,13 @@
 Return<int32_t> PAL::ipc_pal_stream_close(const uint64_t streamHandle)
 {
     int pid = ::android::hardware::IPCThreadState::self()->getCallingPid();
-    Return<int32_t> status = pal_stream_close((pal_stream_handle_t *)streamHandle);
 
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
+    mClientLock.lock();
     for (auto itr = mPalClients.begin(); itr != mPalClients.end(); ) {
         auto client = *itr;
         if (client->pid == pid) {
@@ -635,38 +672,77 @@
             break;
         }
     }
+    mClientLock.unlock();
+
+    Return<int32_t> status = pal_stream_close((pal_stream_handle_t *)streamHandle);
+
     return status;
 }
 
 Return<int32_t> PAL::ipc_pal_stream_start(const uint64_t streamHandle) {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
 
     return pal_stream_start((pal_stream_handle_t *)streamHandle);
 }
 
 Return<int32_t> PAL::ipc_pal_stream_stop(const uint64_t streamHandle) {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_stop((pal_stream_handle_t *)streamHandle);
 }
 
 Return<int32_t> PAL::ipc_pal_stream_pause(const uint64_t streamHandle) {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_pause((pal_stream_handle_t *)streamHandle);
 }
 
 Return<int32_t> PAL::ipc_pal_stream_drain(uint64_t streamHandle, PalDrainType type)
 {
     pal_drain_type_t drain_type = (pal_drain_type_t) type;
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_drain((pal_stream_handle_t *)streamHandle,
                              drain_type);
 }
 
 Return<int32_t> PAL::ipc_pal_stream_flush(const uint64_t streamHandle) {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_flush((pal_stream_handle_t *)streamHandle);
 }
 
 Return<int32_t> PAL::ipc_pal_stream_suspend(const uint64_t streamHandle) {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_suspend((pal_stream_handle_t *)streamHandle);
 }
 
 Return<int32_t> PAL::ipc_pal_stream_resume(const uint64_t streamHandle) {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_resume((pal_stream_handle_t *)streamHandle);
 }
 
@@ -679,6 +755,11 @@
     pal_buffer_config_t out_buf_cfg, in_buf_cfg;
     PalBufferConfig in_buff_config_ret, out_buff_config_ret;
 
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     in_buf_cfg.buf_count = in_buff_config.buf_count;
     in_buf_cfg.buf_size = in_buff_config.buf_size;
     if (in_buff_config.max_metadata_size) {
@@ -731,6 +812,11 @@
                                           const hidl_vec<PalBuffer>& buff_hidl) {
     struct pal_buffer buf = {0};
 
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     buf.size = buff_hidl.data()->size;
     std::vector<uint8_t> dataBuffer;
     if (buff_hidl.data()->buffer.size() == buf.size) {
@@ -749,6 +835,7 @@
     std::vector<uint8_t> bufMetadata(buf.metadata_size, 0);
     buf.metadata = bufMetadata.data();
     auto stream_media_config = std::make_shared<pal_media_config>();
+    mClientLock.lock();
     for (auto& s: PAL::getInstance()->mPalClients) {
         std::lock_guard<std::mutex> lock(s->mActiveSessionsLock);
         for (auto session : s->mActiveSessions) {
@@ -759,6 +846,7 @@
             }
         }
     }
+    mClientLock.unlock();
     auto metadataParser = std::make_unique<MetadataParser>();
     metadataParser->fillMetaData(buf.metadata, buf.frame_index, buf.size,
                                  stream_media_config.get());
@@ -787,6 +875,11 @@
     struct pal_buffer buf = {0};
     hidl_vec<PalBuffer> outBuff_hidl;
 
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     buf.size = inBuff_hidl.data()->size;
     std::vector<uint8_t> dataBuffer(buf.size, 0);
     buf.buffer = dataBuffer.data();
@@ -831,6 +924,12 @@
         ALOGE("Invalid payload size");
         return -EINVAL;
     }
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     param_payload = (pal_param_payload *)calloc (1,
                                     sizeof(pal_param_payload) + paramPayload.data()->size);
     if (!param_payload) {
@@ -852,6 +951,12 @@
     int32_t ret = 0;
     pal_param_payload *param_payload;
     hidl_vec<PalParamPayload> paramPayload;
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     ret = pal_stream_get_param((pal_stream_handle_t *)streamHandle, paramId, &param_payload);
     if (ret == 0) {
         paramPayload.resize(sizeof(PalParamPayload));
@@ -880,10 +985,16 @@
     int cnt = 0;
     int32_t ret = -ENOMEM;
 
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     if (noOfDevices > devs_hidl.size()) {
         ALOGE("Invalid noOfDevices");
         return -EINVAL;
     }
+
     if (devs_hidl.size()) {
         devices = (struct pal_device *)calloc (1,
                                     sizeof(struct pal_device) * noOfDevices);
@@ -924,6 +1035,12 @@
     struct pal_volume_data *volume = nullptr;
     uint32_t noOfVolPairs = vol.data()->noOfVolPairs;
     int32_t ret = -ENOMEM;
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     if (1 != vol.size()) {
         ALOGE("Invalid vol pairs");
         return -EINVAL;
@@ -964,6 +1081,11 @@
 Return<int32_t> PAL::ipc_pal_stream_set_mute(const uint64_t streamHandle,
                                     bool state)
 {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_stream_set_mute((pal_stream_handle_t *)streamHandle, state);
 }
 
@@ -982,6 +1104,12 @@
 {
     struct pal_session_time stime;
     int32_t ret = 0;
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     hidl_vec<PalSessionTime> sessTime_hidl;
     sessTime_hidl.resize(sizeof(struct pal_session_time));
     ret = pal_get_timestamp((pal_stream_handle_t *)streamHandle, &stime);
@@ -994,6 +1122,11 @@
                                           const PalAudioEffect effect,
                                           bool enable)
 {
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return -EINVAL;
+    }
+
     return pal_add_remove_effect((pal_stream_handle_t *)streamHandle,
                                    (pal_audio_effect_t) effect, enable);
 }
@@ -1053,6 +1186,12 @@
     int32_t ret = 0;
     struct pal_mmap_buffer info;
     hidl_vec<PalMmapBuffer> mMapBuffer_hidl;
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     mMapBuffer_hidl.resize(sizeof(struct pal_mmap_buffer));
     ret = pal_stream_create_mmap_buffer((pal_stream_handle_t *)streamHandle, min_size_frames, &info);
     mMapBuffer_hidl.data()->buffer = (uint64_t)info.buffer;
@@ -1070,6 +1209,12 @@
     int32_t ret = 0;
     struct pal_mmap_position mmap_position;
     hidl_vec<PalMmapPosition> mmap_position_hidl;
+
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     mmap_position_hidl.resize(sizeof(struct pal_mmap_position));
     ret = pal_stream_get_mmap_position((pal_stream_handle_t *)streamHandle, &mmap_position);
     memcpy(mmap_position_hidl.data(), &mmap_position, sizeof(struct pal_mmap_position));
@@ -1099,6 +1244,11 @@
     size_t sz = size;
     hidl_vec<uint8_t> payloadRet;
 
+    if (!isValidstreamHandle(streamHandle)) {
+        ALOGE("%s: Invalid streamHandle: %pK", __func__, streamHandle);
+        return Void();
+    }
+
     if (size > 0) {
         payload = (uint8_t *)calloc(1, size);
         if (!payload) {