| /* |
| * Copyright (C) 2019 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #define LOG_TAG "neuralnetworks_hidl_hal_test" |
| |
| #include "VtsHalNeuralnetworks.h" |
| |
| #include "1.2/Callbacks.h" |
| #include "ExecutionBurstController.h" |
| #include "ExecutionBurstServer.h" |
| #include "GeneratedTestHarness.h" |
| #include "TestHarness.h" |
| #include "Utils.h" |
| |
| #include <android-base/logging.h> |
| |
| namespace android::hardware::neuralnetworks::V1_2::vts::functional { |
| |
| using nn::ExecutionBurstController; |
| using nn::RequestChannelSender; |
| using nn::ResultChannelReceiver; |
| using V1_0::ErrorStatus; |
| using V1_0::Request; |
| using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback; |
| |
| // This constant value represents the length of an FMQ that is large enough to |
| // return a result from a burst execution for all of the generated test cases. |
| constexpr size_t kExecutionBurstChannelLength = 1024; |
| |
| // This constant value represents a length of an FMQ that is not large enough |
| // to return a result from a burst execution for some of the generated test |
| // cases. |
| constexpr size_t kExecutionBurstChannelSmallLength = 8; |
| |
| ///////////////////////// UTILITY FUNCTIONS ///////////////////////// |
| |
| static bool badTiming(Timing timing) { |
| return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; |
| } |
| |
| static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback, |
| std::unique_ptr<RequestChannelSender>* sender, |
| std::unique_ptr<ResultChannelReceiver>* receiver, |
| sp<IBurstContext>* context, |
| size_t resultChannelLength = kExecutionBurstChannelLength) { |
| ASSERT_NE(nullptr, preparedModel.get()); |
| ASSERT_NE(nullptr, sender); |
| ASSERT_NE(nullptr, receiver); |
| ASSERT_NE(nullptr, context); |
| |
| // create FMQ objects |
| auto [fmqRequestChannel, fmqRequestDescriptor] = |
| RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true); |
| auto [fmqResultChannel, fmqResultDescriptor] = |
| ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true); |
| ASSERT_NE(nullptr, fmqRequestChannel.get()); |
| ASSERT_NE(nullptr, fmqResultChannel.get()); |
| ASSERT_NE(nullptr, fmqRequestDescriptor); |
| ASSERT_NE(nullptr, fmqResultDescriptor); |
| |
| // configure burst |
| ErrorStatus errorStatus; |
| sp<IBurstContext> burstContext; |
| const Return<void> ret = preparedModel->configureExecutionBurst( |
| callback, *fmqRequestDescriptor, *fmqResultDescriptor, |
| [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) { |
| errorStatus = status; |
| burstContext = context; |
| }); |
| ASSERT_TRUE(ret.isOk()); |
| ASSERT_EQ(ErrorStatus::NONE, errorStatus); |
| ASSERT_NE(nullptr, burstContext.get()); |
| |
| // return values |
| *sender = std::move(fmqRequestChannel); |
| *receiver = std::move(fmqResultChannel); |
| *context = burstContext; |
| } |
| |
| static void createBurstWithResultChannelLength( |
| const sp<IPreparedModel>& preparedModel, size_t resultChannelLength, |
| std::shared_ptr<ExecutionBurstController>* controller) { |
| ASSERT_NE(nullptr, preparedModel.get()); |
| ASSERT_NE(nullptr, controller); |
| |
| // create FMQ objects |
| std::unique_ptr<RequestChannelSender> sender; |
| std::unique_ptr<ResultChannelReceiver> receiver; |
| sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); |
| sp<IBurstContext> context; |
| ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context, |
| resultChannelLength)); |
| ASSERT_NE(nullptr, sender.get()); |
| ASSERT_NE(nullptr, receiver.get()); |
| ASSERT_NE(nullptr, context.get()); |
| |
| // return values |
| *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver), |
| context, callback); |
| } |
| |
| // Primary validation function. This function will take a valid serialized |
| // request, apply a mutation to it to invalidate the serialized request, then |
| // pass it to interface calls that use the serialized request. Note that the |
| // serialized request here is passed by value, and any mutation to the |
| // serialized request does not leave this function. |
| static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver, |
| const std::string& message, std::vector<FmqRequestDatum> serialized, |
| const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) { |
| mutation(&serialized); |
| |
| // skip if packet is too large to send |
| if (serialized.size() > kExecutionBurstChannelLength) { |
| return; |
| } |
| |
| SCOPED_TRACE(message); |
| |
| // send invalid packet |
| ASSERT_TRUE(sender->sendPacket(serialized)); |
| |
| // receive error |
| auto results = receiver->getBlocking(); |
| ASSERT_TRUE(results.has_value()); |
| const auto [status, outputShapes, timing] = std::move(*results); |
| EXPECT_NE(ErrorStatus::NONE, status); |
| EXPECT_EQ(0u, outputShapes.size()); |
| EXPECT_TRUE(badTiming(timing)); |
| } |
| |
| // For validation, valid packet entries are mutated to invalid packet entries, |
| // or invalid packet entries are inserted into valid packets. This function |
| // creates pre-set invalid packet entries for convenience. |
| static std::vector<FmqRequestDatum> createBadRequestPacketEntries() { |
| const FmqRequestDatum::PacketInformation packetInformation = { |
| /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10, |
| /*.numberOfPools=*/10}; |
| const FmqRequestDatum::OperandInformation operandInformation = { |
| /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10}; |
| const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max(); |
| std::vector<FmqRequestDatum> bad(7); |
| bad[0].packetInformation(packetInformation); |
| bad[1].inputOperandInformation(operandInformation); |
| bad[2].inputOperandDimensionValue(0); |
| bad[3].outputOperandInformation(operandInformation); |
| bad[4].outputOperandDimensionValue(0); |
| bad[5].poolIdentifier(invalidPoolIdentifier); |
| bad[6].measureTiming(MeasureTiming::YES); |
| return bad; |
| } |
| |
| // For validation, valid packet entries are mutated to invalid packet entries, |
| // or invalid packet entries are inserted into valid packets. This function |
| // retrieves pre-set invalid packet entries for convenience. This function |
| // caches these data so they can be reused on subsequent validation checks. |
| static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() { |
| static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries(); |
| return bad; |
| } |
| |
| ///////////////////////// REMOVE DATUM //////////////////////////////////// |
| |
| static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, |
| const std::vector<FmqRequestDatum>& serialized) { |
| for (size_t index = 0; index < serialized.size(); ++index) { |
| const std::string message = "removeDatum: removed datum at index " + std::to_string(index); |
| validate(sender, receiver, message, serialized, |
| [index](std::vector<FmqRequestDatum>* serialized) { |
| serialized->erase(serialized->begin() + index); |
| }); |
| } |
| } |
| |
| ///////////////////////// ADD DATUM //////////////////////////////////// |
| |
| static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, |
| const std::vector<FmqRequestDatum>& serialized) { |
| const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries(); |
| for (size_t index = 0; index <= serialized.size(); ++index) { |
| for (size_t type = 0; type < extra.size(); ++type) { |
| const std::string message = "addDatum: added datum type " + std::to_string(type) + |
| " at index " + std::to_string(index); |
| validate(sender, receiver, message, serialized, |
| [index, type, &extra](std::vector<FmqRequestDatum>* serialized) { |
| serialized->insert(serialized->begin() + index, extra[type]); |
| }); |
| } |
| } |
| } |
| |
| ///////////////////////// MUTATE DATUM //////////////////////////////////// |
| |
| static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) { |
| using Discriminator = FmqRequestDatum::hidl_discriminator; |
| |
| const bool differentValues = (lhs != rhs); |
| const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator()); |
| const auto discriminator = rhs.getDiscriminator(); |
| const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue || |
| discriminator == Discriminator::outputOperandDimensionValue); |
| |
| return differentValues && !(sameDiscriminator && isDimensionValue); |
| } |
| |
| static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, |
| const std::vector<FmqRequestDatum>& serialized) { |
| const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries(); |
| for (size_t index = 0; index < serialized.size(); ++index) { |
| for (size_t type = 0; type < change.size(); ++type) { |
| if (interestingCase(serialized[index], change[type])) { |
| const std::string message = "mutateDatum: changed datum at index " + |
| std::to_string(index) + " to datum type " + |
| std::to_string(type); |
| validate(sender, receiver, message, serialized, |
| [index, type, &change](std::vector<FmqRequestDatum>* serialized) { |
| (*serialized)[index] = change[type]; |
| }); |
| } |
| } |
| } |
| } |
| |
| ///////////////////////// BURST VALIATION TESTS //////////////////////////////////// |
| |
| static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel, |
| const Request& request) { |
| // create burst |
| std::unique_ptr<RequestChannelSender> sender; |
| std::unique_ptr<ResultChannelReceiver> receiver; |
| sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); |
| sp<IBurstContext> context; |
| ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); |
| ASSERT_NE(nullptr, sender.get()); |
| ASSERT_NE(nullptr, receiver.get()); |
| ASSERT_NE(nullptr, context.get()); |
| |
| // load memory into callback slots |
| std::vector<intptr_t> keys; |
| keys.reserve(request.pools.size()); |
| std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys), |
| [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); }); |
| const std::vector<int32_t> slots = callback->getSlots(request.pools, keys); |
| |
| // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for |
| // subsequent slot validation testing) |
| ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) { |
| return slot != std::numeric_limits<int32_t>::max(); |
| })); |
| |
| // serialize the request |
| const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots); |
| |
| // validations |
| removeDatumTest(sender.get(), receiver.get(), serialized); |
| addDatumTest(sender.get(), receiver.get(), serialized); |
| mutateDatumTest(sender.get(), receiver.get(), serialized); |
| } |
| |
| // This test validates that when the Result message size exceeds length of the |
| // result FMQ, the service instance gracefully fails and returns an error. |
| static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel, |
| const Request& request) { |
| // create regular burst |
| std::shared_ptr<ExecutionBurstController> controllerRegular; |
| ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength( |
| preparedModel, kExecutionBurstChannelLength, &controllerRegular)); |
| ASSERT_NE(nullptr, controllerRegular.get()); |
| |
| // create burst with small output channel |
| std::shared_ptr<ExecutionBurstController> controllerSmall; |
| ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength( |
| preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall)); |
| ASSERT_NE(nullptr, controllerSmall.get()); |
| |
| // load memory into callback slots |
| std::vector<intptr_t> keys(request.pools.size()); |
| for (size_t i = 0; i < keys.size(); ++i) { |
| keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]); |
| } |
| |
| // collect serialized result by running regular burst |
| const auto [statusRegular, outputShapesRegular, timingRegular] = |
| controllerRegular->compute(request, MeasureTiming::NO, keys); |
| |
| // skip test if regular burst output isn't useful for testing a failure |
| // caused by having too small of a length for the result FMQ |
| const std::vector<FmqResultDatum> serialized = |
| ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular); |
| if (statusRegular != ErrorStatus::NONE || |
| serialized.size() <= kExecutionBurstChannelSmallLength) { |
| return; |
| } |
| |
| // by this point, execution should fail because the result channel isn't |
| // large enough to return the serialized result |
| const auto [statusSmall, outputShapesSmall, timingSmall] = |
| controllerSmall->compute(request, MeasureTiming::NO, keys); |
| EXPECT_NE(ErrorStatus::NONE, statusSmall); |
| EXPECT_EQ(0u, outputShapesSmall.size()); |
| EXPECT_TRUE(badTiming(timingSmall)); |
| } |
| |
| ///////////////////////////// ENTRY POINT ////////////////////////////////// |
| |
| void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel, |
| const Request& request) { |
| ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request)); |
| ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request)); |
| } |
| |
| } // namespace android::hardware::neuralnetworks::V1_2::vts::functional |