Add multi-dex support to hiddenapi

Update hiddenapi so that it is able to update several DexFiles,
within a single physical file.

All the dex files will be copied first, and then hiddenapi data
is appended to the end of the file.

Bug: 266950186
Test: The output for single-dex is exactly identical (full AOSP build).
Change-Id: I51643fe0097d98c862e25adbc65c9024cad9a131
diff --git a/tools/hiddenapi/hiddenapi.cc b/tools/hiddenapi/hiddenapi.cc
index 644ec72..c4e7d40 100644
--- a/tools/hiddenapi/hiddenapi.cc
+++ b/tools/hiddenapi/hiddenapi.cc
@@ -16,14 +16,15 @@
 
 #include <fstream>
 #include <iostream>
+#include <iterator>
 #include <map>
 #include <set>
 #include <string>
 #include <string_view>
+#include <vector>
 
 #include "android-base/stringprintf.h"
 #include "android-base/strings.h"
-
 #include "base/bit_utils.h"
 #include "base/hiddenapi_flags.h"
 #include "base/mem_map.h"
@@ -34,6 +35,7 @@
 #include "dex/art_dex_file_loader.h"
 #include "dex/class_accessor-inl.h"
 #include "dex/dex_file-inl.h"
+#include "dex/dex_file_structs.h"
 
 namespace art {
 namespace hiddenapi {
@@ -248,6 +250,13 @@
     OpenDexFiles(dex_paths, open_writable, ignore_empty);
   }
 
+  template <typename Fn>
+  void ForEachDexClass(const DexFile* dex_file, Fn fn) {
+    for (ClassAccessor accessor : dex_file->GetClasses()) {
+      fn(DexClass(accessor));
+    }
+  }
+
   template<typename Fn>
   void ForEachDexClass(Fn fn) {
     for (auto& dex_file : dex_files_) {
@@ -669,215 +678,127 @@
 // Edits a dex file, inserting a new HiddenapiClassData section.
 class DexFileEditor final {
  public:
-  DexFileEditor(const DexFile& old_dex, const std::vector<uint8_t>& hiddenapi_class_data)
-      : old_dex_(old_dex),
-        hiddenapi_class_data_(hiddenapi_class_data),
-        loaded_dex_header_(nullptr),
-        loaded_dex_maplist_(nullptr) {}
-
-  // Copies dex file into a backing data vector, appends the given HiddenapiClassData
-  // and updates the MapList.
-  void Encode() {
+  // Add dex file to copy to output (possibly several files for multi-dex).
+  void Add(const DexFile* dex, const std::vector<uint8_t>&& hiddenapi_data) {
     // We do not support non-standard dex encodings, e.g. compact dex.
-    CHECK(old_dex_.IsStandardDexFile());
-
-    // If there are no data to append, copy the old dex file and return.
-    if (hiddenapi_class_data_.empty()) {
-      AllocateMemory(old_dex_.Size());
-      Append(old_dex_.Begin(), old_dex_.Size(), /* update_header= */ false);
-      return;
-    }
-
-    // Find the old MapList, find its size.
-    const dex::MapList* old_map = old_dex_.GetMapList();
-    CHECK_LT(old_map->size_, std::numeric_limits<uint32_t>::max());
-
-    // Compute the size of the new dex file. We append the HiddenapiClassData,
-    // one MapItem and possibly some padding to align the new MapList.
-    CHECK(IsAligned<kMapListAlignment>(old_dex_.Size()))
-        << "End of input dex file is not 4-byte aligned, possibly because its MapList is not "
-        << "at the end of the file.";
-    size_t size_delta =
-        RoundUp(hiddenapi_class_data_.size(), kMapListAlignment) + sizeof(dex::MapItem);
-    size_t new_size = old_dex_.Size() + size_delta;
-    AllocateMemory(new_size);
-
-    // Copy the old dex file into the backing data vector. Load the copied
-    // dex file to obtain pointers to its header and MapList.
-    Append(old_dex_.Begin(), old_dex_.Size(), /* update_header= */ false);
-    ReloadDex(/* verify= */ false);
-
-    // Truncate the new dex file before the old MapList. This assumes that
-    // the MapList is the last entry in the dex file. This is currently true
-    // for our tooling.
-    // TODO: Implement the general case by zero-ing the old MapList (turning
-    // it into padding.
-    RemoveOldMapList();
-
-    // Append HiddenapiClassData.
-    size_t payload_offset = AppendHiddenapiClassData();
-
-    // Wrute new MapList with an entry for HiddenapiClassData.
-    CreateMapListWithNewItem(payload_offset);
-
-    // Check that the pre-computed size matches the actual size.
-    CHECK_EQ(offset_, new_size);
-
-    // Reload to all data structures.
-    ReloadDex(/* verify= */ false);
-
-    // Update the dex checksum.
-    UpdateChecksum();
-
-    // Run DexFileVerifier on the new dex file as a CHECK.
-    ReloadDex(/* verify= */ true);
+    CHECK(dex->IsStandardDexFile());
+    inputs_.emplace_back(dex, std::move(hiddenapi_data));
   }
 
   // Writes the edited dex file into a file.
   void WriteTo(const std::string& path) {
-    CHECK(!data_.empty());
+    std::vector<uint8_t> output;
+
+    // Copy the old dex files into the backing data vector.
+    size_t truncated_size = 0;
+    std::vector<size_t> header_offset;
+    for (size_t i = 0; i < inputs_.size(); i++) {
+      const DexFile* dex = inputs_[i].first;
+      header_offset.push_back(output.size());
+      std::copy(
+          dex->Begin(), dex->Begin() + dex->GetHeader().file_size_, std::back_inserter(output));
+
+      // Clear the old map list (make it into padding).
+      const dex::MapList* map = dex->GetMapList();
+      size_t map_off = dex->GetHeader().map_off_;
+      size_t map_size = sizeof(map->size_) + map->size_ * sizeof(map->list_[0]);
+      CHECK_LE(map_off, output.size()) << "Map list past the end of file";
+      CHECK_EQ(map_size, output.size() - map_off) << "Map list expected at the end of file";
+      std::fill_n(output.data() + map_off, map_size, 0);
+      truncated_size = output.size() - map_size;
+    }
+    output.resize(truncated_size);  // Truncate last map list.
+
+    // Append the hidden api data into the backing data vector.
+    std::vector<size_t> hiddenapi_offset;
+    for (size_t i = 0; i < inputs_.size(); i++) {
+      const std::vector<uint8_t>& hiddenapi_data = inputs_[i].second;
+      output.resize(RoundUp(output.size(), kHiddenapiClassDataAlignment));  // Align.
+      hiddenapi_offset.push_back(output.size());
+      std::copy(hiddenapi_data.begin(), hiddenapi_data.end(), std::back_inserter(output));
+    }
+
+    // Update the dex headers and map lists.
+    for (size_t i = 0; i < inputs_.size(); i++) {
+      output.resize(RoundUp(output.size(), kMapListAlignment));  // Align.
+
+      const DexFile* dex = inputs_[i].first;
+      const dex::MapList* map = dex->GetMapList();
+      std::vector<dex::MapItem> items(map->list_, map->list_ + map->size_);
+
+      // Check the header entry.
+      CHECK(!items.empty());
+      CHECK_EQ(items[0].type_, DexFile::kDexTypeHeaderItem);
+      CHECK_EQ(items[0].offset_, header_offset[i]);
+
+      // Check and remove the old map list entry (it does not have to be last).
+      auto is_map_list = [](auto it) { return it.type_ == DexFile::kDexTypeMapList; };
+      auto it = std::find_if(items.begin(), items.end(), is_map_list);
+      CHECK(it != items.end());
+      CHECK_EQ(it->offset_, dex->GetHeader().map_off_);
+      items.erase(it);
+
+      // Write new map list.
+      if (!inputs_[i].second.empty()) {
+        uint32_t payload_offset = hiddenapi_offset[i];
+        items.push_back(dex::MapItem{DexFile::kDexTypeHiddenapiClassData, 0, 1u, payload_offset});
+      }
+      uint32_t map_offset = output.size();
+      items.push_back(dex::MapItem{DexFile::kDexTypeMapList, 0, 1u, map_offset});
+      uint32_t item_count = items.size();
+      Append(&output, &item_count, 1);
+      Append(&output, items.data(), items.size());
+
+      // Update header.
+      uint8_t* begin = output.data() + header_offset[i];
+      auto* header = reinterpret_cast<DexFile::Header*>(begin);
+      header->map_off_ = map_offset;
+      if (i + 1 < inputs_.size()) {
+        CHECK_EQ(header->file_size_, header_offset[i + 1] - header_offset[i]);
+      } else {
+        // Extend last dex file until the end of the file.
+        header->data_size_ = output.size() - header->data_off_;
+        header->file_size_ = output.size() - header_offset[i];
+      }
+      header->checksum_ = DexFile::CalculateChecksum(begin, header->file_size_);
+      // TODO: We should also update the SHA1 signature.
+    }
+
+    // Write the output file.
+    CHECK(!output.empty());
     std::ofstream ofs(path.c_str(), std::ofstream::out | std::ofstream::binary);
-    ofs.write(reinterpret_cast<const char*>(data_.data()), data_.size());
+    ofs.write(reinterpret_cast<const char*>(output.data()), output.size());
     ofs.flush();
     CHECK(ofs.good());
     ofs.close();
+
+    ReloadDex(path.c_str());
   }
 
  private:
   static constexpr size_t kMapListAlignment = 4u;
   static constexpr size_t kHiddenapiClassDataAlignment = 4u;
 
-  void ReloadDex(bool verify) {
+  void ReloadDex(const char* filename) {
     std::string error_msg;
-    DexFileLoader loader;
-    loaded_dex_ = loader.Open(
-        data_.data(),
-        data_.size(),
-        "test_location",
-        old_dex_.GetLocationChecksum(),
-        /* oat_dex_file= */ nullptr,
-        /* verify= */ verify,
-        /* verify_checksum= */ verify,
-        &error_msg);
-    if (loaded_dex_.get() == nullptr) {
-      LOG(FATAL) << "Failed to load edited dex file: " << error_msg;
-      UNREACHABLE();
-    }
-
-    // Load the location of header and map list before we start editing the file.
-    loaded_dex_header_ = const_cast<DexFile::Header*>(&loaded_dex_->GetHeader());
-    loaded_dex_maplist_ = const_cast<dex::MapList*>(loaded_dex_->GetMapList());
+    ArtDexFileLoader loader;
+    std::vector<std::unique_ptr<const DexFile>> dex_files;
+    bool ok = loader.Open(filename,
+                          filename,
+                          /*verify*/ true,
+                          /*verify_checksum*/ true,
+                          &error_msg,
+                          &dex_files);
+    CHECK(ok) << "Failed to load edited dex file: " << error_msg;
   }
 
-  DexFile::Header& GetHeader() const {
-    CHECK(loaded_dex_header_ != nullptr);
-    return *loaded_dex_header_;
+  template <typename T>
+  void Append(std::vector<uint8_t>* output, const T* src, size_t len) {
+    const uint8_t* ptr = reinterpret_cast<const uint8_t*>(src);
+    std::copy(ptr, ptr + len * sizeof(T), std::back_inserter(*output));
   }
 
-  dex::MapList& GetMapList() const {
-    CHECK(loaded_dex_maplist_ != nullptr);
-    return *loaded_dex_maplist_;
-  }
-
-  void AllocateMemory(size_t total_size) {
-    data_.clear();
-    data_.resize(total_size);
-    CHECK(IsAligned<kMapListAlignment>(data_.data()));
-    CHECK(IsAligned<kHiddenapiClassDataAlignment>(data_.data()));
-    offset_ = 0;
-  }
-
-  uint8_t* GetCurrentDataPtr() {
-    return data_.data() + offset_;
-  }
-
-  void UpdateDataSize(off_t delta, bool update_header) {
-    offset_ += delta;
-    if (update_header) {
-      DexFile::Header& header = GetHeader();
-      header.file_size_ += delta;
-      header.data_size_ += delta;
-    }
-  }
-
-  template<typename T>
-  T* Append(const T* src, size_t len, bool update_header = true) {
-    CHECK_LE(offset_ + len, data_.size());
-    uint8_t* dst = GetCurrentDataPtr();
-    memcpy(dst, src, len);
-    UpdateDataSize(len, update_header);
-    return reinterpret_cast<T*>(dst);
-  }
-
-  void InsertPadding(size_t alignment) {
-    size_t len = RoundUp(offset_, alignment) - offset_;
-    std::vector<uint8_t> padding(len, 0);
-    Append(padding.data(), padding.size());
-  }
-
-  void RemoveOldMapList() {
-    size_t map_size = GetMapList().Size();
-    uint8_t* map_start = reinterpret_cast<uint8_t*>(&GetMapList());
-    CHECK_EQ(map_start + map_size, GetCurrentDataPtr()) << "MapList not at the end of dex file";
-    UpdateDataSize(-static_cast<off_t>(map_size), /* update_header= */ true);
-    CHECK_EQ(map_start, GetCurrentDataPtr());
-    loaded_dex_maplist_ = nullptr;  // do not use this map list any more
-  }
-
-  void CreateMapListWithNewItem(size_t payload_offset) {
-    InsertPadding(/* alignment= */ kMapListAlignment);
-
-    size_t new_map_offset = offset_;
-    dex::MapList* map = Append(old_dex_.GetMapList(), old_dex_.GetMapList()->Size());
-
-    // Check last map entry is a pointer to itself.
-    dex::MapItem& old_item = map->list_[map->size_ - 1];
-    CHECK(old_item.type_ == DexFile::kDexTypeMapList);
-    CHECK_EQ(old_item.size_, 1u);
-    CHECK_EQ(old_item.offset_, GetHeader().map_off_);
-
-    // Create a new MapItem entry with new MapList details.
-    dex::MapItem new_item;
-    new_item.type_ = old_item.type_;
-    new_item.unused_ = 0u;  // initialize to ensure dex output is deterministic (b/119308882)
-    new_item.size_ = old_item.size_;
-    new_item.offset_ = new_map_offset;
-
-    // Update pointer in the header.
-    GetHeader().map_off_ = new_map_offset;
-
-    // Append a new MapItem and return its pointer.
-    map->size_++;
-    Append(&new_item, sizeof(dex::MapItem));
-
-    // Change penultimate entry to point to metadata.
-    old_item.type_ = DexFile::kDexTypeHiddenapiClassData;
-    old_item.size_ = 1u;  // there is only one section
-    old_item.offset_ = payload_offset;
-  }
-
-  size_t AppendHiddenapiClassData() {
-    size_t payload_offset = offset_;
-    CHECK_EQ(kMapListAlignment, kHiddenapiClassDataAlignment);
-    CHECK(IsAligned<kHiddenapiClassDataAlignment>(payload_offset))
-        << "Should not need to align the section, previous data was already aligned";
-    Append(hiddenapi_class_data_.data(), hiddenapi_class_data_.size());
-    return payload_offset;
-  }
-
-  void UpdateChecksum() {
-    GetHeader().checksum_ = loaded_dex_->CalculateChecksum();
-  }
-
-  const DexFile& old_dex_;
-  const std::vector<uint8_t>& hiddenapi_class_data_;
-
-  std::vector<uint8_t> data_;
-  size_t offset_;
-
-  std::unique_ptr<const DexFile> loaded_dex_;
-  DexFile::Header* loaded_dex_header_;
-  dex::MapList* loaded_dex_maplist_;
+  std::vector<std::pair<const DexFile*, const std::vector<uint8_t>>> inputs_;
 };
 
 class HiddenApi final {
@@ -994,45 +915,40 @@
       ClassPath boot_classpath({ input_path },
                                /* open_writable= */ false,
                                /* ignore_empty= */ false);
-      std::vector<const DexFile*> input_dex_files = boot_classpath.GetDexFiles();
-      CHECK_EQ(input_dex_files.size(), 1u);
-      const DexFile& input_dex = *input_dex_files[0];
-
-      HiddenapiClassDataBuilder builder(input_dex);
-      boot_classpath.ForEachDexClass([&](const DexClass& boot_class) {
-        builder.BeginClassDef(boot_class.GetClassDefIndex());
-        if (boot_class.GetData() != nullptr) {
-          auto fn_shared = [&](const DexMember& boot_member) {
-            auto signature = boot_member.GetApiEntry();
-            auto it = api_list.find(signature);
-            bool api_list_found = (it != api_list.end());
-            CHECK(!force_assign_all_ || api_list_found)
-                << "Could not find hiddenapi flags for dex entry: " << signature;
-            if (api_list_found && it->second.GetIntValue() > max_hiddenapi_level_.GetIntValue()) {
-              ApiList without_domain(it->second.GetIntValue());
-              LOG(ERROR) << "Hidden api flag " << without_domain
-                         << " for member " << signature
-                         << " in " << input_path
-                         << " exceeds maximum allowable flag "
-                         << max_hiddenapi_level_;
-              max_hiddenapi_level_error = true;
-            } else {
-              builder.WriteFlags(api_list_found ? it->second : ApiList::Sdk());
-            }
-          };
-          auto fn_field = [&](const ClassAccessor::Field& boot_field) {
-            fn_shared(DexMember(boot_class, boot_field));
-          };
-          auto fn_method = [&](const ClassAccessor::Method& boot_method) {
-            fn_shared(DexMember(boot_class, boot_method));
-          };
-          boot_class.VisitFieldsAndMethods(fn_field, fn_field, fn_method, fn_method);
-        }
-        builder.EndClassDef(boot_class.GetClassDefIndex());
-      });
-
-      DexFileEditor dex_editor(input_dex, builder.GetData());
-      dex_editor.Encode();
+      DexFileEditor dex_editor;
+      for (const DexFile* input_dex : boot_classpath.GetDexFiles()) {
+        HiddenapiClassDataBuilder builder(*input_dex);
+        boot_classpath.ForEachDexClass(input_dex, [&](const DexClass& boot_class) {
+          builder.BeginClassDef(boot_class.GetClassDefIndex());
+          if (boot_class.GetData() != nullptr) {
+            auto fn_shared = [&](const DexMember& boot_member) {
+              auto signature = boot_member.GetApiEntry();
+              auto it = api_list.find(signature);
+              bool api_list_found = (it != api_list.end());
+              CHECK(!force_assign_all_ || api_list_found)
+                  << "Could not find hiddenapi flags for dex entry: " << signature;
+              if (api_list_found && it->second.GetIntValue() > max_hiddenapi_level_.GetIntValue()) {
+                ApiList without_domain(it->second.GetIntValue());
+                LOG(ERROR) << "Hidden api flag " << without_domain << " for member " << signature
+                           << " in " << input_path << " exceeds maximum allowable flag "
+                           << max_hiddenapi_level_;
+                max_hiddenapi_level_error = true;
+              } else {
+                builder.WriteFlags(api_list_found ? it->second : ApiList::Sdk());
+              }
+            };
+            auto fn_field = [&](const ClassAccessor::Field& boot_field) {
+              fn_shared(DexMember(boot_class, boot_field));
+            };
+            auto fn_method = [&](const ClassAccessor::Method& boot_method) {
+              fn_shared(DexMember(boot_class, boot_method));
+            };
+            boot_class.VisitFieldsAndMethods(fn_field, fn_field, fn_method, fn_method);
+          }
+          builder.EndClassDef(boot_class.GetClassDefIndex());
+        });
+        dex_editor.Add(input_dex, std::move(builder.GetData()));
+      }
       dex_editor.WriteTo(output_path);
     }