From aa5e36456a5720fc341f59ca3fbd913c10f501e7 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:37:45 -0700 Subject: [PATCH 001/121] [TRT EP] Fix multithreading bug of getting the corrupted trt engine instance (#17507) Revert to the old TRT EP behavior of securing the whole compute_function by lock_guard. Current TRT EP which only puts lock_guard around a critical section (obvious wrong) inside compute_function. The issue can happen where one thread is updating the engine in compute_function whereas another thread still accesses the stale/corrupted engine instance in compute_function, for example, the code outside the critical section, `int total_bindings = trt_engine->getNbBindings()`. So, make the whole compute_function the critical section should be okay. --- .../tensorrt/tensorrt_execution_provider.cc | 402 +++++++++--------- 1 file changed, 200 insertions(+), 202 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e90417a6d14fc..96893f63b4540 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2433,6 +2433,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, + // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; @@ -2475,237 +2480,230 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector lock(*(trt_state->tensorrt_mu_ptr)); - - // Load serialized engine - if (trt_state->engine_cache_enable && trt_engine == nullptr) { - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && !trt_state->engine_decryption_enable && profile_file) { - // Deserialize profile - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - - // Prepare buffer - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Decrypt engine - size_t engine_size = 0; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (*(trt_state->engine) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (*(trt_state->engine) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; } + } - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. - // TRT EP will help determine the min/max/opt profile values based on current input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); - } + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. + // TRT EP will help determine the min/max/opt profile values based on current input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); } } + } - // Regenerate engine - if (engine_update) { - // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. - if (GetPerThreadContext().IsTensorRTContextInMap(fused_node_name)) { - GetPerThreadContext().ResetTensorRTContext(fused_node_name); - } + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. + if (GetPerThreadContext().IsTensorRTContextInMap(fused_node_name)) { + GetPerThreadContext().ResetTensorRTContext(fused_node_name); + } - trt_state->engine->reset(); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); - } + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { - trt_config->setInt8Calibrator(nullptr); - if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); - } + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } + } - // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - } + // Set precision + if (trt_state->fp16_enable && trt_state->int8_enable) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + } else if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + } - // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); - } + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } - // enable sparse weights - if (trt_state->sparsity_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; - } + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } - // enable builder heuristics - if (trt_state->build_heuristics_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } #if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 - // switch optimizaion level - if (trt_state->builder_optimization_level != 3) { - trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; - } + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } - // limit auxiliary streams - if (trt_state->auxiliary_streams >= 0) { - trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; - } + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } #else - if (trt_state->builder_optimization_level != 3) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (trt_state->auxiliary_streams >= 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } + if (trt_state->builder_optimization_level != 3) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } #endif - // limit used tactic sources - if (trt_state->filter_tactic_sources) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= trt_state->tactic_sources; - trt_config->setTacticSources(tactics); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not create timing cache: " + timing_cache_path); } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (trt_state->timing_cache_enable) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not create timing cache: " + timing_cache_path); - } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; } + } - // Build engine - { - auto lock = GetApiLock(); - std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - *(trt_state->engine) = std::unique_ptr( - trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } - } - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + // Build engine + { + auto lock = GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); } - trt_engine = trt_state->engine->get(); - if (trt_state->engine_cache_enable) { - // Serialize engine profile - SerializeProfileV2(profile_cache_path, shape_ranges); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - - // Serialize engine - std::unique_ptr serializedModel(trt_engine->serialize()); - size_t engine_size = serializedModel->size(); - if (trt_state->engine_decryption_enable) { - // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. - if (trt_state->engine_encryption != nullptr) { - if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; - } else { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + *(trt_state->engine) = std::unique_ptr( + trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + } + if (*(trt_state->engine) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + + // Serialize engine + std::unique_ptr serializedModel(trt_engine->serialize()); + size_t engine_size = serializedModel->size(); + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serializedModel->data()), engine_size); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), engine_size); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; } + } - // serialize and save timing cache - if (trt_state->timing_cache_enable) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } - context_update = true; } + context_update = true; } // Build execution context if either of the following conditions is true: From cf672c5887b4e5991b022df38fbb8d61f29fe420 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 12 Sep 2023 10:56:35 -0700 Subject: [PATCH 002/121] Use name of temporary provisioning profile. (#17459) The old provisioning profile no longer works. Switched to a temporary one that we can use before a new one is available. The temporary one has a different name. --- .../templates/stages/mac-ios-packaging-build-stage.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 2484facfae33e..81f17a26b16a6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -116,7 +116,8 @@ stages: xcodeDeveloperDir: '/Applications/Xcode_${{ variables.xcodeVersion }}.app/Contents/Developer' signingOption: 'manual' signingIdentity: '$(APPLE_CERTIFICATE_SIGNING_IDENTITY)' - provisioningProfileName: 'iOS Team Provisioning Profile' + provisioningProfileName: 'temporary *' # temporary name, change it back to the original below later + #provisioningProfileName: 'iOS Team Provisioning Profile' args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/ios_package_test/DerivedData' workingDirectory: '$(Build.BinariesDirectory)/app_center_test/ios_package_test/' useXcpretty: false # xcpretty can hide useful error output so we will disable it From 49511b548371b2af4d9c2d186204dd85dc0a1a7e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 12 Sep 2023 11:38:36 -0700 Subject: [PATCH 003/121] Improve performance of prune_graph in onnx_model.py (#17502) During optimization of SDXL UNet, the prune_graph takes up to 5 minutes. The cause is to find a node in all nodes is time-consuming. This optimization will reduce the latency of prune_graph to 2 seconds. New algorithm will use a hash table (key is first node output, value is node) to speed up. --- .../python/tools/transformers/onnx_model.py | 78 ++++++++++++------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 8c836db7b9ef6..60be2d84b2bc8 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -816,51 +816,77 @@ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): """ if len(self.graphs()) > 1: + # TODO(tianleiwu): handle subgraph logger.debug("Skip prune_graph since graph has subgraph") return - if outputs is None: - outputs = [output.name for output in self.model.graph.output] + keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs output_name_to_node = self.output_name_to_node() - all_nodes = [] - for output in outputs: - if output in output_name_to_node: - last_node = output_name_to_node[output] - if last_node in all_nodes: - continue - nodes = self.get_parent_subgraph_nodes(last_node, []) - all_nodes.append(last_node) - all_nodes.extend(nodes) - nodes_to_remove = [node for node in self.model.graph.node if node not in all_nodes] + def get_first_output(node): + if node.output[0]: + return node.output[0] + return next(iter([o for o in node.output if o]), None) - self.remove_nodes(nodes_to_remove) + # Keep track of nodes to keep. The key is first output of node, and the value is the node. + output_to_node = {} - # remove outputs not in list - output_to_remove = [] - for output in self.model.graph.output: - if output.name not in outputs: - output_to_remove.append(output) - for output in output_to_remove: - self.model.graph.output.remove(output) + # Start from graph outputs, and find parent nodes recurisvely, and add nodes to the output_to_node dictionary. + dq = deque() + for output in keep_outputs: + if output in output_name_to_node: + dq.append(output_name_to_node[output]) + while len(dq) > 0: + node = dq.pop() + first_output = get_first_output(node) + if first_output and (first_output not in output_to_node): + output_to_node[first_output] = node + for name in node.input: + if len(name) > 0 and (name in output_name_to_node) and (name not in output_to_node): + dq.appendleft(output_name_to_node[name]) + + # Keep only those nodes in the output_to_node dictionary. + nodes_to_keep = [] + num_nodes_removed = 0 + for node in self.model.graph.node: + first_output = get_first_output(node) + kept_node = output_to_node[first_output] if first_output in output_to_node else None - # remove inputs not used by any node. + # Need double check the node since fused node might reuse output name of some nodes to be removed. + # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. + if kept_node and kept_node.op_type == node.op_type and kept_node == node: + nodes_to_keep.append(node) + else: + num_nodes_removed += 1 + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes_to_keep) + + # Remove graph outputs not in list + output_to_remove = [] + if outputs is not None: + for output in self.model.graph.output: + if output.name not in outputs: + output_to_remove.append(output) + for output in output_to_remove: + self.model.graph.output.remove(output) + + # Remove graph inputs not used by any node. input_to_remove = [] if allow_remove_graph_inputs: input_name_to_nodes = self.input_name_to_nodes() input_to_remove = [input for input in self.model.graph.input if input.name not in input_name_to_nodes] - for input in input_to_remove: - self.model.graph.input.remove(input) + for name in input_to_remove: + self.model.graph.input.remove(name) - if input_to_remove or output_to_remove or nodes_to_remove: + if input_to_remove or output_to_remove or num_nodes_removed > 0: removed = [] if input_to_remove: removed.append(f"{len(input_to_remove)} inputs") if output_to_remove: removed.append(f"{len(output_to_remove)} outputs") - if nodes_to_remove: - removed.append(f"{len(nodes_to_remove)} nodes") + if num_nodes_removed > 0: + removed.append(f"{num_nodes_removed} nodes") logger.info("Removed %s", ", ".join(removed)) self.update_graph() From f923eec28b6b4237d6fc251042893deab335d97e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:59:13 -0700 Subject: [PATCH 004/121] [js/web] release session after use in npm test (#17470) ### Description release session after use in npm test. This is one of the prerequisites for supporting IO binding for WebGPU buffer in onnxruntime-web. list of prerequisites PRs: #17465 #17469 #17470 (this one) --- js/web/test/test-main.ts | 4 ++-- js/web/test/test-runner.ts | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index e614cc8e67e71..49d0ac225be2f 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -110,9 +110,9 @@ for (const group of ORT_WEB_TEST_CONFIG.model) { test, ORT_WEB_TEST_CONFIG.profile, ORT_WEB_TEST_CONFIG.options.sessionOptions); }); - after('release session', () => { + after('release session', async () => { if (context) { - context.release(); + await context.release(); } }); diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 9802f00f7a866..46d80a9f56f35 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -210,11 +210,12 @@ export class ModelTestContext { Logger.verbose('TestRunner.Perf', '***Perf Data End'); } - release(): void { + async release(): Promise { if (this.profile) { this.session.endProfiling(); } this.logPerfData(); + await this.session.release(); } /** From 9b755dce9f6b39296a716028f52347de037daa2e Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 12 Sep 2023 17:40:49 -0700 Subject: [PATCH 005/121] Delete all Prefast tasks (#17522) ### Description Delete all Prefast tasks because the new VS 17.7 version crashes every time when we run the task on our CI build servers. However, we cannot reproduce it locally. And this problem blocks us installing security patches to our CI build machines. Will use [CodeQL](https://codeql.github.com/) instead. ### Motivation and Context Address some security alerts. --- .github/workflows/sca.yml | 133 ------------------ .../azure-pipelines/post-merge-jobs.yml | 6 - .../azure-pipelines/templates/compliance.yml | 21 --- .../templates/jobs/win-ci-vs-2022-job.yml | 48 ------- .../templates/py-packaging-stage.yml | 18 --- .../azure-pipelines/templates/py-win-gpu.yml | 59 -------- .../azure-pipelines/templates/win-ci.yml | 19 --- .../azure-pipelines/win-ci-pipeline.yml | 10 -- .../azure-pipelines/win-gpu-ci-pipeline.yml | 4 - 9 files changed, 318 deletions(-) delete mode 100644 .github/workflows/sca.yml diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml deleted file mode 100644 index 1416f5a4d33a9..0000000000000 --- a/.github/workflows/sca.yml +++ /dev/null @@ -1,133 +0,0 @@ -name: Windows_SCA -on: - push: - branches: - - main - - rel-* - pull_request: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - -jobs: - Onnxruntime-SCA-training-CUDA: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x64' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Download cuda - run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk - - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --windows_sdk_version 10.0.22621.0 --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v11.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA - - # No python - Onnxruntime-SCA-win32-WINML-x64: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x64' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA_WIN32_WINML_X64 - - # No java, No python - Onnxruntime-SCA-win32-WINML-x86: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x86' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x86 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA_WIN32_WINML_X86 diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 113b24f7579ac..61f9b37d4ce78 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -74,7 +74,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-MultiA10 @@ -95,7 +94,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false ORT_EP_NAME: TRT WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-MultiA10 @@ -114,7 +112,6 @@ stages: isX86: false job_name_suffix: x64_mimalloc RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -134,7 +131,6 @@ stages: isX86: false job_name_suffix: x64_no_memory_profiling RunOnnxRuntimeTests: false - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -154,7 +150,6 @@ stages: isX86: false job_name_suffix: x64_minimal_no_exception RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -174,7 +169,6 @@ stages: isX86: false job_name_suffix: x64_debug_node_input_output RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false diff --git a/tools/ci_build/github/azure-pipelines/templates/compliance.yml b/tools/ci_build/github/azure-pipelines/templates/compliance.yml index f4bce8c53605b..cc451425be42a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/compliance.yml +++ b/tools/ci_build/github/azure-pipelines/templates/compliance.yml @@ -18,27 +18,6 @@ steps: AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll;-:file|$(Build.ArtifactStagingDirectory)\**\DirectML.dll' continueOnError: true -- task: DeleteFiles@1 - displayName: 'Delete files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - -# Manually set msBuildCommandline so that we can also set CAExcludePath -- task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildArchitecture: x64 - msBuildVersion: 17.0 - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln" /p:platform="${{parameters.msbuildPlatform}}" /p:configuration="RelWithDebInfo" /p:CAExcludePath="$(Build.BinariesDirectory);$(Build.SourcesDirectory)\cmake;C:\program files (x86)" /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - - task: SdtReport@2 displayName: 'Create Security Analysis Report' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 67a03beab9362..46f2ae7b97acc 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -34,11 +34,6 @@ parameters: type: boolean default: true -- name: RunStaticCodeAnalysis - displayName: Run Static Code Analysis - type: boolean - default: true - - name: ORT_EP_NAME type: string @@ -220,49 +215,6 @@ jobs: workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' displayName: 'Run tests' - - - ${{ if eq(parameters.RunStaticCodeAnalysis, true) }}: - - task: DeleteFiles@1 - displayName: 'Delete binaries files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - - - # Manually set msBuildCommandline so that we can also set CAExcludePath - # build_dir must be a sub folder of $(Build.SourcesDirectory) - # TODO: move this step to a CPU-only machine to save GPU resources. - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }} --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\RelWithDebInfo\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform=${{ parameters.msbuildPlatform }} /p:configuration=RelWithDebInfo /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - publishXML: true - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - - task: PublishSecurityAnalysisLogs@3 - displayName: 'Publish Security Analysis Logs' - continueOnError: true - - - task: PostAnalysis@2 - displayName: 'Guardian Break v2' - inputs: - GdnBreakGdnToolSDLNativeRulesSeverity: Note - GdnBreakGdnToolSDLNativeRules: true - - - - ${{ if eq(parameters.RunOnnxRuntimeTests, true) }}: - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 8812d4ed91ae7..1305f5ae21725 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -246,24 +246,6 @@ stages: workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run Python Tests' - #Skip it for 32 bits x86 build. Currently the scan tool has a bug: it doesn't allow me use 64 bits link.exe - #in 32 bits Win32 build. I tried all the settings but they all don't work. - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - condition: and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))) - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind --enable_onnx_tests --parallel $(TelemetryOption) --update --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\Debug\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform="$(MsbuildPlatform)" /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - task: TSAUpload@2 displayName: 'TSA upload' condition: and(and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))), eq(variables['Build.SourceBranch'], 'refs/heads/main')) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index ef938a634554a..919749cac15b6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -22,65 +22,6 @@ parameters: default: '' jobs: -- ${{ if eq(parameters.PYTHON_VERSION, '3.8') }}: - - job: Win_py_${{ parameters.EP_NAME }}_Wheels_StaticAnalysis - timeoutInMinutes: 240 - workspace: - clean: all - pool: onnxruntime-Win-CPU-2022 - steps: - - checkout: self - clean: true - submodules: none - - task: UsePythonVersion@0 - inputs: - versionSpec: 3.8 - addToPath: true - architecture: 'x64' - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false - - - template: download-deps.yml - - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - DownloadCUDA: true - - - task: PythonScript@0 - displayName: 'Update deps.txt' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py - arguments: --new_dir $(Build.BinariesDirectory)/deps - workingDirectory: $(Build.BinariesDirectory) - - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} --update --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\Debug\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform=x64 /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - publishXML: true - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 240 workspace: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index f6da7bb857b7d..80d285f3fd3fb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -263,25 +263,6 @@ stages: AnalyzeTargetGlob: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\**\*.dll' continueOnError: true - - task: DeleteFiles@1 - displayName: 'Delete files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - - #Manually set msBuildCommandline so that we can also set CAExcludePath - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - condition: and (succeeded(), eq(variables['msbuildPlatform'], 'x64')) - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.BinariesDirectory)\Debug\onnxruntime.sln" /p:platform="$(MsbuildPlatform)" /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.BinariesDirectory)#$(Build.SourcesDirectory)\cmake#C:\program files (x86)' - - task: PostAnalysis@2 inputs: GdnBreakAllTools: false diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index b9b833a3155bf..2a5622faf2905 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -47,7 +47,6 @@ stages: isX86: false job_name_suffix: x64_debug RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -69,7 +68,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -89,7 +87,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: DNNL GenerateDocumentation: false @@ -111,7 +108,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: XNNPACK GenerateDocumentation: false @@ -132,7 +128,6 @@ stages: job_name_suffix: x64_release_winml RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} # WinML has many warnings - RunStaticCodeAnalysis: false EnablePython: false isTraining: false ORT_EP_NAME: CPU @@ -153,7 +148,6 @@ stages: isX86: true job_name_suffix: x86_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -173,7 +167,6 @@ stages: isX86: false job_name_suffix: training_x64_debug RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: true ORT_EP_NAME: CPU GenerateDocumentation: false @@ -193,7 +186,6 @@ stages: isX86: false job_name_suffix: training_x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: true isTraining: true ORT_EP_NAME: CPU GenerateDocumentation: false @@ -213,7 +205,6 @@ stages: isX86: false job_name_suffix: ort_training_apis_x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false EnablePython: false isTraining: true ORT_EP_NAME: CPU @@ -234,7 +225,6 @@ stages: isX86: false job_name_suffix: x64_release_azure RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false EnablePython: false isTraining: false ORT_EP_NAME: CPU diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 69e71c1266664..8796917afa37d 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -47,7 +47,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-A10 @@ -65,7 +64,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true # Some unit tests crash on A10 GPUs. So this job still needs to use T4. @@ -85,7 +83,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: DML WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-dml-A10 @@ -104,7 +101,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: false - RunStaticCodeAnalysis: false GenerateDocumentation: true ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. WITH_CACHE: true From bbcf4b45dccb5ef7c2c8b516feb89505696ff602 Mon Sep 17 00:00:00 2001 From: "Nat Kershaw (MSFT)" Date: Tue, 12 Sep 2023 20:44:27 -0700 Subject: [PATCH 006/121] Upgrade doxygen to 1.9.8 (#17525) --- .github/workflows/publish-c-apidocs.yml | 6 +- docs/c_cxx/Doxyfile | 414 ++++++++++++++++++------ 2 files changed, 325 insertions(+), 95 deletions(-) diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 2fbd8e521aeee..73e8194bf7a8d 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -30,13 +30,13 @@ jobs: sudo apt update sudo apt-get install libclang-dev sudo apt-get install libclang-cpp14 - wget https://www.doxygen.nl/files/doxygen-1.9.6.linux.bin.tar.gz - tar xvzf doxygen-1.9.6.linux.bin.tar.gz + wget https://www.doxygen.nl/files/doxygen-1.9.8.linux.bin.tar.gz + tar xvzf doxygen-1.9.8.linux.bin.tar.gz - name: Run doxygen run: | mkdir -p build/doxygen cd docs/c_cxx - ../../doxygen-1.9.6/bin/doxygen + ../../doxygen-1.9.8/bin/doxygen - name: Log source commit run: git rev-parse --short HEAD > build/doxygen/html/source-version.txt - name: Move C/C++ docs into site diff --git a/docs/c_cxx/Doxyfile b/docs/c_cxx/Doxyfile index 94b39d2045f69..aedb1fdcfee75 100644 --- a/docs/c_cxx/Doxyfile +++ b/docs/c_cxx/Doxyfile @@ -1,4 +1,4 @@ -# Doxyfile 1.9.2 +# Doxyfile 1.9.8 # This file describes the settings to be used by the documentation system # doxygen (www.doxygen.org) for a project. @@ -12,6 +12,16 @@ # For lists, items can also be appended using: # TAG += value [value, ...] # Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options @@ -60,16 +70,28 @@ PROJECT_LOGO = "../images/ONNX_Runtime_logo - Docs.png" OUTPUT_DIRECTORY = ../../build/doxygen -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this # option can be useful when feeding doxygen a huge amount of source files, where # putting all generated files in the same directory would otherwise causes -# performance problems for the file system. +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. # The default value is: NO. CREATE_SUBDIRS = NO +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + # If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII # characters to appear in the names of generated files. If set to NO, non-ASCII # characters will be escaped, for example _xE3_x81_x84 will be used for Unicode @@ -81,14 +103,14 @@ ALLOW_UNICODE_NAMES = NO # The OUTPUT_LANGUAGE tag is used to specify the language in which all # documentation generated by doxygen is written. Doxygen will use this # information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. # The default value is: English. OUTPUT_LANGUAGE = English @@ -341,6 +363,17 @@ MARKDOWN_SUPPORT = YES TOC_INCLUDE_HEADINGS = 5 +# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to +# generate identifiers for the Markdown headings. Note: Every identifier is +# unique. +# Possible values are: DOXYGEN use a fixed 'autotoc_md' string followed by a +# sequence number starting at 0 and GITHUB use the lower case version of title +# with any whitespace replaced by '-' and punctuation characters removed. +# The default value is: DOXYGEN. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +MARKDOWN_ID_STYLE = DOXYGEN + # When enabled doxygen tries to link words that correspond to documented # classes, or namespaces to their corresponding documentation. Such a link can # be prevented in individual cases by putting a % sign in front of the word or @@ -437,7 +470,7 @@ INLINE_SIMPLE_STRUCTS = NO # types are typedef'ed and only the typedef is referenced, never the tag name. # The default value is: NO. -TYPEDEF_HIDES_STRUCT = YES +TYPEDEF_HIDES_STRUCT = NO # The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This # cache is used to resolve symbols given their name and scope. Since this can be @@ -452,7 +485,7 @@ TYPEDEF_HIDES_STRUCT = YES LOOKUP_CACHE_SIZE = 0 -# The NUM_PROC_THREADS specifies the number threads doxygen is allowed to use +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use # during processing. When set to 0 doxygen will based this on the number of # cores available in the system. You can set it explicitly to a value larger # than 0 to get more control over the balance between CPU load and processing @@ -465,6 +498,14 @@ LOOKUP_CACHE_SIZE = 0 NUM_PROC_THREADS = 1 +# If the TIMESTAMP tag is set different from NO then each generated page will +# contain the date or date and time when the page was generated. Setting this to +# NO can help when comparing the output of multiple runs. +# Possible values are: YES, NO, DATETIME and DATE. +# The default value is: NO. + +TIMESTAMP = NO + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- @@ -546,7 +587,8 @@ HIDE_UNDOC_MEMBERS = NO # If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all # undocumented classes that are normally visible in the class hierarchy. If set # to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. # The default value is: NO. HIDE_UNDOC_CLASSES = NO @@ -577,14 +619,15 @@ INTERNAL_DOCS = NO # filesystem is case sensitive (i.e. it supports files in the same directory # whose names only differ in casing), the option must be set to YES to properly # deal with such files in case they appear in the input. For filesystems that -# are not case sensitive the option should be be set to NO to properly deal with +# are not case sensitive the option should be set to NO to properly deal with # output files written for symbols that only differ in casing, such as for two # classes, one named CLASS and the other named Class, and to also support # references to files without having to specify the exact matching casing. On # Windows (including Cygwin) and MacOS, users should typically set this option # to NO, whereas on Linux or other Unix flavors it should typically be set to # YES. -# The default value is: system dependent. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. CASE_SENSE_NAMES = NO @@ -836,11 +879,26 @@ WARN_IF_INCOMPLETE_DOC = YES WARN_NO_PARAMDOC = YES +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about +# undocumented enumeration values. If set to NO, doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + # If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when # a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS # then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but # at the end of the doxygen process doxygen will return with a non-zero status. -# Possible values are: NO, YES and FAIL_ON_WARNINGS. +# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then doxygen behaves +# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined doxygen will not +# write the warning messages in between other messages but write them at the end +# of a run, in case a WARN_LOGFILE is defined the warning messages will be +# besides being in the defined file also be shown at the end of a run, unless +# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case +# the behavior will remain as with the setting FAIL_ON_WARNINGS. +# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. # The default value is: NO. WARN_AS_ERROR = YES @@ -851,13 +909,27 @@ WARN_AS_ERROR = YES # and the warning text. Optionally the format may contain $version, which will # be replaced by the version of the file (if it could be obtained via # FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT # The default value is: $file:$line: $text. WARN_FORMAT = "$file:$line: $text" +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + # The WARN_LOGFILE tag can be used to specify a file to which warning and error # messages should be written. If left blank the output is written to standard -# error (stderr). +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). WARN_LOGFILE = @@ -881,10 +953,21 @@ INPUT = ../../include/onnxruntime/core/session/onnxruntime_c_ap # libiconv (or the iconv built into libc) for the transcoding. See the libiconv # documentation (see: # https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING # The default value is: UTF-8. INPUT_ENCODING = UTF-8 +# This tag can be used to specify the character encoding of the source files +# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding +# "INPUT_ENCODING" for further information on supported encodings. + +INPUT_FILE_ENCODING = + # If the value of the INPUT tag contains directories, you can use the # FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and # *.h) to filter out the source-files in the directories. @@ -896,18 +979,21 @@ INPUT_ENCODING = UTF-8 # Note the list of default checked file patterns might differ from the list of # default file extension mappings. # -# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, -# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, -# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C -# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, -# *.vhdl, *.ucf, *.qsf and *.ice. +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cxxm, +# *.cpp, *.cppm, *.c++, *.c++m, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, +# *.ddl, *.odl, *.h, *.hh, *.hxx, *.hpp, *.h++, *.ixx, *.l, *.cs, *.d, *.php, +# *.php4, *.php5, *.phtml, *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be +# provided as doxygen C comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, +# *.f18, *.f, *.for, *.vhd, *.vhdl, *.ucf, *.qsf and *.ice. FILE_PATTERNS = *.c \ *.cc \ *.cxx \ + *.cxxm \ *.cpp \ + *.cppm \ *.c++ \ + *.c++m \ *.java \ *.ii \ *.ixx \ @@ -922,6 +1008,8 @@ FILE_PATTERNS = *.c \ *.hxx \ *.hpp \ *.h++ \ + *.ixx \ + *.l \ *.cs \ *.d \ *.php \ @@ -984,10 +1072,7 @@ EXCLUDE_PATTERNS = # (namespaces, classes, functions, etc.) that should be excluded from the # output. The symbol name can be a fully qualified name, a word, or if the # wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* +# ANamespace::AClass, ANamespace::*Test EXCLUDE_SYMBOLS = @@ -1032,6 +1117,11 @@ IMAGE_PATH = # code is scanned, but not when the output code is generated. If lines are added # or removed, the anchors will not be placed correctly. # +# Note that doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# # Note that for custom extensions or not directly supported extensions you also # need to set EXTENSION_MAPPING for the extension otherwise the files are not # properly processed by doxygen. @@ -1073,6 +1163,15 @@ FILTER_SOURCE_PATTERNS = USE_MDFILE_AS_MAINPAGE = +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 + #--------------------------------------------------------------------------- # Configuration options related to source browsing #--------------------------------------------------------------------------- @@ -1210,10 +1309,11 @@ CLANG_DATABASE_PATH = ALPHABETICAL_INDEX = YES -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. # This tag requires that the tag ALPHABETICAL_INDEX is set to YES. IGNORE_PREFIX = @@ -1292,7 +1392,12 @@ HTML_STYLESHEET = # Doxygen will copy the style sheet files to the output directory. # Note: The order of the extra style sheet files is of importance (e.g. the last # style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. # This tag requires that the tag GENERATE_HTML is set to YES. HTML_EXTRA_STYLESHEET = @@ -1307,6 +1412,19 @@ HTML_EXTRA_STYLESHEET = HTML_EXTRA_FILES = +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generate light mode output, DARK always +# generate dark mode output, AUTO_LIGHT automatically set the mode according to +# the user preference, use light mode if no preference is set (the default), +# AUTO_DARK automatically set the mode according to the user preference, use +# dark mode if no preference is set and TOGGLE allow to user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = AUTO_LIGHT + # The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen # will adjust the colors in the style sheet and background images according to # this color. Hue is specified as an angle on a color-wheel, see @@ -1337,15 +1455,6 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - # If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML # documentation will contain a main index with vertical navigation menus that # are dynamically created via JavaScript. If disabled, the navigation index will @@ -1365,6 +1474,13 @@ HTML_DYNAMIC_MENUS = YES HTML_DYNAMIC_SECTIONS = NO +# If the HTML_CODE_FOLDING tag is set to YES then classes and functions can be +# dynamically folded and expanded in the generated HTML source code. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_CODE_FOLDING = YES + # With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries # shown in the various tree structured indices initially; the user can expand # and collapse entries dynamically later on. Doxygen will expand the tree to @@ -1401,6 +1517,13 @@ GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + # This tag specifies a string that should uniquely identify the documentation # set bundle. This should be a reverse domain-name style string, e.g. # com.mycompany.MyDocSet. Doxygen will append .docset to the name. @@ -1488,6 +1611,16 @@ BINARY_TOC = NO TOC_EXPAND = NO +# The SITEMAP_URL tag is used to specify the full URL of the place where the +# generated documentation will be placed on the server by the user during the +# deployment of the documentation. The generated sitemap is called sitemap.xml +# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL +# is specified no sitemap is generated. For information about the sitemap +# protocol see https://www.sitemaps.org +# This tag requires that the tag GENERATE_HTML is set to YES. + +SITEMAP_URL = + # If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and # QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that # can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help @@ -1605,7 +1738,7 @@ GENERATE_TREEVIEW = YES # area (value NO) or if it should extend to the full height of the window (value # YES). Setting this to YES gives a layout similar to # https://docs.readthedocs.io with more room for contents, but less room for the -# project logo, title, and description. If either GENERATOR_TREEVIEW or +# project logo, title, and description. If either GENERATE_TREEVIEW or # DISABLE_INDEX is set to NO, this option has no effect. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1636,6 +1769,13 @@ TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + # If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg # tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see # https://inkscape.org) to generate formulas as SVG images instead of PNGs for @@ -1969,9 +2109,16 @@ PDF_HYPERLINKS = YES USE_PDFLATEX = YES -# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode -# command to the generated LaTeX files. This will instruct LaTeX to keep running -# if errors occur, instead of asking the user for help. +# The LATEX_BATCHMODE tag signals the behavior of LaTeX in case of an error. +# Possible values are: NO same as ERROR_STOP, YES same as BATCH, BATCH In batch +# mode nothing is printed on the terminal, errors are scrolled as if is +# hit at every error; missing files that TeX tries to input or request from +# keyboard input (\read on a not open input stream) cause the job to abort, +# NON_STOP In nonstop mode the diagnostic message will appear on the terminal, +# but there is no possibility of user interaction just like in batch mode, +# SCROLL In scroll mode, TeX will stop only for missing files to input or if +# keyboard input is necessary and ERROR_STOP In errorstop mode, TeX will stop at +# each error, asking for user intervention. # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. @@ -1992,14 +2139,6 @@ LATEX_HIDE_INDICES = NO LATEX_BIB_STYLE = plain -# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated -# page will contain the date and time when the page was generated. Setting this -# to NO can help when comparing the output of multiple runs. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_TIMESTAMP = NO - # The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) # path from which the emoji images will be read. If a relative path is entered, # it will be relative to the LATEX_OUTPUT directory. If left blank the @@ -2165,7 +2304,7 @@ DOCBOOK_OUTPUT = docbook #--------------------------------------------------------------------------- # If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an -# AutoGen Definitions (see http://autogen.sourceforge.net/) file that captures +# AutoGen Definitions (see https://autogen.sourceforge.net/) file that captures # the structure of the code including all documentation. Note that this feature # is still experimental and incomplete at the moment. # The default value is: NO. @@ -2176,6 +2315,28 @@ GENERATE_AUTOGEN_DEF = NO # Configuration options related to Sqlite3 output #--------------------------------------------------------------------------- +# If the GENERATE_SQLITE3 tag is set to YES doxygen will generate a Sqlite3 +# database with symbols found by doxygen stored in tables. +# The default value is: NO. + +GENERATE_SQLITE3 = NO + +# The SQLITE3_OUTPUT tag is used to specify where the Sqlite3 database will be +# put. If a relative path is entered the value of OUTPUT_DIRECTORY will be put +# in front of it. +# The default directory is: sqlite3. +# This tag requires that the tag GENERATE_SQLITE3 is set to YES. + +SQLITE3_OUTPUT = sqlite3 + +# The SQLITE3_OVERWRITE_DB tag is set to YES, the existing doxygen_sqlite3.db +# database file will be recreated with each doxygen run. If set to NO, doxygen +# will warn if an a database file is already found and not modify it. +# The default value is: YES. +# This tag requires that the tag GENERATE_SQLITE3 is set to YES. + +SQLITE3_RECREATE_DB = YES + #--------------------------------------------------------------------------- # Configuration options related to the Perl module output #--------------------------------------------------------------------------- @@ -2250,7 +2411,8 @@ SEARCH_INCLUDES = YES # The INCLUDE_PATH tag can be used to specify one or more directories that # contain include files that are not input files but should be processed by the -# preprocessor. +# preprocessor. Note that the INCLUDE_PATH is not recursive, so the setting of +# RECURSIVE has no effect here. # This tag requires that the tag SEARCH_INCLUDES is set to YES. # onnxruntime-training and onnxruntime core headers are in different directories. @@ -2324,15 +2486,15 @@ TAGFILES = GENERATE_TAGFILE = -# If the ALLEXTERNALS tag is set to YES, all external class will be listed in -# the class index. If set to NO, only the inherited external classes will be -# listed. +# If the ALLEXTERNALS tag is set to YES, all external classes and namespaces +# will be listed in the class and namespace index. If set to NO, only the +# inherited external classes will be listed. # The default value is: NO. ALLEXTERNALS = NO # If the EXTERNAL_GROUPS tag is set to YES, all external groups will be listed -# in the modules index. If set to NO, only the current project's groups will be +# in the topic index. If set to NO, only the current project's groups will be # listed. # The default value is: YES. @@ -2346,16 +2508,9 @@ EXTERNAL_GROUPS = YES EXTERNAL_PAGES = YES #--------------------------------------------------------------------------- -# Configuration options related to the dot tool +# Configuration options related to diagram generator tools #--------------------------------------------------------------------------- -# You can include diagrams made with dia in doxygen documentation. Doxygen will -# then run dia to produce the diagram and insert it in the documentation. The -# DIA_PATH tag allows you to specify the directory where the dia binary resides. -# If left empty dia is assumed to be found in the default search path. - -DIA_PATH = - # If set to YES the inheritance and collaboration graphs will hide inheritance # and usage relations if the target is undocumented or is not a class. # The default value is: YES. @@ -2364,7 +2519,7 @@ HIDE_UNDOC_RELATIONS = YES # If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is # available from the path. This tool is part of Graphviz (see: -# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent +# https://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent # Bell Labs. The other options in this section have no effect if this option is # set to NO # The default value is: NO. @@ -2381,32 +2536,73 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# By default doxygen will tell dot to use the default font as specified with -# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set -# the path where dot can find it using this tag. +# DOT_COMMON_ATTR is common attributes for nodes, edges and labels of +# subgraphs. When you want a differently looking font in the dot files that +# doxygen generates you can specify fontname, fontcolor and fontsize attributes. +# For details please see Node, +# Edge and Graph Attributes specification You need to make sure dot is able +# to find the font, which can be done by putting it in a standard location or by +# setting the DOTFONTPATH environment variable or by setting DOT_FONTPATH to the +# directory containing the font. Default graphviz fontsize is 14. +# The default value is: fontname=Helvetica,fontsize=10. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_COMMON_ATTR = "fontname=Helvetica,fontsize=10" + +# DOT_EDGE_ATTR is concatenated with DOT_COMMON_ATTR. For elegant style you can +# add 'arrowhead=open, arrowtail=open, arrowsize=0.5'. Complete documentation about +# arrows shapes. +# The default value is: labelfontname=Helvetica,labelfontsize=10. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_EDGE_ATTR = "labelfontname=Helvetica,labelfontsize=10" + +# DOT_NODE_ATTR is concatenated with DOT_COMMON_ATTR. For view without boxes +# around nodes set 'shape=plain' or 'shape=plaintext' Shapes specification +# The default value is: shape=box,height=0.2,width=0.4. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_NODE_ATTR = "shape=box,height=0.2,width=0.4" + +# You can set the path where dot can find font specified with fontname in +# DOT_COMMON_ATTR and others dot attributes. # This tag requires that the tag HAVE_DOT is set to YES. DOT_FONTPATH = -# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for -# each documented class showing the direct and indirect inheritance relations. -# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO. +# If the CLASS_GRAPH tag is set to YES or GRAPH or BUILTIN then doxygen will +# generate a graph for each documented class showing the direct and indirect +# inheritance relations. In case the CLASS_GRAPH tag is set to YES or GRAPH and +# HAVE_DOT is enabled as well, then dot will be used to draw the graph. In case +# the CLASS_GRAPH tag is set to YES and HAVE_DOT is disabled or if the +# CLASS_GRAPH tag is set to BUILTIN, then the built-in generator will be used. +# If the CLASS_GRAPH tag is set to TEXT the direct and indirect inheritance +# relations will be shown as texts / links. +# Possible values are: NO, YES, TEXT, GRAPH and BUILTIN. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. CLASS_GRAPH = YES # If the COLLABORATION_GRAPH tag is set to YES then doxygen will generate a # graph for each documented class showing the direct and indirect implementation # dependencies (inheritance, containment, and class references variables) of the -# class with other documented classes. +# class with other documented classes. Explicit enabling a collaboration graph, +# when COLLABORATION_GRAPH is set to NO, can be accomplished by means of the +# command \collaborationgraph. Disabling a collaboration graph can be +# accomplished by means of the command \hidecollaborationgraph. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. COLLABORATION_GRAPH = YES # If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for -# groups, showing the direct groups dependencies. +# groups, showing the direct groups dependencies. Explicit enabling a group +# dependency graph, when GROUP_GRAPHS is set to NO, can be accomplished by means +# of the command \groupgraph. Disabling a directory graph can be accomplished by +# means of the command \hidegroupgraph. See also the chapter Grouping in the +# manual. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2466,7 +2662,9 @@ TEMPLATE_RELATIONS = NO # If the INCLUDE_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are set to # YES then doxygen will generate a graph for each documented file showing the # direct and indirect include dependencies of the file with other documented -# files. +# files. Explicit enabling an include graph, when INCLUDE_GRAPH is is set to NO, +# can be accomplished by means of the command \includegraph. Disabling an +# include graph can be accomplished by means of the command \hideincludegraph. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2475,7 +2673,10 @@ INCLUDE_GRAPH = YES # If the INCLUDED_BY_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are # set to YES then doxygen will generate a graph for each documented file showing # the direct and indirect include dependencies of the file with other documented -# files. +# files. Explicit enabling an included by graph, when INCLUDED_BY_GRAPH is set +# to NO, can be accomplished by means of the command \includedbygraph. Disabling +# an included by graph can be accomplished by means of the command +# \hideincludedbygraph. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2515,16 +2716,26 @@ GRAPHICAL_HIERARCHY = YES # If the DIRECTORY_GRAPH tag is set to YES then doxygen will show the # dependencies a directory has on other directories in a graphical way. The # dependency relations are determined by the #include relations between the -# files in the directories. +# files in the directories. Explicit enabling a directory graph, when +# DIRECTORY_GRAPH is set to NO, can be accomplished by means of the command +# \directorygraph. Disabling a directory graph can be accomplished by means of +# the command \hidedirectorygraph. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. DIRECTORY_GRAPH = YES +# The DIR_GRAPH_MAX_DEPTH tag can be used to limit the maximum number of levels +# of child directories generated in directory dependency graphs by dot. +# Minimum value: 1, maximum value: 25, default value: 1. +# This tag requires that the tag DIRECTORY_GRAPH is set to YES. + +DIR_GRAPH_MAX_DEPTH = 1 + # The DOT_IMAGE_FORMAT tag can be used to set the image format of the images # generated by dot. For an explanation of the image formats see the section # output formats in the documentation of the dot tool (Graphviz (see: -# http://www.graphviz.org/)). +# https://www.graphviz.org/)). # Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order # to make the SVG files visible in IE 9+ (other browsers do not have this # requirement). @@ -2561,11 +2772,12 @@ DOT_PATH = DOTFILE_DIRS = -# The MSCFILE_DIRS tag can be used to specify one or more directories that -# contain msc files that are included in the documentation (see the \mscfile -# command). +# You can include diagrams made with dia in doxygen documentation. Doxygen will +# then run dia to produce the diagram and insert it in the documentation. The +# DIA_PATH tag allows you to specify the directory where the dia binary resides. +# If left empty dia is assumed to be found in the default search path. -MSCFILE_DIRS = +DIA_PATH = # The DIAFILE_DIRS tag can be used to specify one or more directories that # contain dia files that are included in the documentation (see the \diafile @@ -2574,10 +2786,10 @@ MSCFILE_DIRS = DIAFILE_DIRS = # When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the -# path where java can find the plantuml.jar file. If left blank, it is assumed -# PlantUML is not used or called during a preprocessing step. Doxygen will -# generate a warning when it encounters a \startuml command in this case and -# will not generate output for the diagram. +# path where java can find the plantuml.jar file or to the filename of jar file +# to be used. If left blank, it is assumed PlantUML is not used or called during +# a preprocessing step. Doxygen will generate a warning when it encounters a +# \startuml command in this case and will not generate output for the diagram. PLANTUML_JAR_PATH = @@ -2627,6 +2839,8 @@ DOT_MULTI_TARGETS = NO # If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page # explaining the meaning of the various boxes and arrows in the dot generated # graphs. +# Note: This tag requires that UML_LOOK isn't set, i.e. the doxygen internal +# graphical representation for inheritance and collaboration diagrams is used. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2640,3 +2854,19 @@ GENERATE_LEGEND = YES # The default value is: YES. DOT_CLEANUP = YES + +# You can define message sequence charts within doxygen comments using the \msc +# command. If the MSCGEN_TOOL tag is left empty (the default), then doxygen will +# use a built-in version of mscgen tool to produce the charts. Alternatively, +# the MSCGEN_TOOL tag can also specify the name an external tool. For instance, +# specifying prog as the value, doxygen will call the tool as prog -T +# -o . The external tool should support +# output file formats "png", "eps", "svg", and "ismap". + +MSCGEN_TOOL = + +# The MSCFILE_DIRS tag can be used to specify one or more directories that +# contain msc files that are included in the documentation (see the \mscfile +# command). + +MSCFILE_DIRS = From b52127d22d008254f20de1956e698b8a6f8712d6 Mon Sep 17 00:00:00 2001 From: rui-ren Date: Tue, 12 Sep 2023 22:32:20 -0700 Subject: [PATCH 007/121] update acpt image for the training ci nightly (#17521) ### Description The name of nightly ACPT image has been updated to `ptebic.azurecr.io/internal/aifx/acpt/nightly-ubuntu-cuda-torch-dev` As the previous image alias had `cu118`, `torch210dev` or `py38`, any version update will break the training nightly pipeline ### Motivation and Context Using constant image alias to avoid pipeline failure. --- .../orttraining-linux-nightly-ortmodule-test-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml index a41ca5f02467d..7824bf2203efe 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml @@ -23,7 +23,7 @@ jobs: --rm \ --volume $(Build.SourcesDirectory)/orttraining/orttraining/test/python:/onnxruntime_src \ --volume $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly:/requirements_torch_nightly \ - ptebic.azurecr.io/internal/azureml/aifx/nightly-ubuntu2004-cu118-py38-torch210dev \ + ptebic.azurecr.io/internal/aifx/acpt/nightly-ubuntu-cuda-torch-dev \ bash -c "python3 -m pip install -r /requirements_torch_nightly/requirements.txt && python3 -m pytest -sv /onnxruntime_src/orttraining_test_ortmodule_api.py" displayName: 'Run ORTModule Tests' condition: succeededOrFailed() From 24a3c740c06f3c3b2d97f2839ece55bbd8ea0dd9 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 12 Sep 2023 22:39:31 -0700 Subject: [PATCH 008/121] Revert "[ROCm][MIGraphX] for googletest dep, set OVERRIDE_FIND_PACKAGE (#16715)" (#17523) This reverts commit bb136f86c8a1d0bcbdc2a77cb16f1c26c9ebd817, then re-implement it in a different way. I reverted the original change, then added a version constraint to the find_package args. If you still found it picks up wrong gtest version after this change, you may disable `find_package` by setting 'FETCHCONTENT_TRY_FIND_PACKAGE_MODE' to NEVER. For example, the latest gtest version is 1.14.0. If at a later time Google releases a new version of gtest and that one is incompatible with the ONNX Runtime source code you get today and your dev environment already installed the new version and you do not want to create a new clean build environment that is without the package, you can add `--cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER` to your build command to solve the problem. --- cmake/external/onnxruntime_external_deps.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 2a100ac161b97..e1671bcf43ed9 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -43,8 +43,8 @@ if (onnxruntime_BUILD_UNIT_TESTS) FetchContent_Declare( googletest URL ${DEP_URL_googletest} + FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest URL_HASH SHA1=${DEP_SHA1_googletest} - OVERRIDE_FIND_PACKAGE ) endif() From 41584b28278a25f0e6a269a47264bccad9888f87 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 12 Sep 2023 23:52:08 -0700 Subject: [PATCH 009/121] [js/web] ensure ORT initialization to run only once (#17529) ### Description ensure ORT initialization to run only once --- js/web/lib/wasm/session-handler.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d35f295592685..d8c5ae7886fe4 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -9,6 +9,7 @@ import {SerializableModeldata} from './proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; let runtimeInitialized: boolean; +let runtimeInitializationPromise: Promise|undefined; export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { private sessionId: number; @@ -29,7 +30,11 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { if (!runtimeInitialized) { - await initializeRuntime(env); + if (!runtimeInitializationPromise) { + runtimeInitializationPromise = initializeRuntime(env); + } + await runtimeInitializationPromise; + runtimeInitializationPromise = undefined; runtimeInitialized = true; } From ec94b07f0a14fc0a805e767803d483059cbd58ff Mon Sep 17 00:00:00 2001 From: xhcao Date: Wed, 13 Sep 2023 15:05:00 +0800 Subject: [PATCH 010/121] [JS/WebGPU] support Concat.int32 operator (#17003) ### Description ### Motivation and Context --- js/web/test/data/ops/concat_int32.jsonc | 406 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 27 +- .../core/providers/js/operators/concat.cc | 12 +- 3 files changed, 428 insertions(+), 17 deletions(-) create mode 100644 js/web/test/data/ops/concat_int32.jsonc diff --git a/js/web/test/data/ops/concat_int32.jsonc b/js/web/test/data/ops/concat_int32.jsonc new file mode 100644 index 0000000000000..6e2ce18c6f7c5 --- /dev/null +++ b/js/web/test/data/ops/concat_int32.jsonc @@ -0,0 +1,406 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "[4,4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [8, 4], + "type": "int32" + } + ] + }, + { + "name": "[2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8], + "dims": [2, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [4, 4], + "type": "int32" + } + ] + }, + { + "name": "[2,3]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], + "dims": [4, 3], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "[4,4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, + 16 + ], + "dims": [4, 8], + "type": "int32" + } + ] + }, + { + "name": "[2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8], + "dims": [2, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 5, 6, 1, 2, 3, 4, 3, 4, 7, 8, 5, 6, 7, 8], + "dims": [2, 8], + "type": "int32" + } + ] + }, + { + "name": "[2,3]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6], + "dims": [2, 6], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 3D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "[2,2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + }, + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, + 16 + ], + "dims": [4, 2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 3D axis=1", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "[2,2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + }, + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 9, 10, 13, 14, 11, 12, 15, + 16 + ], + "dims": [2, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 3D axis=2", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 2, "type": "int" }], + "cases": [ + { + "name": "[2,2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + }, + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 1, 2, 5, 6, 3, 4, 7, 8, 3, 4, 7, 8, 9, 10, 13, 14, 9, 10, 13, 14, 11, 12, 15, 16, 11, 12, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, + 29, 30, 27, 28, 31, 32 + ], + "dims": [4, 2, 2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=1", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, + 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32, 17, 18, 21, 22, 19, 20, 23, 24, 25, + 26, 29, 30, 27, 28, 31, 32 + ], + "dims": [2, 4, 2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=2", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 2, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 9, 10, 13, 14, 11, 12, 15, + 16, 17, 18, 21, 22, 19, 20, 23, 24, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32, 25, + 26, 29, 30, 27, 28, 31, 32 + ], + "dims": [2, 2, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=3", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 3, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 1, 2, 5, 6, 3, 4, 7, 8, 3, 4, 7, 8, 9, 10, 13, 14, 9, 10, 13, 14, 11, 12, 15, 16, 11, 12, 15, + 16, 17, 18, 21, 22, 17, 18, 21, 22, 19, 20, 23, 24, 19, 20, 23, 24, 25, 26, 29, 30, 25, 26, 29, 30, 27, + 28, 31, 32, 27, 28, 31, 32 + ], + "dims": [2, 2, 2, 8], + "type": "int32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index f53da708b8f6f..f4249b24101e5 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -432,18 +432,18 @@ // // "test_compress_1", // // "test_compress_default_axis", // // "test_compress_negative_axis", - // "test_concat_1d_axis_0", - // "test_concat_1d_axis_negative_1", - // "test_concat_2d_axis_0", - // "test_concat_2d_axis_1", - // "test_concat_2d_axis_negative_1", - // "test_concat_2d_axis_negative_2", - // "test_concat_3d_axis_0", - // "test_concat_3d_axis_1", - // "test_concat_3d_axis_2", - // "test_concat_3d_axis_negative_1", - // "test_concat_3d_axis_negative_2", - // "test_concat_3d_axis_negative_3", + "test_concat_1d_axis_0", + "test_concat_1d_axis_negative_1", + "test_concat_2d_axis_0", + "test_concat_2d_axis_1", + "test_concat_2d_axis_negative_1", + "test_concat_2d_axis_negative_2", + "test_concat_3d_axis_0", + "test_concat_3d_axis_1", + "test_concat_3d_axis_2", + "test_concat_3d_axis_negative_1", + "test_concat_3d_axis_negative_2", + "test_concat_3d_axis_negative_3", "test_conv_with_autopad_same", "test_conv_with_strides_and_asymmetric_padding", "test_conv_with_strides_no_padding", @@ -1330,7 +1330,8 @@ //"and.jsonc", "asin.jsonc", "ceil.jsonc", - //"concat.jsonc", + "concat.jsonc", + "concat_int32.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", diff --git a/onnxruntime/core/providers/js/operators/concat.cc b/onnxruntime/core/providers/js/operators/concat.cc index 7d50d78c82851..3a6a7e1cafd7a 100644 --- a/onnxruntime/core/providers/js/operators/concat.cc +++ b/onnxruntime/core/providers/js/operators/concat.cc @@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 3, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -21,7 +22,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 4, 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -30,7 +32,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 11, 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); ONNX_OPERATOR_KERNEL_EX( @@ -39,7 +42,8 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); } // namespace js From cdf3e9dba9718b4b461540894905e01357096b7d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 13 Sep 2023 00:07:16 -0700 Subject: [PATCH 011/121] [js] update prepack script to use exact version (#17484) ### Description update prepack script to use exact version. the prepack script for onnxruntime-node, onnxruntime-web and onnxruntime-react-native is used to update their referencing version of dependency "onnxruntime-common". Previously "~" (tilde symbol) is used. This may cause NPM choose an older version (if the old version matches the version requirement and was previously installed already so hit the cache). see also https://semver.npmjs.com/. [This build](https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1134671&view=results) is caused by this issue. --- js/node/script/prepack.ts | 2 +- js/react_native/scripts/prepack.ts | 2 +- js/web/script/prepack.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/js/node/script/prepack.ts b/js/node/script/prepack.ts index be86c5687bec0..4c5941d8dae12 100644 --- a/js/node/script/prepack.ts +++ b/js/node/script/prepack.ts @@ -11,7 +11,7 @@ function updatePackageJson() { const packageCommon = fs.readJSONSync(commonPackageJsonPath); const packageSelf = fs.readJSONSync(selfPackageJsonPath); const version = packageCommon.version; - packageSelf.dependencies['onnxruntime-common'] = `~${version}`; + packageSelf.dependencies['onnxruntime-common'] = `${version}`; fs.writeJSONSync(selfPackageJsonPath, packageSelf, {spaces: 2}); console.log('=== finished updating package.json.'); } diff --git a/js/react_native/scripts/prepack.ts b/js/react_native/scripts/prepack.ts index 15ae69722108c..2e43294165a83 100644 --- a/js/react_native/scripts/prepack.ts +++ b/js/react_native/scripts/prepack.ts @@ -18,7 +18,7 @@ function updatePackageJson() { delete packageSelf.dependencies['onnxruntime-common']; } else { const version = packageCommon.version; - packageSelf.dependencies['onnxruntime-common'] = `~${version}`; + packageSelf.dependencies['onnxruntime-common'] = `${version}`; } fs.writeJSONSync(selfPackageJsonPath, packageSelf, {spaces: 2}); console.log('=== finished updating package.json.'); diff --git a/js/web/script/prepack.ts b/js/web/script/prepack.ts index be86c5687bec0..4c5941d8dae12 100644 --- a/js/web/script/prepack.ts +++ b/js/web/script/prepack.ts @@ -11,7 +11,7 @@ function updatePackageJson() { const packageCommon = fs.readJSONSync(commonPackageJsonPath); const packageSelf = fs.readJSONSync(selfPackageJsonPath); const version = packageCommon.version; - packageSelf.dependencies['onnxruntime-common'] = `~${version}`; + packageSelf.dependencies['onnxruntime-common'] = `${version}`; fs.writeJSONSync(selfPackageJsonPath, packageSelf, {spaces: 2}); console.log('=== finished updating package.json.'); } From c0a4fe777fcc1311bf1379651ca68dfde176d94d Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 13 Sep 2023 15:21:28 +0800 Subject: [PATCH 012/121] Move Linux python test into docker (#17479) ### Description supplement of #17417 ### Motivation and Context --- .../azure-pipelines/linux-ci-pipeline.yml | 103 ++++++------------ tools/scripts/python_test.sh | 28 +++++ tools/scripts/symbolic_shape_infer_test.sh | 13 +++ 3 files changed, 73 insertions(+), 71 deletions(-) create mode 100644 tools/scripts/python_test.sh create mode 100644 tools/scripts/symbolic_shape_infer_test.sh diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index eb6b274f87d6b..21bc1c481b3e6 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -141,77 +141,39 @@ stages: " displayName: 'Dotnet build C# sln and Test' - - task: CmdLine@2 - displayName: 'Install python deps' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_transformers_tool_test - --ctest_path "" - - - task: CmdLine@2 - displayName: 'Install Debug python package' - inputs: - script: | - set -e -x - rm -rf $(Build.BinariesDirectory)/Debug/onnxruntime $(Build.BinariesDirectory)/Debug/pybind11 - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml -qq - python3 -m pip install $(Build.BinariesDirectory)/Debug/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Debug unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Debug - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Debug - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_transformers_tool_test - --ctest_path "" + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + /bin/bash /onnxruntime_src/tools/scripts/python_test.sh /onnxruntime_src /build Release && \ + /bin/bash /onnxruntime_src/tools/scripts/symbolic_shape_infer_test.sh /build + " + displayName: 'Run Release tests and symbolic shape infer test' - - task: PythonScript@0 - displayName: 'Symbolic shape infer' - inputs: - scriptPath: $(Build.BinariesDirectory)/Release/onnxruntime_test_python_symbolic_shape_infer.py - workingDirectory: $(Build.BinariesDirectory)/Release + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash /onnxruntime_src/tools/scripts/python_test.sh /onnxruntime_src /build Debug + displayName: 'Run Debug tests' - task: PublishTestResults@2 displayName: 'Publish unit test results' @@ -221,7 +183,6 @@ stages: testRunTitle: 'Unit Test Run' condition: succeededOrFailed() - - stage: arm64_build dependsOn: [] jobs: diff --git a/tools/scripts/python_test.sh b/tools/scripts/python_test.sh new file mode 100644 index 0000000000000..bfdd4663feede --- /dev/null +++ b/tools/scripts/python_test.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +set -ex + +export src_dir=$1 +export build_dir=$2 +export config=$3 + +# it's for manylinux image +export PATH=/opt/python/cp38-cp38/bin:$PATH + +echo Install Python Deps +cp $src_dir/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $build_dir/requirements.txt + +python3 -m pip install -r $build_dir/requirements.txt +mkdir -p $build_dir/requirements_torch_cpu/ +cp $src_dir/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $build_dir/requirements_torch_cpu/requirements.txt +python3 -m pip install -r $build_dir/requirements_torch_cpu/requirements.txt +python3 -m pip list | grep onnx + +echo Install $config python package +rm -rf $build_dir/$config/onnxruntime $build_dir/$config/pybind11 +python3 -m pip install $build_dir/$config/dist/*.whl + +echo Run $config unit tests +pushd $build_dir/$config/ +python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --build_wheel --enable_onnx_tests --enable_transformers_tool_test --ctest_path "" +popd diff --git a/tools/scripts/symbolic_shape_infer_test.sh b/tools/scripts/symbolic_shape_infer_test.sh new file mode 100644 index 0000000000000..d8d50c5e3fa91 --- /dev/null +++ b/tools/scripts/symbolic_shape_infer_test.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -ex + +export build_dir=$1 + +# it's for manylinux image +export PATH=/opt/python/cp38-cp38/bin:$PATH + +echo Run symbolic shape infer test +pushd $build_dir/Release/ +python3 /build/Release/onnxruntime_test_python_symbolic_shape_infer.py +popd From 54a092c42714caa0d236e0f67b07b0804dbb9490 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 13 Sep 2023 07:26:35 -0700 Subject: [PATCH 013/121] [DML EP] Complete python IO binding implementation (#17344) @fdwr This is the part 2 of the pybind work that was started earlier. This adds the following features to the python IO binding implementation: - Use a bucketized allocator in order to reduce the number of resource allocations - Implement the following functions: `ortvalue_from_numpy`, `update_inplace`, `ortvalue_from_shape_and_type` and `numpy` - Modify the `onnxruntime_test_python_iobinding` tests to also run on DML --------- Co-authored-by: Jeff Bloomfield --- .../src/BucketizedBufferAllocator.cpp | 9 - .../src/BucketizedBufferAllocator.h | 8 +- .../providers/dml/dml_provider_factory.cc | 39 +- .../dml/dml_provider_factory_creator.h | 1 + .../python/onnxruntime_pybind_mlvalue.cc | 107 +++- .../python/onnxruntime_pybind_mlvalue.h | 6 + .../python/onnxruntime_pybind_ortvalue.cc | 31 +- .../python/onnxruntime_pybind_state.cc | 4 + .../onnxruntime_test_python_iobinding.py | 455 ++++++++++-------- tools/ci_build/build.py | 7 +- 10 files changed, 413 insertions(+), 254 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 5dbea41901b80..c24257071eda5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -212,15 +212,6 @@ namespace Dml ORT_THROW_HR(E_INVALIDARG); } const auto* allocInfo = static_cast(opaqueHandle); - - auto owner = allocInfo->GetOwner(); - //The owner can be null if the resource was wrapped via CreateGPUAllocationFromD3DResource - if (owner != nullptr && owner != this) - { - // This allocation doesn't belong to this allocator! - ORT_THROW_HR(E_INVALIDARG); - } - return allocInfo; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 4c24cb174f6ed..196fba5d7689d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -83,16 +83,16 @@ namespace Dml std::vector m_pool; size_t m_currentAllocationId = 0; uint64_t m_currentResourceId = 0; - - // Unless specifically requested, allocation sizes are not rounded to enable pooling - // until SetDefaultRoundingMode is called. This should be done at completion of session + + // Unless specifically requested, allocation sizes are not rounded to enable pooling + // until SetDefaultRoundingMode is called. This should be done at completion of session // initialization. AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Disabled; std::shared_ptr m_context; std::unique_ptr m_subAllocator; - #if _DEBUG + #ifndef NDEBUG // Useful for debugging; keeps track of all allocations that haven't been freed yet std::map m_outstandingAllocationsById; #endif diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index a46f820c6207f..fde61e73c2124 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -128,21 +128,13 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Devic return d3d12_device; } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { - ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); - - D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; - - ComPtr cmd_queue; - ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); - +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) +{ DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE; // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG && !_GAMING_XBOX - ComPtr debug_device; + Microsoft::WRL::ComPtr debug_device; (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); @@ -151,12 +143,27 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int } #endif - ComPtr dml_device; - ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(), - flags, - DML_FEATURE_LEVEL_5_0, - IID_PPV_ARGS(&dml_device))); + Microsoft::WRL::ComPtr dml_device; + ORT_THROW_IF_FAILED(DMLCreateDevice1( + d3d12_device, + flags, + DML_FEATURE_LEVEL_5_0, + IID_PPV_ARGS(&dml_device))); + + return dml_device; +} + +std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { + ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); + + D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; + cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; + + ComPtr cmd_queue; + ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + auto dml_device = CreateDMLDevice(d3d12_device.Get()); return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index b1c9bb3f6f679..574f4410fe3e3 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -16,5 +16,6 @@ struct DMLProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); + static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 10c8a2de7c3df..f470e9f6b6ed1 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -26,7 +26,18 @@ #include "core/framework/provider_options_utils.h" #ifdef USE_DML -#include "core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h" +using Microsoft::WRL::ComPtr; + +#include +#include "core/providers/dml/DmlExecutionProvider/src/External/D3DX12/d3dx12.h" +#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h" +#include "core/providers/dml/DmlExecutionProvider/src/DescriptorPool.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h" +#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h" +#include "core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h" +#include "core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h" +#include "core/providers/dml/DmlExecutionProvider/src/AllocationInfo.h" #endif namespace onnxruntime { @@ -186,6 +197,11 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_DML + +constexpr GUID execution_context_guid = {0x50fd773b, 0x4462, 0x4b28, {0x98, 0x9e, 0x8c, 0xa0, 0x54, 0x05, 0xbd, 0x4a}}; +constexpr GUID upload_heap_guid = {0x125235f9, 0xef41, 0x4043, {0xa4, 0x9d, 0xdd, 0xc9, 0x61, 0xe7, 0xdb, 0xee}}; +constexpr GUID dml_readback_heap_guid = {0x00d32df8, 0xea2d, 0x40bf, {0xa4, 0x47, 0x9c, 0xb4, 0xbc, 0xf1, 0x1d, 0x5e}}; + AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { // Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make // multi-threaded DML allocation work, including maintaining a per-thread DML allocator. @@ -196,13 +212,100 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { auto hit = id_to_allocator_map->find(id); if (hit == id_to_allocator_map->end()) { - auto dml_allocator = std::make_shared(id); + constexpr uint32_t device_id = 0; + auto d3d12_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); + auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get()); + + D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; + cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; + + ComPtr cmd_queue; + ORT_THROW_IF_FAILED( + d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + + auto context = std::make_shared(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get()); + + // We leak the upload and readback heaps to keep them alive, just like the map + auto upload_heap = std::make_unique(d3d12_device.Get(), context).release(); + auto readback_heap = std::make_unique(d3d12_device.Get(), context).release(); + + auto dml_allocator = std::make_shared( + d3d12_device.Get(), + context, + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + std::make_unique(d3d12_device.Get())); + dml_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + context->SetAllocator(dml_allocator); + + auto context_ptr = context.get(); + + ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(execution_context_guid, sizeof(context_ptr), &context_ptr)); + ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(upload_heap_guid, sizeof(upload_heap), &upload_heap)); + ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(dml_readback_heap_guid, sizeof(readback_heap), &readback_heap)); + hit = id_to_allocator_map->emplace(id, std::move(dml_allocator)).first; } return hit->second; } +void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes) { + const auto* allocInfo = static_cast(dst); + ID3D12Resource* dst_data = allocInfo->GetResource(); + + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(dst_data->GetDevice(IID_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + Dml::ExecutionContext* context = nullptr; + uint32_t context_size = gsl::narrow_cast(sizeof(context)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(execution_context_guid, &context_size, &context)); + + Dml::PooledUploadHeap* upload_heap = nullptr; + uint32_t upload_heap_size = gsl::narrow_cast(sizeof(upload_heap)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(upload_heap_guid, &upload_heap_size, &upload_heap)); + + upload_heap->BeginUploadToGpu( + dst_data, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, gsl::make_span(static_cast(src), num_bytes)); + context->Flush(); + + // We don't use the same command queue as the execution provider, so we need to sync to make sure that all data has + // been uploaded to the resource. This function is usually called before inference just to upload initial data to the + // GPU, so it shouldn't be a bottleneck. + context->GetCurrentCompletionEvent().WaitForSignal(); +} + +void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { + const auto* allocInfo = static_cast(src); + ID3D12Resource* src_data = allocInfo->GetResource(); + + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(src_data->GetDevice(IID_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + Dml::ExecutionContext* context = nullptr; + uint32_t context_size = gsl::narrow_cast(sizeof(context)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(execution_context_guid, &context_size, &context)); + + Dml::ReadbackHeap* readback_heap = nullptr; + uint32_t readback_heap_size = gsl::narrow_cast(sizeof(readback_heap)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(dml_readback_heap_guid, &readback_heap_size, &readback_heap)); + + // ReadbackFromGpu already syncs with the CPU and waits for the copy to be completed, so we don't need to sync after + // this call + readback_heap->ReadbackFromGpu( + gsl::make_span(static_cast(dst), num_bytes), src_data, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); +} + +const std::unordered_map* GetDmlToHostMemCpyFunction() { + static std::unordered_map map{ + {OrtDevice::GPU, DmlToCpuMemCpy}}; + + return ↦ +} + #endif #ifdef USE_CANN diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index 4ac9c70468b19..e3f277bcb9c41 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -77,6 +77,12 @@ std::unique_ptr GetGPUDataTransfer(); AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id); +void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); + +void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); + +const std::unordered_map* GetDmlToHostMemCpyFunction(); + #endif #ifdef USE_CANN diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index f9d908e0ac518..dc4a4dcc13b7f 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -63,7 +63,12 @@ void addOrtValueMethods(pybind11::module& m) { // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); - +#elif USE_DML + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + CreateGenericMLValue( + nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); #else throw std::runtime_error( "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " @@ -126,6 +131,12 @@ void addOrtValueMethods(pybind11::module& m) { values_type, *(ml_value->GetMutable()), CpuToRocmMemCpy); +#elif USE_DML + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToDmlMemCpy); #else throw std::runtime_error( "Unsupported GPU device: Cannot find the supported GPU device."); @@ -158,12 +169,18 @@ void addOrtValueMethods(pybind11::module& m) { throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); } allocator = GetCudaAllocator(device.Id()); -#elif USE_DML - allocator = GetDmlAllocator(device.Id()); #else throw std::runtime_error( "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (strcmp(GetDeviceName(device), DML) == 0) { +#if USE_DML + allocator = GetDmlAllocator(device.Id()); +#else + throw std::runtime_error( + "Can't allocate memory on the DirectML device using this package of OnnxRuntime. " + "Please use the DirectML package of OnnxRuntime to use this feature."); #endif } else { throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); @@ -290,11 +307,13 @@ void addOrtValueMethods(pybind11::module& m) { #ifdef USE_CUDA GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCudaToHostMemCpyFunction()); #elif USE_ROCM - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetRocmToHostMemCpyFunction()); + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetRocmToHostMemCpyFunction()); #elif USE_CANN - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCannToHostMemCpyFunction()); + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCannToHostMemCpyFunction()); +#elif USE_DML + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetDmlToHostMemCpyFunction()); #else - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, nullptr); + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, nullptr); #endif return obj; }) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 82d119894a5d8..907ea0ec41e23 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -237,7 +237,11 @@ const char* GetDeviceName(const OrtDevice& device) { case OrtDevice::CPU: return CPU; case OrtDevice::GPU: +#ifdef USE_DML + return DML; +#else return CUDA; +#endif case OrtDevice::FPGA: return "FPGA"; case OrtDevice::NPU: diff --git a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py index 8009d97ba34ce..56417f13fbea4 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py +++ b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py @@ -16,40 +16,43 @@ from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding +test_params = [ + ("cuda", "CUDAExecutionProvider", C_OrtDevice.cuda), + ("dml", "DmlExecutionProvider", C_OrtDevice.dml), +] + class TestIOBinding(unittest.TestCase): - def create_ortvalue_input_on_gpu(self): + def _create_ortvalue_input_on_gpu(self, device): return onnxrt.OrtValue.ortvalue_from_numpy( - np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), "cuda", 0 + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 ) - def create_ortvalue_alternate_input_on_gpu(self): + def _create_ortvalue_alternate_input_on_gpu(self, device): return onnxrt.OrtValue.ortvalue_from_numpy( np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), - "cuda", + device, 0, ) - def create_uninitialized_ortvalue_input_on_gpu(self): - return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, "cuda", 0) + def _create_uninitialized_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) - def create_numpy_input(self): + def _create_numpy_input(self): return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - def create_expected_output(self): + def _create_expected_output(self): return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - def create_expected_output_alternate(self): + def _create_expected_output_alternate(self): return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) def test_bind_input_to_cpu_arr(self): - self.create_numpy_input() - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() # Bind Numpy object (input) that's on CPU to wherever the model needs it - io_binding.bind_cpu_input("X", self.create_numpy_input()) + io_binding.bind_cpu_input("X", self._create_numpy_input()) # Bind output to CPU io_binding.bind_output("Y") @@ -57,254 +60,280 @@ def test_bind_input_to_cpu_arr(self): # Invoke Run session.run_with_iobinding(io_binding) - # Sync if different CUDA streams + # Sync if different streams io_binding.synchronize_outputs() - # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host here) ort_output = io_binding.copy_outputs_to_cpu()[0] # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output)) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) - @unittest.skip("Could not find an implementation for Identity(19) node with name ''") def test_bind_input_types(self): - opset = onnx_opset_version() - devices = [ - ( - C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), - ["CPUExecutionProvider"], - ) - ] - if "CUDAExecutionProvider" in onnxrt.get_all_providers(): - devices.append( - ( - C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0), - ["CUDAExecutionProvider"], - ) - ) - - for device, provider in devices: - for dtype in [ - np.float32, - np.float64, - np.int32, - np.uint32, - np.int64, - np.uint64, - np.int16, - np.uint16, - np.int8, - np.uint8, - np.float16, - np.bool_, - ]: - with self.subTest(dtype=dtype, device=str(device)): - x = np.arange(8).reshape((-1, 2)).astype(dtype) - proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype] - - X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 - Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 - - # inference - node_add = helper.make_node("Identity", ["X"], ["Y"]) - - # graph - graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) - model_def = helper.make_model( - graph_def, - producer_name="dummy", - ir_version=7, - producer_version="0", - opset_imports=[helper.make_operatorsetid("", opset)], - ) - - sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) - - bind = SessionIOBinding(sess._sess) - ort_value = C_OrtValue.ortvalue_from_numpy(x, device) - bind.bind_ortvalue_input("X", ort_value) - bind.bind_output("Y", device) - sess._sess.run_with_iobinding(bind, None) - ortvaluevector = bind.get_outputs() - self.assertIsInstance(ortvaluevector, OrtValueVector) - ortvalue = bind.get_outputs()[0] - y = ortvalue.numpy() - assert_almost_equal(x, y) - - bind = SessionIOBinding(sess._sess) - bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) - bind.bind_output("Y", device) - sess._sess.run_with_iobinding(bind, None) - ortvalue = bind.get_outputs()[0] - y = ortvalue.numpy() - assert_almost_equal(x, y) + for device, execution_provider, generate_device in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") + + opset = onnx_opset_version() + devices = [ + ( + C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), + ["CPUExecutionProvider"], + ), + ( + C_OrtDevice(generate_device(), C_OrtDevice.default_memory(), 0), + [execution_provider], + ), + ] + + for inner_device, provider in devices: + for dtype in [ + np.float32, + np.float64, + np.int32, + np.uint32, + np.int64, + np.uint64, + np.int16, + np.uint16, + np.int8, + np.uint8, + np.float16, + np.bool_, + ]: + with self.subTest(dtype=dtype, inner_device=str(inner_device)): + x = np.arange(8).reshape((-1, 2)).astype(dtype) + proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype] + + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 + + # inference + node_add = helper.make_node("Identity", ["X"], ["Y"]) + + # graph + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) + + bind = SessionIOBinding(sess._sess) + ort_value = C_OrtValue.ortvalue_from_numpy(x, inner_device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", inner_device) + sess._sess.run_with_iobinding(bind, None) + ortvaluevector = bind.get_outputs() + self.assertIsInstance(ortvaluevector, OrtValueVector) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", inner_device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", inner_device) + sess._sess.run_with_iobinding(bind, None) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) def test_bind_input_only(self): - input = self.create_ortvalue_input_on_gpu() + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") + input = self._create_ortvalue_input_on_gpu(device) - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() - # Bind input to CUDA - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + # Bind input to the GPU + io_binding.bind_input("X", device, 0, np.float32, [3, 2], input.data_ptr()) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Sync if different streams + io_binding.synchronize_inputs() - # Bind output to CPU - io_binding.bind_output("Y") + # Bind output to CPU + io_binding.bind_output("Y") - # Invoke Run - session.run_with_iobinding(io_binding) + # Invoke Run + session.run_with_iobinding(io_binding) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Sync if different streams + io_binding.synchronize_outputs() - # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) - ort_output = io_binding.copy_outputs_to_cpu()[0] + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output = io_binding.copy_outputs_to_cpu()[0] - # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output)) + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) def test_bind_input_and_preallocated_output(self): - input = self.create_ortvalue_input_on_gpu() + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() - - # Bind input to CUDA - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) - - # Bind output to CUDA - output = self.create_uninitialized_ortvalue_input_on_gpu() - io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) - - # Sync if different CUDA streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) + input = self._create_ortvalue_input_on_gpu(device) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() - # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) - ort_output_vals = io_binding.copy_outputs_to_cpu()[0] - # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals)) + # Bind input to the GPU + io_binding.bind_input("X", device, 0, np.float32, [3, 2], input.data_ptr()) - # Validate if ORT actually wrote to pre-allocated buffer by copying the Torch allocated buffer - # to the host and validating its contents - ort_output_vals_in_cpu = output.numpy() - # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals_in_cpu)) + # Bind output to the GPU + output = self._create_uninitialized_ortvalue_input_on_gpu(device) + io_binding.bind_output("Y", device, 0, np.float32, [3, 2], output.data_ptr()) - def test_bind_input_and_non_preallocated_output(self): - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() + # Sync if different streams + io_binding.synchronize_inputs() - # Bind input to CUDA - io_binding.bind_input( - "X", - "cuda", - 0, - np.float32, - [3, 2], - self.create_ortvalue_input_on_gpu().data_ptr(), - ) + # Invoke Run + session.run_with_iobinding(io_binding) - # Bind output to CUDA - io_binding.bind_output("Y", "cuda") + # Sync if different streams + io_binding.synchronize_outputs() - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output_vals = io_binding.copy_outputs_to_cpu()[0] + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) - # Invoke Run - session.run_with_iobinding(io_binding) + # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer + # to the host and validating its contents + ort_output_vals_in_cpu = output.numpy() + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + def test_bind_input_and_non_preallocated_output(self): + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() + + input = self._create_ortvalue_input_on_gpu(device) + + # Bind input to the GPU + io_binding.bind_input( + "X", + device, + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) - # This call returns an OrtValue which has data allocated by ORT on CUDA - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy())) - - # We should be able to repeat the above process as many times as we want - try once more - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy())) - - # Change the bound input and validate the results in the same bound OrtValue - # Bind alternate input to CUDA - io_binding.bind_input( - "X", - "cuda", - 0, - np.float32, - [3, 2], - self.create_ortvalue_alternate_input_on_gpu().data_ptr(), - ) + # Bind output to the GPU + io_binding.bind_output("Y", device) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), device) + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + # We should be able to repeat the above process as many times as we want - try once more + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), device) + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + input = self._create_ortvalue_alternate_input_on_gpu(device) + + # Change the bound input and validate the results in the same bound OrtValue + # Bind alternate input to the GPU + io_binding.bind_input( + "X", + device, + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Sync if different streams + io_binding.synchronize_inputs() - # Invoke Run - session.run_with_iobinding(io_binding) + # Invoke Run + session.run_with_iobinding(io_binding) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Sync if different streams + io_binding.synchronize_outputs() - # This call returns an OrtValue which has data allocated by ORT on CUDA - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self.create_expected_output_alternate(), ort_outputs[0].numpy())) + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), device) + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) def test_bind_input_and_bind_output_with_ortvalues(self): - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") - # Bind ortvalue as input - input_ortvalue = self.create_ortvalue_input_on_gpu() - io_binding.bind_ortvalue_input("X", input_ortvalue) + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() - # Bind ortvalue as output - output_ortvalue = self.create_uninitialized_ortvalue_input_on_gpu() - io_binding.bind_ortvalue_output("Y", output_ortvalue) + # Bind ortvalue as input + input_ortvalue = self._create_ortvalue_input_on_gpu(device) + io_binding.bind_ortvalue_input("X", input_ortvalue) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Bind ortvalue as output + output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu(device) + io_binding.bind_ortvalue_output("Y", output_ortvalue) - # Invoke Run - session.run_with_iobinding(io_binding) + # Sync if different streams + io_binding.synchronize_inputs() - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Invoke Run + session.run_with_iobinding(io_binding) - # Inspect contents of output_ortvalue and make sure that it has the right contents - self.assertTrue(np.array_equal(self.create_expected_output(), output_ortvalue.numpy())) + # Sync if different streams + io_binding.synchronize_outputs() - # Bind another ortvalue as input - input_ortvalue_2 = self.create_ortvalue_alternate_input_on_gpu() - io_binding.bind_ortvalue_input("X", input_ortvalue_2) + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Bind another ortvalue as input + input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu(device) + io_binding.bind_ortvalue_input("X", input_ortvalue_2) - # Invoke Run - session.run_with_iobinding(io_binding) + # Sync if different streams + io_binding.synchronize_inputs() - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() - # Inspect contents of output_ortvalue and make sure that it has the right contents - self.assertTrue(np.array_equal(self.create_expected_output_alternate(), output_ortvalue.numpy())) + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) if __name__ == "__main__": diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 48129e15934dc..c4fb5499983cb 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1840,13 +1840,12 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): [sys.executable, "onnxruntime_test_python_symbolic_shape_infer.py"], cwd=cwd, dll_path=dll_path ) - # For CUDA enabled builds test IOBinding feature - if args.use_cuda: - # We need to have Torch installed to test the IOBinding feature - # which currently uses Torch's allocator to allocate GPU memory for testing + # For CUDA or DML enabled builds test IOBinding feature + if args.use_cuda or args.use_dml: log.info("Testing IOBinding feature") run_subprocess([sys.executable, "onnxruntime_test_python_iobinding.py"], cwd=cwd, dll_path=dll_path) + if args.use_cuda: log.info("Testing CUDA Graph feature") run_subprocess([sys.executable, "onnxruntime_test_python_cudagraph.py"], cwd=cwd, dll_path=dll_path) From 5d3786206bbe8a84b9d3f9c62e310258e3570ff2 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 13 Sep 2023 08:50:14 -0700 Subject: [PATCH 014/121] Fix ROCM's nightly build (#17518) ### Description PR 15470 updated some C/C++ dependencies. The change caused ROCM EP's nightly build to fail. see issue https://github.com/ROCm-Developer-Tools/HIP/issues/2082 for a background. So, the root cause is HIP compiler has a special requirement that HIP's include dirs must be used before the operating system's include folder: /usr/include. HIP adds "-isystem" in front of "/usr/include". gcc or clang will search the folders added with "-I" first, then the "-isystem" folder. It works fine as long as we do not add "-I/usr/include" to the compile commands for *.cu files. It would be wrong if we already have installed an open source library to /usr and want to use the prebuilt library from there instead of the current build dir. ### Motivation and Context --- tools/ci_build/github/azure-pipelines/templates/rocm.yml | 2 +- tools/ci_build/github/linux/build_rocm_c_api_package.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index fe0f2c3791e72..cc2e8745e8946 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -91,7 +91,7 @@ jobs: --enable_training \ --cmake_extra_defines \ CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ \ - onnxruntime_BUILD_UNIT_TESTS=OFF \ + onnxruntime_BUILD_UNIT_TESTS=OFF FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ ${{ variables['EnableProfiling'] }} workingDirectory: $(Build.SourcesDirectory) displayName: 'Build onnxruntime (in container)' diff --git a/tools/ci_build/github/linux/build_rocm_c_api_package.sh b/tools/ci_build/github/linux/build_rocm_c_api_package.sh index 4d0af63893643..957f1f8a812a5 100755 --- a/tools/ci_build/github/linux/build_rocm_c_api_package.sh +++ b/tools/ci_build/github/linux/build_rocm_c_api_package.sh @@ -40,7 +40,7 @@ docker run --rm \ --use_rocm --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME \ --build_shared_lib \ --skip_submodule_sync \ - --skip_tests \ + --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER EXIT_CODE=$? From a2e75114cc9c38ab4c48fac4c884d53f2d0d17d4 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:17:34 -0700 Subject: [PATCH 015/121] [js/web] add sessionOptions.freeDimensionOverrides (#17488) ### Description Allows to specify fixed size for dynamic input of a model. resolves #16707 Pending test --- js/common/lib/inference-session.ts | 7 +++++++ js/web/lib/wasm/binding/ort-wasm.d.ts | 1 + js/web/lib/wasm/session-options.ts | 15 +++++++++++++++ onnxruntime/wasm/api.cc | 6 ++++++ onnxruntime/wasm/api.h | 7 +++++++ 5 files changed, 36 insertions(+) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index ec030084c9675..71a5912df2464 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -66,6 +66,13 @@ export declare namespace InferenceSession { */ interOpNumThreads?: number; + /** + * The free dimension override. + * + * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend + */ + freeDimensionOverrides?: {readonly [dimensionName: string]: number}; + /** * The optimization level. * diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 7f0430b7b28b9..59da1369e152e 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -54,6 +54,7 @@ export interface OrtWasmModule extends EmscriptenModule { enableProfiling: boolean, profileFilePrefix: number, logId: number, logSeverityLevel: number, logVerbosityLevel: number, optimizedModelFilePath: number): number; _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; + _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 38caa9076e3c0..2659b471733f5 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -143,6 +143,21 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.freeDimensionOverrides) { + for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { + if (typeof name !== 'string') { + throw new Error(`free dimension override name must be a string: ${name}`); + } + if (typeof value !== 'number' || !Number.isInteger(value) || value < 0) { + throw new Error(`free dimension override value must be a non-negative integer: ${value}`); + } + const nameOffset = allocWasmString(name, allocs); + if (wasm._OrtAddFreeDimensionOverride(sessionOptionsHandle, nameOffset, value) !== 0) { + checkLastError(`Can't set a free dimension override: ${name} - ${value}.`); + } + } + } + if (sessionOptions.extra !== undefined) { iterateExtraOptions(sessionOptions.extra, '', new WeakSet>(), (key, value) => { const keyDataOffset = allocWasmString(key, allocs); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 937d505015d3c..174edabbc91fe 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -155,6 +155,12 @@ int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, con return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, nullptr, nullptr, 0); } +int OrtAddFreeDimensionOverride(ort_session_options_handle_t session_options, + const char* dim_param_name, + int dim_value) { + return CHECK_STATUS(AddFreeDimensionOverrideByName, session_options, dim_param_name, dim_value); +} + int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, const char* config_key, const char* config_value) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index b9103414aae67..398c901e0e5ed 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -84,6 +84,13 @@ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options, const char* name); +/** + * add a free dimension override for one dimension of a session's input. + */ +int EMSCRIPTEN_KEEPALIVE OrtAddFreeDimensionOverride(ort_session_options_handle_t session_options, + const char* dim_param_name, + int dim_value); + /** * store configurations for a session. * @param session_options a handle to session options created by OrtCreateSessionOptions From 4e37c5d1f0c54d43ba2662f490d84cd0333c8813 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:22:21 -0700 Subject: [PATCH 016/121] Bump actions/checkout from 3 to 4 (#17487) --- .github/workflows/cffconvert.yml | 2 +- .github/workflows/codeql.yml | 2 +- .github/workflows/gradle-wrapper-validation.yml | 2 +- .github/workflows/lint.yml | 6 +++--- .github/workflows/linux.yml | 2 +- .github/workflows/publish-c-apidocs.yml | 2 +- .github/workflows/publish-csharp-apidocs.yml | 2 +- .github/workflows/publish-java-apidocs.yml | 2 +- .github/workflows/publish-js-apidocs.yml | 2 +- .github/workflows/publish-objectivec-apidocs.yml | 2 +- .github/workflows/publish-python-apidocs.yml | 2 +- .github/workflows/windows.yml | 4 ++-- 12 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 6851c52d380ec..7144363717749 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out a copy of the repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 2fe66013ebbbc..d3ecf44fe5733 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 07346b38b2151..03ea773a25130 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -10,5 +10,5 @@ jobs: name: "Validation" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: gradle/wrapper-validation-action@v1 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 91f9a8ee3df40..432c789e943b5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,7 +12,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -34,7 +34,7 @@ jobs: name: Python format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v4 with: @@ -100,7 +100,7 @@ jobs: name: Lint JavaScript runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: reviewdog/action-eslint@v1 with: reporter: github-pr-check diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index dceb15b446a8a..7b314d845d9b4 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -14,7 +14,7 @@ jobs: Onnxruntime-TVM: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - uses: actions/setup-python@v4 diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 73e8194bf7a8d..0a3e9ed2594c1 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 097d4a1cdff5e..9b9ca924bd008 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup .NET uses: actions/setup-dotnet@v3 with: diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index cea350ba54de0..9ea9bda7e7c53 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up JDK 11 uses: actions/setup-java@v3 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index 5668be77c98a4..ba8bfd718abfa 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Node.js uses: actions/setup-node@v3 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index b966793cc0d06..1b327eebfa8a8 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-13 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Jazzy run: | diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 4ca1249fc1d8e..ab9d4781afb83 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 8cd62db77b744..ba24e7eebfb03 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -18,7 +18,7 @@ jobs: Windows-CUDA-12: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: false - uses: actions/setup-python@v4 @@ -46,7 +46,7 @@ jobs: Onnxruntime-TVM: runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - uses: actions/setup-python@v4 From 7edff1c2bfca812ce352b8df8efc6c904c4292ca Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 13 Sep 2023 13:02:58 -0700 Subject: [PATCH 017/121] [DML EP] Add subgraph fusion support (#17504) ### Description ### Motivation and Context --- .../src/DmlGraphFusionTransformer.cpp | 37 +++++++++++++- .../src/DmlGraphFusionTransformer.h | 49 +++++++++++-------- .../src/GraphPartitioner.cpp | 18 +++---- .../src/GraphPartitioner.h | 5 +- 4 files changed, 76 insertions(+), 33 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index a9d19a022d3e7..4813707cdf50c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -38,6 +38,16 @@ namespace Dml bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graph_level, logger, {}); + } + + onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const { onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider; const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); @@ -49,6 +59,30 @@ namespace Dml std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; + onnxruntime::GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) + { + auto* node = graph.GetNode(node_index); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs)); + } + } + do { // Initializers needed by any graph partition @@ -62,7 +96,8 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, - additionalSplittingNodes); + additionalSplittingNodes, + implicitInputDefs); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index b546f29f59719..19dab0c89943c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -2,32 +2,41 @@ // Licensed under the MIT License. #pragma once - +#include +#include #include "core/optimizer/graph_transformer.h" #include "core/framework/execution_providers.h" namespace Dml { - class ExecutionProviderImpl; +class ExecutionProviderImpl; + +class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain"; - class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer - { - public: - DmlGraphFusionTransformer( - const std::string& name, - const onnxruntime::IExecutionProvider* provider - ); +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger) const final; - public: - inline const static char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_"; - inline const static char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain"; + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; - private: - onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, - bool& modified, - int graph_level, - const onnxruntime::logging::Logger& logger) const final; - private: - const ExecutionProviderImpl* m_providerImpl = nullptr; - }; +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 2c8d4e4459f7f..18943878ccedc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -345,13 +345,8 @@ namespace Dml // Whether any operator in the model contains a subgraph. This is true // if the graph being partitioned is itself within a subgraph, or contains // an operator with a subgraph. - bool ModelUsesSubgraph(const onnxruntime::GraphViewer& graph) + bool ContainsSubgraph(const onnxruntime::GraphViewer& graph) { - if (graph.IsSubgraph()) - { - return true; - } - const std::vector& toplogicalOrder = graph.GetNodesInTopologicalOrder(); for (size_t nodeIndex : toplogicalOrder) @@ -384,7 +379,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - gsl::span additionalSplittingNodes) + gsl::span additionalSplittingNodes, + const std::unordered_map& implicitInputs) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -419,7 +415,7 @@ namespace Dml } // Check whether this graph is a subgraph, or contains any node with a subgraph. - bool modelUsesSubgraph = ModelUsesSubgraph(graph); + bool containsSubgraph = ContainsSubgraph(graph); uint32_t splittingNodeIndex = 0; @@ -454,10 +450,10 @@ namespace Dml // Add a unique partition if graph node usage is not supported. // // Partitioning is disabled in models with subgraphs to work around issues with implicit inputs. - // The partitioning algorithm does not currently consider such inputs. Transfering shared initializers + // The partitioning algorithm does not currently consider such inputs. Transferring shared initializers // for partitions could also cause problems. Note, operators with subgraphs are currently not efficient // anyhow due to CPU/GPU copies. - if (modelUsesSubgraph || !isDmlGraphNode) + if (containsSubgraph || !isDmlGraphNode) { partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap)); continue; @@ -505,7 +501,7 @@ namespace Dml firstNonFinalInputPartition->AddInput(arg->Name()); } - if (graphInputs.find(arg->Name()) != graphInputs.end()) + if (graphInputs.find(arg->Name()) != graphInputs.end() || implicitInputs.find(arg->Name()) != implicitInputs.end()) { firstNonFinalInputPartition->AddInput(arg->Name()); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 990ba00fc4672..37d577f647fb5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include "core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h" namespace Dml @@ -48,5 +50,6 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - gsl::span additionalSplittingNodes); + gsl::span additionalSplittingNodes, + const std::unordered_map& implicitInputs); } // namespace Dml From 03b56f7a73319e95e381a623325207665c6ac037 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Thu, 14 Sep 2023 00:11:17 +0400 Subject: [PATCH 018/121] [js/webgpu] FP16 extension registration (#17493) ### Description First small change to support FP16 --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/web/lib/wasm/jsep/backend-webgpu.ts | 8 ++++++-- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 9 ++++++--- js/web/lib/wasm/jsep/webgpu/program-manager.ts | 7 +++++-- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 653957a9a3489..e6e78df2cfb23 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -110,6 +110,7 @@ export class WebGpuBackend { } this.env = env; + const requiredFeatures: GPUFeatureName[] = []; const deviceDescriptor: GPUDeviceDescriptor = { requiredLimits: { maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize, @@ -121,13 +122,16 @@ export class WebGpuBackend { maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY, maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ, }, + requiredFeatures, }; // WebGPU Spec: Timestamp Queries Inside Passes // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md if (adapter.features.has('timestamp-query-inside-passes')) { this.supportTimestampQuery = true; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - deviceDescriptor.requiredFeatures = ['timestamp-query-inside-passes' as any]; + requiredFeatures.push('timestamp-query-inside-passes' as GPUFeatureName); + } + if (adapter.features.has('shader-f16')) { + requiredFeatures.push('shader-f16'); } this.device = await adapter.requestDevice(deviceDescriptor); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c96f4858db2ae..f3845e3110905 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -192,11 +192,14 @@ export interface IndicesHelper { } const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { + if (components === 3) { + throw new Error('vec3 has same alignment as vec4, use vec4 instead'); + } + // return type is [ storage type, runtime type ] or a single string for both switch (type) { - // TODO: enable after "shader-f16" WSGL extension release - // case DataType.float16: - // return components > 1 ? `vec${components}` : 'f16'; + case DataType.float16: + return components > 1 ? `vec${components}` : 'f16'; case DataType.float: return components > 1 ? `vec${components}` : 'f32'; case DataType.int32: diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index a02d2ebeebf78..cce61be3448cd 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -126,10 +126,13 @@ export class ProgramManager { } build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { const device = this.backend.device; - + const extensions: string[] = []; + if (device.features.has('shader-f16')) { + extensions.push('enable f16;'); + } const shaderHelper = createShaderHelper(normalizedDispatchGroupSize); const userCode = programInfo.getShaderSource(shaderHelper); - const code = `${shaderHelper.additionalImplementations}\n${userCode}`; + const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({code, label: programInfo.name}); LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`); From 32f5658abb5f2fc82c981b623d8e7969eb46e14f Mon Sep 17 00:00:00 2001 From: cao lei Date: Wed, 13 Sep 2023 21:47:43 -0700 Subject: [PATCH 019/121] remove gsl to make status.h independent from gsl (#17402) ### Description Make status.h independent from gsl. ### Motivation and Context In the coming new feature external EP API (see the prototype https://github.com/microsoft/onnxruntime/pull/16718), we need to expose stream in the public header, however, stream is dependent on status.h which is dependent on gsl. We are seeking a way to decouple stream from gsl. From Changming's comment offline, prefast is disabled so all GSL_SUPPRESS are not taking any effect now. He will handle the warnings when enable prefast in the future --- include/onnxruntime/core/common/status.h | 3 --- onnxruntime/core/framework/tuning_context.h | 1 + .../dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h | 1 + .../core/providers/nnapi/nnapi_builtin/builders/helper.h | 1 + winml/lib/Api.Ort/OnnxruntimeEngine.h | 1 + 5 files changed, 4 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index d6e1992944feb..8f171daabbb1e 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -19,7 +19,6 @@ limitations under the License. #ifdef _WIN32 #include #endif -#include "core/common/gsl.h" namespace onnxruntime { namespace common { @@ -121,10 +120,8 @@ class [[nodiscard]] Status { Status(StatusCategory category, int code); - GSL_SUPPRESS(r.11) Status(const Status& other) : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {} - GSL_SUPPRESS(r.11) Status& operator=(const Status& other) { if (state_ != other.state_) { if (other.state_ == nullptr) { diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index b6569a21e4c91..aae70d85814bc 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include "core/common/common.h" diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index e9c63cc72a837..f94270cfadb8b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -5,6 +5,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" +#include "core/common/gsl.h" #ifdef ORT_NO_EXCEPTIONS #define ML_CHECK_BOOL(x) ORT_THROW_HR_IF(E_INVALIDARG, !(x)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index 421c55a2c91a8..766034b3decea 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -8,6 +8,7 @@ #include "core/common/inlined_containers.h" #include "core/graph/basic_types.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksTypes.h" +#include "core/common/gsl.h" // This is the minimal Android API Level required by ORT NNAPI EP to run // ORT running on any host system with Android API level less than this will fall back to CPU EP diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h index 5974d46b82c4f..eae7dc37941c7 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -3,6 +3,7 @@ #include "iengine.h" #include "UniqueOrtPtr.h" +#include "core/common/gsl.h" #include #include From ad369a1fadb9939b615c442f2a9f40ef5e7b6ec0 Mon Sep 17 00:00:00 2001 From: Hans Date: Thu, 14 Sep 2023 13:02:27 +0800 Subject: [PATCH 020/121] [js/rn] Support create boolean tensor (#17052) ### Description For some use case need to create boolean tensor. I've tested on [this project](https://github.com/hans00/react-native-transformers-example) ### Motivation and Context Add handle `ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL` And it required #15556 (It seems not include in latest release (v1.15.1)) --- .../reactnative/TensorHelperTest.java | 28 +++++++++++++++++++ .../onnxruntime/reactnative/TensorHelper.java | 6 +++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java index 76fd608e4362b..72518488e6682 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java @@ -238,6 +238,34 @@ public void createInputTensor_double() throws Exception { outputTensor.close(); } + @Test + public void createInputTensor_bool() throws Exception { + OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new boolean[] {false, true}); + + JavaOnlyMap inputTensorMap = new JavaOnlyMap(); + + JavaOnlyArray dims = new JavaOnlyArray(); + dims.pushInt(2); + inputTensorMap.putArray("dims", dims); + + inputTensorMap.putString("type", TensorHelper.JsTensorTypeBool); + + ByteBuffer dataByteBuffer = ByteBuffer.allocate(2); + dataByteBuffer.put((byte)0); + dataByteBuffer.put((byte)1); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); + + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); + + Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); + Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); + Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); + Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); + + inputTensor.close(); + outputTensor.close(); + } + @Test public void createOutputTensor_bool() throws Exception { MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java index d9c2e3bac5d9b..63cddace36640 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java @@ -174,7 +174,11 @@ private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType tensorType tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8); break; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + ByteBuffer buffer = values; + tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.BOOL); + break; + } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: From 5af6279440a0db698b0afbfc3de655400562831f Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 14 Sep 2023 07:36:01 -0700 Subject: [PATCH 021/121] Fix Android build (#17540) ### Description The new cpuinfo library doesn't use clog on Android. Newer XNNPack versions have removed the dependency on clog, but the one we use still has it. So I cherry-pick the XNNPack to our patch file. --- .../xnnpack/AddEmscriptenAndIosSupport.patch | 48 +++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 7296f2f30f286..37bdbf9fb53f6 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d53c48aa1..4c987bd7a 100755 +index d53c48aa1..77c3cf983 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -105,7 +105,7 @@ ENDIF() +@@ -105,22 +105,12 @@ ENDIF() IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") @@ -11,7 +11,22 @@ index d53c48aa1..4c987bd7a 100755 MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}") ENDIF() -@@ -7108,6 +7108,10 @@ IF(MSVC) + # ---[ Download deps + IF(NOT XNNPACK_USE_SYSTEM_LIBS) +- IF(NOT DEFINED CLOG_SOURCE_DIR) +- MESSAGE(STATUS "Downloading clog to ${CMAKE_BINARY_DIR}/clog-source (define CLOG_SOURCE_DIR to avoid it)") +- CONFIGURE_FILE(cmake/DownloadCLog.cmake "${CMAKE_BINARY_DIR}/clog-download/CMakeLists.txt") +- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . +- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") +- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . +- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") +- SET(CLOG_SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" CACHE STRING "clog source directory") +- ENDIF() +- + IF(NOT DEFINED CPUINFO_SOURCE_DIR) + MESSAGE(STATUS "Downloading cpuinfo to ${CMAKE_BINARY_DIR}/cpuinfo-source (define CPUINFO_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CMAKE_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") +@@ -7108,6 +7098,10 @@ IF(MSVC) SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") SET_PROPERTY(SOURCE ${COLD_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O1 >") @@ -22,3 +37,30 @@ index d53c48aa1..4c987bd7a 100755 ELSE() SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") +@@ -7142,26 +7136,6 @@ IF(LIBM) + TARGET_LINK_LIBRARIES(indirection PRIVATE ${LIBM}) + ENDIF() + +-# ---[ Configure clog +-IF(NOT TARGET clog) +- IF(NOT XNNPACK_USE_SYSTEM_LIBS) +- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") +- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") +- ADD_SUBDIRECTORY( +- "${CLOG_SOURCE_DIR}/deps/clog" +- "${CMAKE_BINARY_DIR}/clog") +- # We build static version of clog but a dynamic library may indirectly depend on it +- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) +- ELSE() +- ADD_LIBRARY(clog STATIC IMPORTED) +- FIND_LIBRARY(CLOG_LIBRARY clog) +- IF(NOT CLOG_LIBRARY) +- MESSAGE(FATAL_ERROR "Cannot find clog") +- ENDIF() +- SET_PROPERTY(TARGET clog PROPERTY IMPORTED_LOCATION "${CLOG_LIBRARY}") +- ENDIF() +-ENDIF() +- + # ---[ Configure cpuinfo + IF(NOT TARGET cpuinfo) + IF(NOT XNNPACK_USE_SYSTEM_LIBS) From 7af2f68ef307a66a770a620ad2611e511bb2008f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:05:31 -0700 Subject: [PATCH 022/121] [js/web] add a test flag to customize chromium flags (#17545) ### Description add a test flag to customize chromium flags. Usage: npm test -- \ --chromium-flags=<...> --- js/web/karma.conf.js | 55 +++++++------------ js/web/script/test-runner-cli-args.ts | 16 +++++- js/web/script/test-runner-cli.ts | 46 +++++++++------- .../azure-pipelines/templates/win-web-ci.yml | 2 +- 4 files changed, 62 insertions(+), 57 deletions(-) diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 63c6f5bb045a7..35f782d1fdca3 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -3,10 +3,22 @@ 'use strict'; -const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined; -const karmaPlugins = require('minimist')(process.argv)['karma-plugins'] || undefined; -const timeoutMocha = require('minimist')(process.argv)['timeout-mocha'] || 60000; -const forceLocalHost = !!require('minimist')(process.argv)['force-localhost']; +const args = require('minimist')(process.argv, {}); +const bundleMode = args['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined; +const karmaPlugins = args['karma-plugins'] || undefined; +const timeoutMocha = args['timeout-mocha'] || 60000; +const forceLocalHost = !!args['force-localhost']; + +// parse chromium flags +let chromiumFlags = args['chromium-flags']; +if (!chromiumFlags) { + chromiumFlags = []; +} else if (typeof chromiumFlags === 'string') { + chromiumFlags = [chromiumFlags]; +} else if (!Array.isArray(chromiumFlags)) { + throw new Error(`Invalid command line arg: --chromium-flags: ${chromiumFlags}`); +} + const commonFile = bundleMode === 'dev' ? '../common/dist/ort-common.js' : '../common/dist/ort-common.min.js' const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js'; @@ -91,37 +103,10 @@ module.exports = function(config) { listenAddress, customLaunchers: { // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. - EdgeWebGpuTest: {base: 'Edge', flags: ['--ignore-gpu-blocklist', '--gpu-vendor-id=0x10de']}, - ChromeTest: {base: 'Chrome', flags: ['--enable-features=SharedArrayBuffer']}, - ChromeTestHeadless: {base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer']}, - ChromeDebug: - {debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer']}, - ChromeCanaryTest: { - base: 'ChromeCanary', - flags: ['--enable-features=SharedArrayBuffer', '--enable-experimental-web-platform-features'] - }, - ChromeCanaryDebug: { - debug: true, - base: 'ChromeCanary', - flags: [ - '--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', - '--enable-experimental-web-platform-features' - ] - }, - ChromeWebGpuProfileTest: { - base: 'Chrome', - flags: - ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--disable-dawn-features=disallow_unsafe_apis'] - }, - ChromeWebGpuProfileDebug: { - debug: true, - base: 'Chrome', - flags: [ - '--remote-debugging-port=9333', - '--enable-features=SharedArrayBuffer', - '--disable-dawn-features=disallow_unsafe_apis', - ] - }, + EdgeTest: {base: 'Edge', flags: chromiumFlags}, + ChromeTest: {base: 'Chrome', flags: chromiumFlags}, + ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags}, + ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags}, // // ==== BrowserStack browsers ==== // diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index e34529fa1037d..7b41850948149 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -80,6 +80,7 @@ Options: --no-sandbox This flag will be passed to Chrome. Sometimes Chrome need this flag to work together with Karma. + --chromium-flags=<...> This flag will be passed to Chrome and Edge browsers. Can be used multiple times. Examples: @@ -173,6 +174,7 @@ export interface TestRunnerCliArgs { webglOptions?: InferenceSession.WebGLExecutionProviderOption; globalEnvFlags?: Test.Options['globalEnvFlags']; noSandbox?: boolean; + chromiumFlags: string[]; } @@ -439,6 +441,17 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // Option: --no-sandbox const noSandbox = !!args['no-sandbox']; + // parse chromium flags + let chromiumFlags = args['chromium-flags']; + if (!chromiumFlags) { + chromiumFlags = []; + } else if (typeof chromiumFlags === 'string') { + chromiumFlags = [chromiumFlags]; + } else if (!Array.isArray(chromiumFlags)) { + throw new Error(`Invalid command line arg: --chromium-flags: ${chromiumFlags}`); + } + + npmlog.verbose('TestRunnerCli.Init', ` Mode: ${mode}`); npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); @@ -462,6 +475,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs webglOptions, wasmOptions, globalEnvFlags, - noSandbox + noSandbox, + chromiumFlags }; } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index fa4312ee0aaf3..520ef62b2c719 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -475,10 +475,12 @@ async function main() { args.bundleMode === 'perf' ? 'perf' : args.debug ? 'debug' : 'test', - webgpu, webnn, config.options.globalEnvFlags?.webgpu?.profilingMode === 'default'); + webgpu, webnn); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; + const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags]; if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); + chromiumFlags.push('--remote-debugging-port=9333'); } else { karmaArgs.push('--single-run'); } @@ -488,7 +490,22 @@ async function main() { if (webgpu || webnn) { karmaArgs.push('--force-localhost'); } + if (webgpu) { + if (browser.includes('Canary')) { + chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); + } else { + chromiumFlags.push('--enable-dawn-features=use_dxc'); + chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); + } + } + if (webnn) { + chromiumFlags.push('--enable-experimental-web-platform-features'); + } + if (config.options.globalEnvFlags?.webgpu?.profilingMode === 'default') { + chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); + } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); + karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { // There are currently 2 Edge browser launchers: // - karma-edge-launcher: used to launch the old Edge browser @@ -580,12 +597,12 @@ async function main() { } function getBrowserNameFromEnv( - env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean, profile: boolean) { + env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu, webnn, profile); + return selectChromeBrowser(mode, webgpu, webnn); case 'edge': - return webgpu ? 'EdgeWebGpuTest' : 'Edge'; + return 'EdgeTest'; case 'firefox': return 'Firefox'; case 'electron': @@ -599,25 +616,14 @@ async function main() { } } - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean, profile: boolean) { - if (webgpu) { - switch (mode) { - case 'debug': - return profile ? 'ChromeWebGpuProfileDebug' : 'ChromeDebug'; - default: - return profile ? 'ChromeWebGpuProfileTest' : 'ChromeTest'; - } - } else if (webnn) { - switch (mode) { - case 'debug': - return 'ChromeCanaryDebug'; - default: - return 'ChromeCanaryTest'; - } + function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean) { + if (webnn) { + return 'ChromeCanaryTest'; + } else if (webgpu) { + return 'ChromeTest'; } else { switch (mode) { case 'debug': - return 'ChromeDebug'; case 'perf': return 'ChromeTest'; default: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 713396dd64532..bad7448715936 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -161,7 +161,7 @@ jobs: displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' condition: ne('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu + npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu --chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: ne('${{ parameters.RunWebGpuTests }}', 'false') From e11849e716d7d63d0ff2eee8e66b93406c1ef547 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:58:25 -0700 Subject: [PATCH 023/121] Configure StringNormalizer `default_locale` for _APPLE_ system (#17339) ### Description As title. iOS language code uses different syntax for specifying language code/region code: https://developer.apple.com/documentation/xcode/choosing-localization-regions-and-scripts current `default_locale` is not working for iOS. ### Motivation and Context Issue: https://github.com/microsoft/onnxruntime/issues/17017 --------- Co-authored-by: rachguo Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../core/providers/cpu/nn/string_normalizer.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.cc b/onnxruntime/core/providers/cpu/nn/string_normalizer.cc index 330b92d4a8f78..27407e999945a 100644 --- a/onnxruntime/core/providers/cpu/nn/string_normalizer.cc +++ b/onnxruntime/core/providers/cpu/nn/string_normalizer.cc @@ -201,7 +201,16 @@ class Utf8Converter { #endif -const std::string default_locale("en_US.UTF-8"); // All non-MS +#if defined(__APPLE__) +#include +#if TARGET_OS_IPHONE || TARGET_OS_SIMULATOR +const std::string default_locale("en-US.UTF-8"); +#else +const std::string default_locale("en_US.UTF-8"); // Other kinds of Apple Platforms including MacOS, etc +#endif +#else +const std::string default_locale("en_US.UTF-8"); // All non-MS and not Apple +#endif #endif // _MSC_VER From 198d4688490e8dfd447ba2cc7c99bb30c9b0ef92 Mon Sep 17 00:00:00 2001 From: xhcao Date: Fri, 15 Sep 2023 04:14:11 +0800 Subject: [PATCH 024/121] [WebGPU/JS] Added Pad operator support (#16928) ### Description ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 252 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 11 +- .../providers/js/js_execution_provider.cc | 12 + .../core/providers/js/operators/pad.cc | 72 +++++ onnxruntime/core/providers/js/operators/pad.h | 34 +++ 7 files changed, 379 insertions(+), 5 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/pad.ts create mode 100644 onnxruntime/core/providers/js/operators/pad.cc create mode 100644 onnxruntime/core/providers/js/operators/pad.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index de53f943bc9ef..71d98f5d73671 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -59,6 +59,7 @@ Do not modify directly.* | Mul | ai.onnx(7-12,13,14+) | | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | +| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 9c46b97694903..e92e6696d9a78 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -14,6 +14,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; @@ -80,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Mul', [binaryOps.mul]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], + ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin, parseReduceAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts new file mode 100644 index 0000000000000..d90296b5c5a46 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -0,0 +1,252 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface PadAttributes extends AttributeWithCacheKey { + // 0-constant, 1-reflect, 2-edge, 3-wrap + readonly mode: number; + readonly value: number; + readonly pads: number[]; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length < 1) { + throw new Error('Too few inputs'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Input type must be float.'); + } + + if (inputs.length >= 2) { + let validPads = inputs[0].dims.length * 2 === inputs[1].dims[0]; + if (inputs.length === 4) { + validPads = inputs[3].dims[0] * 2 === inputs[1].dims[0]; + } + if (!validPads) { + throw new Error('The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].'); + } + } +}; + +const getPadConstant = + (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], + inputStrides: readonly number[], pads: number[], dataType: string, constantValue: number): string => { + const inputRank = inputDims.length; + + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; + if (k < 0) { + break; + } + if (k >= ${inputDims[i]}) { + break; + } + offset += k * ${inputStrides[i]}; + `; + } + + return ` + value = ${dataType}(${constantValue}); + for (var i = 0; i < 1; i++) { + var offset = 0; + var k = 0; + ${block} + value = x[offset]; + } + `; + }; + +const getPadReflect = + (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], + inputStrides: readonly number[], pads: number[]): string => { + const inputRank = inputDims.length; + + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; + if (k < 0) { + k = -k; + } + { + let _2n_1 = ${2 * (inputDims[i] - 1)}; + k = k % _2n_1; + if(k >= ${inputDims[i]}) { + k = _2n_1 - k; + } + } + offset += k * ${inputStrides[i]}; + `; + } + + return ` + var offset = 0; + var k = 0; + ${block} + value = x[offset]; + `; + }; + +const getPadEdge = + (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], + inputStrides: readonly number[], pads: number[]): string => { + const inputRank = inputDims.length; + + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; + if (k < 0) { + k = 0; + } + if (k >= ${inputDims[i]}) { + k = ${inputDims[i] - 1}; + } + offset += k * ${inputStrides[i]}; + `; + } + + return ` + var offset = 0; + var k = 0; + ${block} + value = x[offset]; + `; + }; + +const getPadWrap = + (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], + inputStrides: readonly number[], pads: number[]): string => { + const inputRank = inputDims.length; + + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; + if (k < 0) { + k += ${inputDims[i]}; + } + if (k >= ${inputDims[i]}) { + k -= ${inputDims[i]}; + } + offset += k * ${inputStrides[i]}; + `; + } + + return ` + var offset = 0; + var k = 0; + ${block} + value = x[offset]; + `; + }; + +const getPadSnippet = + (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], + inputStrides: readonly number[], attributes: PadAttributes, dataType: string): string => { + switch (attributes.mode) { + case 0: + return getPadConstant( + output, outputDims, inputDims, inputStrides, attributes.pads, dataType, attributes.value); + case 1: + return getPadReflect(output, outputDims, inputDims, inputStrides, attributes.pads); + case 2: + return getPadEdge(output, outputDims, inputDims, inputStrides, attributes.pads); + case 3: + return getPadWrap(output, outputDims, inputDims, inputStrides, attributes.pads); + default: + throw new Error('Invalid mode'); + } + }; + +const generatePadCode = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string): + string => { + const inputDims = inputs[0].dims; + const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads); + const outputSize = ShapeUtil.size(outputDims); + const inputStrides = ShapeUtil.computeStrides(inputDims); + + const output = outputVariable('output', inputs[0].dataType, outputDims); + const input = inputVariable('x', inputs[0].dataType, inputDims); + + const padSnippet = getPadSnippet(output, outputDims, inputDims, inputStrides, attributes, dataType); + const padCode = ` + ${shaderHelper.declareVariables(input, output)} + ${output.impl()} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + let indices = ${output.offsetToIndices('global_idx')}; + + var value = ${dataType}(0); + ${padSnippet} + output[global_idx] = value; + }`; + return padCode; + }; + +const createPadProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: PadAttributes): ProgramInfo => { + const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'), + dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}) + }; + }; + +const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => { + if (inputs.length > 1) { + const bigInt64Pads = inputs[1].getBigInt64Array(); + const value = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : 0.0; + + const inputRank = inputs[0].dims.length; + const updatePads = new Int32Array(2 * inputRank).fill(0); + if (inputs.length >= 4) { + const axes = inputs[3].getBigInt64Array(); + for (let i = 0; i < axes.length; i++) { + updatePads[Number(axes[i])] = Number(bigInt64Pads[i]); + updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]); + } + } else { + bigInt64Pads.forEach((i, v) => updatePads[Number(i)] = (Number(v))); + } + + const pads: number[] = []; + updatePads.forEach(v => pads.push(v)); + + return createAttributeWithCacheKey({mode: attributes.mode, value, pads}); + } else { + return attributes; + } +}; + +const createPadProgramInfoLoader = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfoLoader => { + const updatedAttributes = createPadAttributesFromInputs(inputs, attributes); + const metadata: + ProgramMetadata = {name: 'Pad', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; + return {...metadata, get: () => createPadProgramInfo(inputs, metadata, updatedAttributes)}; +}; + +export const pad = (context: ComputeContext, attributes: PadAttributes): void => { + validateInputs(context.inputs); + context.compute(createPadProgramInfoLoader(context.inputs, attributes), {inputs: [0]}); +}; + +export const parsePadAttributes = (attributes: Record): PadAttributes => { + const mode = attributes.mode as number; + const value = attributes.value as number; + const pads = attributes.pads as number[]; + return createAttributeWithCacheKey({mode, value, pads}); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index f4249b24101e5..e580259071968 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -505,7 +505,7 @@ // // "test_dynamicquantizelinear_min_adjusted_expanded", // // "test_dynamicquantizelinear_min_adjusted", // // "test_dynamicquantizelinear", - // // "test_edge_pad", + "test_edge_pad", // "test_einsum_batch_diagonal", // "test_einsum_batch_matmul", // "test_einsum_inner_prod", @@ -965,7 +965,7 @@ "test_reduce_sum_square_keepdims_random", "test_reduce_sum_square_negative_axes_keepdims_example", "test_reduce_sum_square_negative_axes_keepdims_random", - // // "test_reflect_pad", + "test_reflect_pad", "test_relu", // "test_reshape_allowzero_reordered", "test_reshape_extended_dims", @@ -1308,7 +1308,8 @@ "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", "test_unsqueeze_unsorted_axes", - "test_unsqueeze" + "test_unsqueeze", + "test_wrap_pad" // "test_upsample_nearest", // "test_where_example", // "test_where_long_example", @@ -1361,8 +1362,8 @@ "reduce-min.jsonc", "relu.jsonc", "gelu.jsonc", - //"pad.jsonc", - //"pad-big.jsonc", + "pad.jsonc", + "pad-big.jsonc", "pow.jsonc", "pow_int32.jsonc", "pow-big-number.jsonc", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6b548921cdc8c..9dccd7c47fbb6 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -321,6 +321,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -577,6 +583,12 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc new file mode 100644 index 0000000000000..24ba85cbf6e0d --- /dev/null +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "pad.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 2, + 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + Pad); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), + Pad); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 13, + 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), + Pad); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 18, + 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), + Pad); + +ONNX_OPERATOR_KERNEL_EX( + Pad, + kOnnxDomain, + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), + Pad); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h new file mode 100644 index 0000000000000..19168f40b4722 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/tensor/padbase.h" + +namespace onnxruntime { +namespace js { + +class Pad : public JsKernel, public PadBase { + public: + explicit Pad(const OpKernelInfo& info) : JsKernel(info), PadBase(info) { + std::vector pads; + if (!is_dynamic_) { + pads.resize(pads_.size()); + for (size_t i = 0; i < pads_.size(); ++i) { + pads[i] = gsl::narrow_cast(pads_[i]); + } + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, + "value" : $2, + "pads" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}), + static_cast(mode_), + static_cast(value_), + gsl::narrow_cast(pads.size()), + reinterpret_cast((pads.size() > 0) ? pads.data() : nullptr) >> 2); + } +}; + +} // namespace js +} // namespace onnxruntime From 46fe08226feb95ce03b6a4773d59b9786f14a3df Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 14 Sep 2023 14:22:45 -0700 Subject: [PATCH 025/121] [QNN EP] Enable Pad op support for QNN EP (#17508) ### Description Enable Pad op support for QNN EP to support more models --- .../selectors_actions/qdq_selectors.cc | 25 ++ .../selectors_actions/qdq_selectors.h | 10 + .../selectors_actions/shared/utils.cc | 11 + .../qnn/builder/op_builder_factory.cc | 4 + .../qnn/builder/op_builder_factory.h | 2 + .../qnn/builder/opbuilder/base_op_builder.h | 4 +- .../builder/opbuilder/gather_op_builder.cc | 1 - .../qnn/builder/opbuilder/pad_op_builder.cc | 247 +++++++++++++ .../builder/opbuilder/simple_op_builder.cc | 4 +- .../providers/qnn/builder/qnn_model_wrapper.h | 1 + .../test/providers/qnn/pad_op_test.cpp | 346 ++++++++++++++++++ 11 files changed, 651 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc create mode 100644 onnxruntime/test/providers/qnn/pad_op_test.cpp diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 565afcc67e7df..02a7fb733813c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -330,6 +330,31 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& dt_input_1 == dt_output; } +bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + // Pad can have 1 or 2 dq input, the optional input constant_value can be quantized or non-quantized. + // QNN supports data input quantized with constant_value input non-quantized. + int num_dq_inputs = static_cast(dq_nodes.size()); + if (num_dq_inputs > 2) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, num_dq_inputs)) { + return false; + } + + const int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + const int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dq_nodes.size() > 1) { + const int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + return dt_input_1 == dt_input_2 && + dt_input_1 == dt_output; + } else { + return dt_input_1 == dt_output; + } +} + bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index ab9ad45697dfa..58ebf81508962 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -110,6 +110,16 @@ class WhereNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +class PadNodeGroupSelector : public NodeGroupSelector { + public: + PadNodeGroupSelector() = default; + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not // The lack of a trailing Q isn't really a QDQ node group, so we default support for that to off. class MatMulNodeGroupSelector : public NodeGroupSelector { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 7783d3b3f36b7..f1bdd7a99c329 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -123,6 +123,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetLogicalComparisonOpVersions static const OpVersionsAndSelector::OpVersionsMap GetWhereOpVersionsMap() { return {{"Where", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { + return {{"Pad", {}}}; +} /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { @@ -217,6 +220,13 @@ void RegisterWhereSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterPadSelectors(Selectors& qdq_selectors) { + /* register selectors for Pad ops */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetPadOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -231,6 +241,7 @@ void SelectorManager::CreateSelectors() { RegisterBatchNormalizationSelector(qdq_selectors_); RegisterLogicalComparisonSelectors(qdq_selectors_); RegisterWhereSelectors(qdq_selectors_); + RegisterPadSelectors(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 58ac3ad45a577..fc8c2efc7a80f 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -154,6 +154,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateTransposeOpBuilder("Transpose", *this); } + + { + CreatePadOpBuilder("Pad", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 36cf0e7ff5ac0..5d59f4343d773 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -88,5 +88,7 @@ void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 14d5e45799b81..0431d605bc843 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -162,7 +162,9 @@ class BaseOpBuilder : public IOpBuilder { {"BatchNormalization", QNN_OP_BATCHNORM}, {"LayerNormalization", QNN_OP_LAYER_NORM}, - {"LRN", QNN_OP_LRN}}; + {"LRN", QNN_OP_LRN}, + + {"Pad", QNN_OP_PAD}}; auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type); ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end()); return it->second; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index bd07c099b3cfe..e203667576447 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -13,7 +13,6 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes class GatherOpBuilder : public BaseOpBuilder { public: GatherOpBuilder() : BaseOpBuilder("GatherOpBuilder") {} diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc new file mode 100644 index 0000000000000..2dfdfffe5fa54 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/cpu/tensor/slice_helper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" + +namespace onnxruntime { +namespace qnn { +class PadOpBuilder : public BaseOpBuilder { + public: + PadOpBuilder() : BaseOpBuilder("PadOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PadOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status PadOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + // QNN Pad only has 1 input, the pads input & constant_value input need to be initializer and set as Qnn node parameter, axes input is not supported. + if (do_op_validation) { + ORT_RETURN_IF(inputs.size() > 3, "QNN Pad doesn't support axes."); + ORT_RETURN_IF(inputs.size() < 2, "QNN Pad requires the pads input."); + + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); + ORT_RETURN_IF(input_shape.size() > 5, "QNN Pad doesn't support more than 5 dimension"); + + auto& pads_input_name = inputs[1].node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(pads_input_name), + "Qnn doesn't support dynamic pad input"); + if (node_unit.Inputs().size() > 2) { + auto& constant_value_input_name = inputs[2].node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(constant_value_input_name), + "Qnn doesn't support dynamic constant_value input"); + } + } + + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + return Status::OK(); +} + +template +float DequantizeValue(T value, int32_t offset, float scale) { + return static_cast(static_cast(value) - offset) * scale; +} + +Status ProcessConstantValue(QnnModelWrapper& qnn_model_wrapper, + std::vector& param_tensor_names, + const NodeUnit& node_unit, + const NodeUnitIODef& input) { + OnnxInputInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(input, input_info)); + std::vector unpacked_tensor; + // Already confirmed constant_value input is initializer in ProcessInputs() + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, unpacked_tensor)); + Qnn_Scalar_t constant_value_qnn_scalar = QNN_SCALAR_INIT; + // constant_value is quantized + if (input.quant_param.has_value()) { + // QNN prefers pad_constant_value quantized with quantization params same as in[0], and data stored as 32-bit signed integer + // Onnx doesn't guarantee it has same quantization parameter as in[0], so get back the float32 value and use non-quantized data directly + constant_value_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; + float constant_value = 0; + switch (input_info.qnn_data_type) { + case QNN_DATATYPE_SFIXED_POINT_8: { + auto int8_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value = DequantizeValue(int8_span.data()[0], + input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale); + break; + } + case QNN_DATATYPE_SFIXED_POINT_16: { + auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value = DequantizeValue(int16_span.data()[0], + input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale); + break; + } + case QNN_DATATYPE_SFIXED_POINT_32: { + auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value = DequantizeValue(int32_span.data()[0], + input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale); + break; + } + case QNN_DATATYPE_UFIXED_POINT_8: { + constant_value = DequantizeValue(unpacked_tensor.data()[0], + input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale); + break; + } + case QNN_DATATYPE_UFIXED_POINT_16: { + auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value = DequantizeValue(uint16_span.data()[0], + input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale); + break; + } + case QNN_DATATYPE_UFIXED_POINT_32: { + auto uint32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value = DequantizeValue(uint32_span.data()[0], + input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Pad constant_value."); + } + constant_value_qnn_scalar.floatValue = constant_value; + } else { // constant_value is non-quantized + constant_value_qnn_scalar.dataType = input_info.qnn_data_type; + switch (input_info.qnn_data_type) { + case QNN_DATATYPE_UINT_8: { + constant_value_qnn_scalar.uint8Value = unpacked_tensor.data()[0]; + break; + } + case QNN_DATATYPE_INT_8: { + auto int8_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value_qnn_scalar.int8Value = int8_span.data()[0]; + break; + } + case QNN_DATATYPE_INT_16: { + auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value_qnn_scalar.int16Value = int16_span.data()[0]; + break; + } + case QNN_DATATYPE_INT_32: { + auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value_qnn_scalar.int32Value = int32_span.data()[0]; + break; + } + case QNN_DATATYPE_INT_64: { + auto int64_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value_qnn_scalar.int64Value = int64_span.data()[0]; + break; + } + case QNN_DATATYPE_FLOAT_32: { + auto float_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + constant_value_qnn_scalar.floatValue = float_span.data()[0]; + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported."); + } // switch + } // if-else + + QnnParamWrapper constant_value_param(node_unit.Index(), + node_unit.Name(), + QNN_OP_PAD_PARAM_PAD_CONSTANT_VALUE, + constant_value_qnn_scalar); + param_tensor_names.push_back(constant_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(constant_value_param)); + + return Status::OK(); +} + +Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector param_tensor_names; + // Process pads input + // Already confirmed pads input is initializer in ProcessInputs() + const auto& inputs = node_unit.Inputs(); + const auto& pads_input_name = inputs[1].node_arg.Name(); + + std::vector unpacked_tensor; + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(pads_input_name); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); + // Onnx Pads are int64, Qnn use uint32 + const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); + size_t tensor_byte_size = unpacked_tensor.size(); + size_t size = tensor_byte_size / sizeof(int64_t); + + std::vector pad_amount; + std::transform(tensor_data, tensor_data + size, std::back_inserter(pad_amount), + [](int64_t item) { return SafeInt(item); }); + // Onnx format is begin_0, begin_1, ..., end_0, end_1, ... + // Qnn format is begin_0, end_0, begin_1, end_1, ... + ReArranagePads(pad_amount); + + std::vector pad_amount_dim{static_cast(pad_amount.size() / 2), static_cast(2)}; + QnnParamWrapper multiples_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), + std::move(pad_amount)); + param_tensor_names.push_back(multiples_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(multiples_param)); + + // Process optional input constant_value + if (node_unit.Inputs().size() > 2) { + ORT_RETURN_IF_ERROR(ProcessConstantValue(qnn_model_wrapper, param_tensor_names, node_unit, inputs[2])); + } // constant_value + + NodeAttrHelper node_helper(node_unit); + std::string mode = node_helper.Get("mode", "constant"); + Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; + mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + if ("constant" == mode) { + mode_qnn_scalar.uint32Value = QNN_OP_PAD_SCHEME_CONSTANT; + } else if ("reflect" == mode) { + mode_qnn_scalar.uint32Value = QNN_OP_PAD_SCHEME_MIRROR_REFLECT; + } else if ("edge" == mode) { + mode_qnn_scalar.uint32Value = QNN_OP_PAD_SCHEME_EDGE; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad mode only support constant."); + } + + QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_SCHEME, mode_qnn_scalar); + param_tensor_names.push_back(mode_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(mode_param)); + + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(node_unit.OpType()))); + + return Status::OK(); +} + +void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 8abb847b20b46..556a86bb1519b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -118,9 +118,9 @@ Status ProcessModeAttribute(QnnModelWrapper& qnn_model_wrapper, Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; if ("DCR" == mode) { - mode_qnn_scalar.uint32Value = 0; + mode_qnn_scalar.uint32Value = QNN_OP_DEPTH_TO_SPACE_MODE_DCR; } else if ("CRD" == mode) { - mode_qnn_scalar.uint32Value = 1; // CRD mode + mode_qnn_scalar.uint32Value = QNN_OP_DEPTH_TO_SPACE_MODE_CRD; // CRD mode } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DepthToSpace mode only support DCR & CRD."); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 1f54bda9107c7..22f8d3a0eaa64 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -117,6 +117,7 @@ class QnnModelWrapper { return input_index_map_.find(tensor_name) != input_index_map_.end(); } + // TODO(hecli) rename to GetTensorInfo Status GetOnnxInputInfo(const NodeUnitIODef& input, OnnxInputInfo& input_info) const; Status AddReshapeNode(const std::string& input_name, diff --git a/onnxruntime/test/providers/qnn/pad_op_test.cpp b/onnxruntime/test/providers/qnn/pad_op_test.cpp new file mode 100644 index 0000000000000..95961e423833a --- /dev/null +++ b/onnxruntime/test/providers/qnn/pad_op_test.cpp @@ -0,0 +1,346 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that creates a graph with a single Pad operator. +static GetTestModelFn BuildPadTestCase(const TestInputDef& data_def, + const TestInputDef& pads_def, + const TestInputDef& constant_value_def, + const std::vector& attrs, + bool has_constant_value = true) { + return [data_def, pads_def, constant_value_def, attrs, has_constant_value](ModelTestBuilder& builder) { + NodeArg* data = MakeTestInput(builder, data_def); + NodeArg* pads = MakeTestInput(builder, pads_def); + std::vector inputs{data, pads}; + if (has_constant_value) { + NodeArg* constant_value = MakeTestInput(builder, constant_value_def); + inputs.push_back(constant_value); + } + NodeArg* output = builder.MakeOutput(); + Node& pad_node = builder.AddNode("Pad", inputs, {output}); + + for (const auto& attr : attrs) { + pad_node.AddAttributeProto(attr); + } + }; +} + +// Returns a function that creates a graph with a QDQ Pad operator. +template +GetTestQDQModelFn BuildPadQDQTestCase(const TestInputDef& data_def, + const TestInputDef& pads_def, + const TestInputDef& constant_value_def, + const std::vector& attrs, + bool has_constant_value, + bool constant_value_quantized) { + return [data_def, pads_def, constant_value_def, attrs, has_constant_value, constant_value_quantized](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector inputs; + // data -> Q -> DQ -> + NodeArg* data = MakeTestInput(builder, data_def); + QuantParams data_qparams = GetTestInputQuantParams(data_def); + NodeArg* data_qdq = AddQDQNodePair(builder, data, data_qparams.scale, data_qparams.zero_point); + inputs.push_back(data_qdq); + + // pads + NodeArg* pads = MakeTestInput(builder, pads_def); + inputs.push_back(pads); + + // constant_value -- QNN support both quantized and non-quantized + if (has_constant_value) { + if (constant_value_quantized) { + // constant_value -> Q -> DQ -> + NodeArg* constant_value = MakeTestInput(builder, constant_value_def); + QuantParams constant_value_qparams = GetTestInputQuantParams(constant_value_def); + NodeArg* constant_value_qdq = AddQDQNodePair(builder, constant_value, + constant_value_qparams.scale, + constant_value_qparams.zero_point); + inputs.push_back(constant_value_qdq); + } else { + NodeArg* constant_value = MakeTestInput(builder, constant_value_def); + inputs.push_back(constant_value); + } + } + + NodeArg* output = builder.MakeIntermediate(); + Node& pad_node = builder.AddNode("Pad", inputs, {output}); + + for (const auto& attr : attrs) { + pad_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +// Runs an Pad model on the QNN CPU backend. Checks the graph node assignment, and that inference +// outputs for QNN and CPU match. +static void RunPadOpTest(const TestInputDef& data_def, + const TestInputDef& pads_def, + const TestInputDef& constant_value_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + bool has_constant_value = true, + int opset = 18) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildPadTestCase(data_def, pads_def, constant_value_def, attrs, has_constant_value), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Pad model on the QNN HTP backend. Checks the graph node assignment, and that inference +// outputs for QNN and CPU match. +template +static void RunQDQPadOpTest(const TestInputDef& data_def, + const TestInputDef& pads_def, + const TestInputDef& constant_value_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + bool has_constant_value = true, + bool constant_value_quantized = true, + int opset = 18) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildPadTestCase(data_def, pads_def, constant_value_def, attrs), + BuildPadQDQTestCase(data_def, pads_def, constant_value_def, attrs, + has_constant_value, constant_value_quantized), + provider_options, + opset, + expected_ep_assignment, + 1e-5f); +} + +// +// CPU tests: +// + +// Pad 2d +TEST_F(QnnCPUBackendTests, Pad2d) { + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All); +} + +// Pad 2d, pads input not initializer +TEST_F(QnnCPUBackendTests, Pad2dPadsNotIni) { + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, false, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::None); +} + +// Pad reflect mode +// Expected: contains 12 values, where each value and its corresponding value in 16-byte object <0C-00 00-00 00-00 00-00 40-01 23-05 EC-01 00-00> are an almost-equal pair +// Actual: 16-byte object <0C-00 00-00 00-00 00-00 40-01 12-05 EC-01 00-00>, where the value pair (1.2, 0) at index #1 don't match, which is -1.2 from 1.2 +TEST_F(QnnCPUBackendTests, DISABLED_PadModeReflect) { + bool has_constant_value = false; + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "reflect")}, + ExpectedEPNodeAssignment::All, + has_constant_value); +} + +// Pad edge mode +TEST_F(QnnCPUBackendTests, PadModeEdge) { + bool has_constant_value = false; + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "edge")}, + ExpectedEPNodeAssignment::All, + has_constant_value); +} + +// Pad wrap mode not supported +TEST_F(QnnCPUBackendTests, PadModeWrap) { + bool has_constant_value = false; + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "wrap")}, + ExpectedEPNodeAssignment::None, // not supported + has_constant_value); +} + +// Pad 4d +TEST_F(QnnCPUBackendTests, Pad4d) { + RunPadOpTest(TestInputDef({1, 2, 2, 2}, false, + {1.0f, 1.0f, + 1.0f, 1.0f, + 1.0f, 1.0f, + 1.0f, 1.0f}), + TestInputDef({8}, true, {0, 0, 0, 1, 0, 0, 0, 1}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All); +} + +// Pad 5d supported +TEST_F(QnnCPUBackendTests, Pad5d) { + RunPadOpTest(TestInputDef({1, 2, 2, 2, 2}, false, GetFloatDataInRange(1.0f, 10.0f, 16)), + TestInputDef({10}, true, {0, 0, 0, 1, 0, 0, 0, 1, 0, 0}), + TestInputDef({1}, true, {5.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All); +} + +// Pad 6d supported +TEST_F(QnnCPUBackendTests, Pad6d) { + RunPadOpTest(TestInputDef({1, 2, 2, 2, 2, 2}, false, GetFloatDataInRange(1.0f, 10.0f, 32)), + TestInputDef({12}, true, {0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::None); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// +// QDQ Pad +TEST_F(QnnHTPBackendTests, PadNoConstantValue) { + bool has_constant_value_input = false; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input); +} + +TEST_F(QnnHTPBackendTests, PadHasConstantValueNonQuantized) { + bool has_constant_value_input = true; + bool constant_value_quantized = false; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input, + constant_value_quantized); +} + +TEST_F(QnnHTPBackendTests, PadHasConstantValueQuantized) { + bool has_constant_value_input = true; + bool constant_value_quantized = true; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input, + constant_value_quantized); +} + +// QNN graph execute error. Error code: 6031 +TEST_F(QnnHTPBackendTests, DISABLED_PadReflectMode) { + bool has_constant_value_input = false; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "reflect")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input); +} + +TEST_F(QnnHTPBackendTests, PadEdgeMode) { + bool has_constant_value_input = false; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "edge")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input); +} + +// wrap mode not supported +TEST_F(QnnHTPBackendTests, PadWrapMode) { + bool has_constant_value_input = false; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "wrap")}, + ExpectedEPNodeAssignment::None, + has_constant_value_input); +} + +TEST_F(QnnHTPBackendTests, Pad4d) { + RunQDQPadOpTest(TestInputDef({1, 2, 2, 2}, false, + {1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f, + 7.0f, 8.0f}), + TestInputDef({8}, true, {0, 0, 0, 1, 0, 0, 0, 1}), + TestInputDef({1}, true, {5.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All); +} + +// Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.035294119268655777, zero_point=0. +// Expected val: 9 +// QNN QDQ val: 8.0117654800415039 (err 0.98823451995849609) +// CPU QDQ val: 9 (err 0) +// QNN limitation? pad_constant_value has to be within the range of input[0]. +// Here pad_constant_value = 9.0 > max(input[0]) = 8.0 +TEST_F(QnnHTPBackendTests, DISABLED_Pad4dOutOfRangePadConstantValue) { + RunQDQPadOpTest(TestInputDef({1, 2, 2, 2}, false, + {1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f, + 7.0f, 8.0f}), + TestInputDef({8}, true, {0, 0, 0, 1, 0, 0, 0, 1}), + TestInputDef({1}, true, {9.0f}), // pad_constant_value out of input[0] range + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All); +} + +// Pad 5d supported, but Quantize & Dequantize doesn't support 5d +TEST_F(QnnHTPBackendTests, DISABLED_Pad5d) { + RunQDQPadOpTest(TestInputDef({1, 2, 2, 2, 2}, false, GetFloatDataInRange(1.0f, 10.0f, 16)), + TestInputDef({10}, true, {0, 0, 0, 1, 0, 0, 0, 1, 0, 0}), + TestInputDef({1}, true, {2.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) \ No newline at end of file From 41d2ff622c49aa3628c04f6b64ed7f33c8d80f30 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 15 Sep 2023 08:03:18 +0800 Subject: [PATCH 026/121] [js/webgpu] Optimize InstanceNormalization (#17491) ### Description In previous implementation, there are two loops to iterate H * W elements to calculate the `mean` and `squaredNorm` value in one thread, meanwhile it outputs H * W elements in one thread. That results it's very very slow when H * W is a large value. And usually, H * W does be a large value in a model. For example, in the `candy-8` model, the shapes of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization` op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below: ``` [profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns ``` In this PR, it uses workgroup memory to optimize the original algorithm. The advantage is that it can parallelly utilize the 64 (workgroupSize) threads in one workgroup to calculate `mean` and `squaredNorm` value. Meanwhile, it only outputs `H * W / workgroupSize` outputs for one thread, which greatly reduces the overhead for one thread. With this optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the extra two `transpose`. The `createInstanceNormProgramInfo` only needs `0.64` ms. See below: ``` [profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns program-manager.ts:115 [profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns program-manager.ts:115 [profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns ``` This PR currently only applies the new algorithm to NCHW format. For NHWC format, one way is to transpose the input so that it can use the new algorithm. But the disadvantage is that 2 extra transpose are added. @dakenf also gives another way to optimize NHWC. Details see [here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts). I checked @dakenf's method. The perf is similar with transpose + optimized NCHW. But on different GPUs, one is a little better than another or vice versa. So I prefer this PR only does the NCHW part. @dakenf can submit his optimization on NHWC. --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 3 +- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 116 ++++++++++-------- js/web/test/data/ops/instance-norm.jsonc | 79 ++++++++++++ js/web/test/suite-test-list.jsonc | 2 + 4 files changed, 147 insertions(+), 53 deletions(-) create mode 100644 js/web/test/data/ops/instance-norm.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index f3845e3110905..c054da51a3098 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -592,7 +592,8 @@ class ShaderHelperImpl implements ShaderHelper { const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2]; const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; - const paramList = is1DimensionDispatch ? '@builtin(global_invocation_id) global_id : vec3' : + const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, + @builtin(local_invocation_id) local_id : vec3` : `@builtin(local_invocation_index) local_index : u32, @builtin(workgroup_id) workgroup_id : vec3`; const globalIdxDefinition = is1DimensionDispatch ? diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index f62c766aa9ed0..449073a133295 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,83 +1,97 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; format: 'NHWC'|'NCHW'; } -const validateInputs = (inputs: readonly TensorView[]): void => { - if (!inputs || inputs.length !== 3) { - throw new Error('instanceNorm requires 3 inputs.'); - } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } -}; - const createInstanceNormProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { const xShape = inputs[0].dims; - const scale = inputs[1]; - const bias = inputs[2]; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); const C = xShape[1]; - - const scaleSize = ShapeUtil.size(scale.dims); - const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; - if (scaleSize !== normSize || (bias && biasSize !== normSize)) { - throw new Error(`Size of X.shape()[axis:] == ${normSize}. - Size of scale and bias (if provided) must match this. - Got scale size of ${scaleSize} and bias size of ${biasSize}`); - } - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - + const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const variables = [x, scale, bias, output]; + const dataType = x.type.value; + const workgroupSize = 64; const getShaderSource = (shaderHelper: ShaderHelper) => ` + const C: u32 = ${C}; const normSize: u32 = ${normSize}; - const normSizeTyped: ${dataType} = ${normSize}; const epsilon: f32 = ${attributes.epsilon}; + var meanShared : ${dataType}; + var squaredNormShared : ${dataType}; + var workgroupShared : array<${dataType}, ${workgroupSize}>; + const workgroupSize = ${workgroupSize}u; + ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.mainStart(workgroupSize)} + let norm = global_idx / workgroupSize; + let batch = norm / C; + let channel = norm % C; + let localIndex = local_id.x; + + // initialize workgroup memory + var initial: ${dataType} = 0; + for (var h = localIndex; h < normSize; h += workgroupSize) { + initial = initial + ${x.get('batch', 'channel', 'h')}; + } + workgroupShared[localIndex] = initial; + workgroupBarrier(); - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var scale : array<${dataType}>; - @group(0) @binding(2) var bias : array<${dataType}>; - @group(0) @binding(3) var output : array<${dataType}>; - - ${shaderHelper.mainStart()} - let offset = global_idx * normSize; - if (offset + normSize >= ${outputSize}) { return; } - var mean: ${dataType} = 0; + // Calculate the mean of current channel data. + for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { + if (localIndex < currSize) { + workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; + } + workgroupBarrier(); + } + if (localIndex == 0) { + meanShared = workgroupShared[0] / ${dataType}(normSize); + } + workgroupBarrier(); - for (var h: u32 = 0u; h < normSize; h++) { - mean = mean + x[h + offset]; + // reinitialize workgroup memory. + initial = 0; + for (var h = localIndex; h < normSize; h += workgroupSize) { + let deviation = ${x.get('batch', 'channel', 'h')} - meanShared; + initial = initial + deviation * deviation; } - mean = mean / normSizeTyped; + workgroupShared[localIndex] = initial; + workgroupBarrier(); - var squaredNorm: ${dataType} = 0; - for (var h: u32 = 0u; h < normSize; h++) { - let deviation: f32 = x[h + offset] - mean; - squaredNorm = squaredNorm + deviation * deviation; + // Calculate the sum of square of deviation of current channel data. + for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { + if (localIndex < currSize) { + workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; + } + workgroupBarrier(); } - let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon); - let channelScale = invStdDev * scale[global_idx % C]; - let channelShift = bias[global_idx % C] - mean * channelScale; - for (var j: u32 = 0; j < normSize; j++) { - output[j + offset] = x[j + offset] * channelScale + channelShift; + if (localIndex == 0) { + squaredNormShared = workgroupShared[0]; + } + workgroupBarrier(); + + let invStdDev = 1 / sqrt(squaredNormShared / ${dataType}(normSize) + epsilon); + let channelScale = invStdDev * ${scale.getByOffset('channel')}; + let channelShift = ${bias.getByOffset('channel')} - meanShared * channelScale; + for (var h = localIndex; h < normSize; h += workgroupSize) { + let value = ${x.get('batch', 'channel', 'h')} * channelScale + channelShift; + ${output.set('batch', 'channel', 'h', 'value')}; } }`; return { @@ -86,7 +100,7 @@ const createInstanceNormProgramInfo = {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, ], getShaderSource, - dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)}) + dispatchGroup: () => ({x: normCount}) }; }; @@ -118,7 +132,7 @@ const createInstanceNormNHWCProgramInfo = ${shaderHelper.mainStart()} let currentImageNumber = global_idx / C; let currentChannelNumber = global_idx % C; - + // offset is channel num * N let offset = currentImageNumber * imageSize; if (offset >= ${outputSize}) { return; } @@ -156,8 +170,6 @@ export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format}); export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { - validateInputs(context.inputs); - const metadata = { name: 'InstanceNormalization', inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc new file mode 100644 index 0000000000000..6a4e6912405ee --- /dev/null +++ b/js/web/test/data/ops/instance-norm.jsonc @@ -0,0 +1,79 @@ +[ + { + "name": "Simple test with NHWC", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4], + "dims": [1, 4, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + }, + { + "data": [4, 5, 6, 7], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.6583645343780518, 3.552788257598877, 4.447211742401123, 5.341635704040527, 2.3167295455932617, + 4.105576515197754, 5.8944244384765625, 7.683271408081055, 6, 10.242595672607422, 6, 1.7574005126953125, + 12.36654281616211, 8.788846969604492, 5.211153030395508, 1.633458137512207 + ], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NCHW", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4], + "dims": [1, 4, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + }, + { + "data": [4, 5, 6, 7], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.6583645343780518, 3.552788257598877, 4.447211742401123, 5.341635704040527, 2.3167295455932617, + 4.105576515197754, 5.8944244384765625, 7.683271408081055, 6, 10.242595672607422, 6, 1.7574005126953125, + 12.36654281616211, 8.788846969604492, 5.211153030395508, 1.633458137512207 + ], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index e580259071968..94592884ccad6 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -257,6 +257,7 @@ "greater.jsonc", //"identity.jsonc", "image-scaler.jsonc", + "instance-norm.jsonc", "less.jsonc", "log.jsonc", "matmul.jsonc", @@ -1347,6 +1348,7 @@ "gemm.jsonc", "global-average-pool.jsonc", "greater.jsonc", + "instance-norm.jsonc", "less.jsonc", "log.jsonc", "matmul.jsonc", From 5ed5f13920b66098cf3bfe7da85f463e6ff64bac Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Fri, 15 Sep 2023 09:47:45 +0900 Subject: [PATCH 027/121] [DML EP] Add missing member initializer DmlGraphNodeCreateInfo::nodeCount (#17505) ### Description This adds a missing member initialization. ### Motivation and Context It caused an access violation in `Dml::GraphDescBuilder::BuildGraphDesc`. --- .../dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 232a022d869f4..04381b6ce355c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -80,7 +80,7 @@ namespace Windows::AI::MachineLearning::Adapter // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { - uint32_t nodeCount; + uint32_t nodeCount = 0; std::vector> nodesAsOperatorDesc; std::vector> nodesAsIDMLOperator; std::vector inputEdges; From 3a1e48dd5ad2a9ef95b6bc137441a4df6b05e5d0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 14 Sep 2023 18:15:29 -0700 Subject: [PATCH 028/121] update BERT notebook with ORT 1.16 (#17524) - Update BERT notebook with onnxruntime-gpu 1.16 - Add example of packing mode - Run results in RTX 4090 GPU --- .../PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb | 1433 +++++++++++------ 1 file changed, 956 insertions(+), 477 deletions(-) diff --git a/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb b/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb index 74b81fc7c867f..43c31e1ea45ac 100644 --- a/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb +++ b/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb @@ -33,19 +33,20 @@ "\n", "#### GPU Environment Setup using AnaConda\n", "\n", - "First, we install [AnaConda](https://www.anaconda.com/distribution/) in a target machine and open an AnaConda prompt window when it is done. Then run the following commands to create a conda environment. This notebook is tested with PyTorch 1.5.0 and OnnxRuntime 1.3.0.\n", + "First, we install [AnaConda](https://www.anaconda.com/distribution/) in a target machine and open an AnaConda prompt window when it is done. Then run the following commands to create a conda environment. This notebook is tested with PyTorch 2.0.1 and OnnxRuntime 1.16.0.\n", "\n", "```console\n", - "conda create -n gpu_env python=3.6\n", + "conda create -n gpu_env python=3.10\n", "conda activate gpu_env\n", - "conda install -c anaconda ipykernel\n", + "pip install jupyterlab\n", + "conda install ipykernel\n", "conda install -c conda-forge ipywidgets\n", - "python -m ipykernel install --user --name=gpu_env\n", - "jupyter notebook\n", + "ipython kernel install --user --name gpu_env\n", + "jupyter-lab\n", "```\n", "Finally, launch Jupyter Notebook and you can choose gpu_env as kernel to run this notebook.\n", "\n", - "Onnxruntime-gpu need specified version of CUDA and cuDNN. You can find the Requirements [here](https://onnxruntime.ai/docs/install/). Remember to add the directories to PATH environment variable (See [CUDA and cuDNN Path](#CUDA-and-cuDNN-Path) below)." + "Onnxruntime-gpu need specified version of CUDA and cuDNN. You can find the Requirements [here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements). Remember to add the directories to PATH environment variable (See [CUDA and cuDNN Path](#CUDA-and-cuDNN-Path) below)." ] }, { @@ -56,18 +57,19 @@ "source": [ "import sys\n", "\n", - "run_install = False # Only need install once\n", - "if run_install:\n", - " if sys.platform in ['linux', 'win32']: # Linux or Windows\n", - " !{sys.executable} -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html\n", - " else: # Mac\n", - " print(\"PyTorch 1.9 MacOS Binaries do not support CUDA, install from source instead\")\n", - "\n", - " !{sys.executable} -m pip install onnxruntime-gpu==1.8.1 onnx==1.9.0 onnxconverter_common==1.8.1\n", - "\n", - " # Install other packages used in this notebook.\n", - " !{sys.executable} -m pip install transformers==4.8.2\n", - " !{sys.executable} -m pip install psutil pytz pandas py-cpuinfo py3nvml coloredlogs wget netron sympy" + "if sys.platform in ['linux', 'win32']: # Linux or Windows\n", + " !{sys.executable} -m pip install torch --index-url https://download.pytorch.org/whl/cu118 -q\n", + " !{sys.executable} -m pip install onnxruntime-gpu onnx transformers psutil pandas py-cpuinfo py3nvml coloredlogs wget netron sympy protobuf==3.20.3 -q\n", + "else: # Mac\n", + " print(\"CUDA is not available on MacOS\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CUDA and cuDNN Path\n", + "onnxruntime-gpu has dependency on [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn). Required CUDA version can be found [here](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements) If you import torch before onnxruntime, onnxruntime might use the CUDA and cuDNN DLLs that loaded by PyTorch." ] }, { @@ -79,10 +81,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "pytorch: 1.9.0+cu111\n", - "onnxruntime: 1.8.1\n", - "onnx: 1.9.0\n", - "transformers: 4.8.2\n" + "pytorch: 2.0.1+cu118\n", + "onnxruntime: 1.16.0\n", + "onnx: 1.14.1\n", + "transformers: 4.33.1\n" ] } ], @@ -191,9 +193,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 48/48 [00:03<00:00, 14.24it/s]\n", - "convert squad examples to features: 100%|██████████| 1000/1000 [00:08<00:00, 112.67it/s]\n", - "add example index and unique id: 100%|██████████| 1000/1000 [00:00<00:00, 836518.55it/s]\n" + "Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", + "- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [00:02<00:00, 16.27it/s]\n", + "convert squad examples to features: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 256.11it/s]\n", + "add example index and unique id: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00\n", " \n", " 0\n", - " 23.72\n", - " 23.72\n", - " 23.87\n", - " 23.99\n", - " 24.11\n", - " 24.37\n", - " 42.15\n", - " 4\n", + " 3.19\n", + " 3.16\n", + " 3.21\n", + " 3.27\n", + " 3.35\n", + " 3.52\n", + " 313.24\n", + " 1\n", " \n", " \n", " 1\n", - " 24.24\n", - " 24.24\n", - " 24.42\n", - " 24.60\n", - " 24.76\n", - " 25.23\n", - " 41.25\n", - " 3\n", + " 3.20\n", + " 3.17\n", + " 3.22\n", + " 3.25\n", + " 3.34\n", + " 3.50\n", + " 312.80\n", + " 8\n", " \n", " \n", " 2\n", - " 24.36\n", - " 24.36\n", - " 24.47\n", - " 24.69\n", - " 25.01\n", - " 26.52\n", - " 41.05\n", - " 2\n", + " 3.20\n", + " 3.15\n", + " 3.25\n", + " 3.29\n", + " 3.36\n", + " 3.58\n", + " 312.51\n", + " 15\n", " \n", " \n", " 3\n", - " 24.39\n", - " 24.37\n", - " 24.47\n", - " 24.65\n", - " 24.73\n", - " 25.12\n", - " 41.01\n", - " 1\n", + " 3.20\n", + " 3.18\n", + " 3.21\n", + " 3.26\n", + " 3.35\n", + " 3.53\n", + " 312.49\n", + " 14\n", + " \n", + " \n", + " 4\n", + " 3.20\n", + " 3.16\n", + " 3.25\n", + " 3.29\n", + " 3.40\n", + " 3.56\n", + " 312.24\n", + " 13\n", + " \n", + " \n", + " 5\n", + " 3.20\n", + " 3.19\n", + " 3.22\n", + " 3.27\n", + " 3.35\n", + " 3.48\n", + " 312.20\n", + " 12\n", + " \n", + " \n", + " 6\n", + " 3.21\n", + " 3.18\n", + " 3.23\n", + " 3.28\n", + " 3.37\n", + " 3.51\n", + " 311.73\n", + " 24\n", + " \n", + " \n", + " 7\n", + " 3.21\n", + " 3.19\n", + " 3.23\n", + " 3.27\n", + " 3.34\n", + " 3.52\n", + " 311.57\n", + " 9\n", + " \n", + " \n", + " 8\n", + " 3.21\n", + " 3.18\n", + " 3.26\n", + " 3.31\n", + " 3.36\n", + " 3.54\n", + " 311.15\n", + " 32\n", + " \n", + " \n", + " 9\n", + " 3.21\n", + " 3.17\n", + " 3.24\n", + " 3.28\n", + " 3.34\n", + " 3.52\n", + " 311.10\n", + " 5\n", + " \n", + " \n", + " 10\n", + " 3.21\n", + " 3.19\n", + " 3.25\n", + " 3.29\n", + " 3.33\n", + " 3.54\n", + " 311.10\n", + " 2\n", + " \n", + " \n", + " 11\n", + " 3.22\n", + " 3.19\n", + " 3.25\n", + " 3.29\n", + " 3.36\n", + " 3.51\n", + " 310.93\n", + " 10\n", + " \n", + " \n", + " 12\n", + " 3.22\n", + " 3.19\n", + " 3.26\n", + " 3.29\n", + " 3.40\n", + " 3.55\n", + " 310.30\n", + " 3\n", + " \n", + " \n", + " 13\n", + " 3.23\n", + " 3.19\n", + " 3.26\n", + " 3.32\n", + " 3.42\n", + " 3.58\n", + " 310.02\n", + " 11\n", + " \n", + " \n", + " 14\n", + " 3.23\n", + " 3.19\n", + " 3.26\n", + " 3.30\n", + " 3.36\n", + " 3.54\n", + " 310.02\n", + " 4\n", + " \n", + " \n", + " 15\n", + " 3.23\n", + " 3.20\n", + " 3.23\n", + " 3.27\n", + " 3.35\n", + " 3.60\n", + " 309.53\n", + " 7\n", + " \n", + " \n", + " 16\n", + " 3.23\n", + " 3.19\n", + " 3.22\n", + " 3.26\n", + " 3.33\n", + " 3.68\n", + " 309.27\n", + " 6\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Latency(ms) Latency_P50 Latency_P75 Latency_P90 Latency_P95 \\\n", - "0 23.72 23.72 23.87 23.99 24.11 \n", - "1 24.24 24.24 24.42 24.60 24.76 \n", - "2 24.36 24.36 24.47 24.69 25.01 \n", - "3 24.39 24.37 24.47 24.65 24.73 \n", + " Latency(ms) Latency_P50 Latency_P75 Latency_P90 Latency_P95 \\\n", + "0 3.19 3.16 3.21 3.27 3.35 \n", + "1 3.20 3.17 3.22 3.25 3.34 \n", + "2 3.20 3.15 3.25 3.29 3.36 \n", + "3 3.20 3.18 3.21 3.26 3.35 \n", + "4 3.20 3.16 3.25 3.29 3.40 \n", + "5 3.20 3.19 3.22 3.27 3.35 \n", + "6 3.21 3.18 3.23 3.28 3.37 \n", + "7 3.21 3.19 3.23 3.27 3.34 \n", + "8 3.21 3.18 3.26 3.31 3.36 \n", + "9 3.21 3.17 3.24 3.28 3.34 \n", + "10 3.21 3.19 3.25 3.29 3.33 \n", + "11 3.22 3.19 3.25 3.29 3.36 \n", + "12 3.22 3.19 3.26 3.29 3.40 \n", + "13 3.23 3.19 3.26 3.32 3.42 \n", + "14 3.23 3.19 3.26 3.30 3.36 \n", + "15 3.23 3.20 3.23 3.27 3.35 \n", + "16 3.23 3.19 3.22 3.26 3.33 \n", "\n", - " Latency_P99 Throughput(QPS) intra_op_num_threads \n", - "0 24.37 42.15 4 \n", - "1 25.23 41.25 3 \n", - "2 26.52 41.05 2 \n", - "3 25.12 41.01 1 " + " Latency_P99 Throughput(QPS) intra_op_num_threads \n", + "0 3.52 313.24 1 \n", + "1 3.50 312.80 8 \n", + "2 3.58 312.51 15 \n", + "3 3.53 312.49 14 \n", + "4 3.56 312.24 13 \n", + "5 3.48 312.20 12 \n", + "6 3.51 311.73 24 \n", + "7 3.52 311.57 9 \n", + "8 3.54 311.15 32 \n", + "9 3.52 311.10 5 \n", + "10 3.54 311.10 2 \n", + "11 3.51 310.93 10 \n", + "12 3.55 310.30 3 \n", + "13 3.58 310.02 11 \n", + "14 3.54 310.02 4 \n", + "15 3.60 309.53 7 \n", + "16 3.68 309.27 6 " ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import os\n", - "import glob \n", - "import pandas\n", - "latest_result_file = max(glob.glob(\"./onnx/perf_results_GPU_B1_S128_*.txt\"), key=os.path.getmtime)\n", - "result_data = pandas.read_table(latest_result_file)\n", - "print(\"Float32 model perf results from\", latest_result_file)\n", - "# Remove some columns that have same values for all rows.\n", - "columns_to_remove = ['model', 'graph_optimization_level', 'batch_size', 'sequence_length', 'test_cases', 'test_times', 'use_gpu']\n", - "result_data.drop(columns_to_remove, axis=1, inplace=True)\n", - "result_data" + "def load_last_perf_test_result():\n", + " import os\n", + " import glob \n", + " import pandas\n", + " latest_result_file = max(glob.glob(\"./onnx/perf_results_*.txt\"), key=os.path.getmtime)\n", + " result_data = pandas.read_table(latest_result_file)\n", + " print(\"Perf results from\", latest_result_file)\n", + " # Do not show columns that have same values for all rows.\n", + " columns_to_remove = ['model', 'graph_optimization_level', 'batch_size', 'sequence_length', 'test_cases', 'test_times', 'use_gpu', 'use_io_binding', 'average_sequence_length', 'random_sequence_length']\n", + " result_data.drop(columns_to_remove, axis=1, inplace=True)\n", + " return result_data\n", + " \n", + "thread_results = load_last_perf_test_result()\n", + "thread_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "From above result, we can see that latency is very close for different settings. The default setting (intra_op_num_threads=0, OMP_NUM_THREADS and OMP_WAIT_POLICY does not exist) performs the best. \n", + "From above result, we can see that latency is very close for different settings of intra_op_num_threads.\n", "\n", "### Model Results Comparison Tool\n", "\n", @@ -891,21 +1063,21 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "100% passed for 100 random inputs given thresholds (rtol=0.01, atol=0.01).\r\n", - "maximum absolute difference=5.316734313964844e-05\r\n", - "maximum relative difference=0.00012461667938623577\r\n" + "100% passed for 100 random inputs given thresholds (rtol=0.01, atol=0.01).\n", + "maximum absolute difference=0.05149984359741211\n" ] } ], "source": [ - "!{sys.executable} -m onnxruntime.transformers.compare_bert_results --baseline_model $export_model_path --optimized_model $optimized_fp32_model_path --batch_size 1 --sequence_length 128 --samples 100 --rtol 0.01 --atol 0.01 $GPU_OPTION" + "USE_GPU = '--use_gpu' if use_gpu else ''\n", + "!{sys.executable} -m onnxruntime.transformers.compare_bert_results --baseline_model $export_model_path --optimized_model $optimized_fp32_model_path --batch_size 1 --sequence_length 128 --samples 100 --rtol 0.01 --atol 0.01 $USE_GPU" ] }, { @@ -916,80 +1088,106 @@ "\n", "The optimizer.py script have an option **--float16** to convert model to use float16 to store weights. After the conversion, it could be faster to run in GPU with tensor cores like V100 or T4.\n", "\n", - "Let's run tools to measure the performance on V100. The results show significant performance improvement: latency is about 3.4 ms for float32 model, and 1.8 ms for float16 model." + "Let's run tools to measure the performance on Nvidia RTX 4090. The results show significant performance improvement: latency is about 3.2 ms for float32 model, and about 1.8 ms for float16 model." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - " apply: Fused LayerNormalization count: 49\n", - " apply: Fused Gelu count: 24\n", - "adjust_reshape_and_expand: Removed Reshape and Expand count: 0\n", - " apply: Fused SkipLayerNormalization count: 48\n", - " apply: Fused Attention count: 24\n", - " prune_graph: Graph pruned: 0 inputs, 0 outputs and 5 nodes are removed\n", - " apply: Fused EmbedLayerNormalization(with mask) count: 1\n", - " prune_graph: Graph pruned: 0 inputs, 0 outputs and 3 nodes are removed\n", - " prune_graph: Graph pruned: 0 inputs, 0 outputs and 0 nodes are removed\n", - " apply: Fused BiasGelu count: 24\n", - " apply: Fused SkipLayerNormalization(add bias) count: 48\n", + "\u001b\u0000[\u00000\u0000;\u00009\u00003\u0000m\u00002\u00000\u00002\u00003\u0000-\u00000\u00009\u0000-\u00001\u00002\u0000 \u00001\u00002\u0000:\u00005\u00007\u0000:\u00005\u00004\u0000.\u00005\u00005\u00000\u00008\u00002\u00002\u00008\u0000 \u0000[\u0000W\u0000:\u0000o\u0000n\u0000n\u0000x\u0000r\u0000u\u0000n\u0000t\u0000i\u0000m\u0000e\u0000:\u0000,\u0000 \u0000s\u0000e\u0000s\u0000s\u0000i\u0000o\u0000n\u0000_\u0000s\u0000t\u0000a\u0000t\u0000e\u0000.\u0000c\u0000c\u0000:\u00001\u00001\u00006\u00002\u0000 \u0000o\u0000n\u0000n\u0000x\u0000r\u0000u\u0000n\u0000t\u0000i\u0000m\u0000e\u0000:\u0000:\u0000V\u0000e\u0000r\u0000i\u0000f\u0000y\u0000E\u0000a\u0000c\u0000h\u0000N\u0000o\u0000d\u0000e\u0000I\u0000s\u0000A\u0000s\u0000s\u0000i\u0000g\u0000n\u0000e\u0000d\u0000T\u0000o\u0000A\u0000n\u0000E\u0000p\u0000]\u0000 \u0000S\u0000o\u0000m\u0000e\u0000 \u0000n\u0000o\u0000d\u0000e\u0000s\u0000 \u0000w\u0000e\u0000r\u0000e\u0000 \u0000n\u0000o\u0000t\u0000 \u0000a\u0000s\u0000s\u0000i\u0000g\u0000n\u0000e\u0000d\u0000 \u0000t\u0000o\u0000 \u0000t\u0000h\u0000e\u0000 \u0000p\u0000r\u0000e\u0000f\u0000e\u0000r\u0000r\u0000e\u0000d\u0000 \u0000e\u0000x\u0000e\u0000c\u0000u\u0000t\u0000i\u0000o\u0000n\u0000 \u0000p\u0000r\u0000o\u0000v\u0000i\u0000d\u0000e\u0000r\u0000s\u0000 \u0000w\u0000h\u0000i\u0000c\u0000h\u0000 \u0000m\u0000a\u0000y\u0000 \u0000o\u0000r\u0000 \u0000m\u0000a\u0000y\u0000 \u0000n\u0000o\u0000t\u0000 \u0000h\u0000a\u0000v\u0000e\u0000 \u0000a\u0000n\u0000 \u0000n\u0000e\u0000g\u0000a\u0000t\u0000i\u0000v\u0000e\u0000 \u0000i\u0000m\u0000p\u0000a\u0000c\u0000t\u0000 \u0000o\u0000n\u0000 \u0000p\u0000e\u0000r\u0000f\u0000o\u0000r\u0000m\u0000a\u0000n\u0000c\u0000e\u0000.\u0000 \u0000e\u0000.\u0000g\u0000.\u0000 \u0000O\u0000R\u0000T\u0000 \u0000e\u0000x\u0000p\u0000l\u0000i\u0000c\u0000i\u0000t\u0000l\u0000y\u0000 \u0000a\u0000s\u0000s\u0000i\u0000g\u0000n\u0000s\u0000 \u0000s\u0000h\u0000a\u0000p\u0000e\u0000 \u0000r\u0000e\u0000l\u0000a\u0000t\u0000e\u0000d\u0000 \u0000o\u0000p\u0000s\u0000 \u0000t\u0000o\u0000 \u0000C\u0000P\u0000U\u0000 \u0000t\u0000o\u0000 \u0000i\u0000m\u0000p\u0000r\u0000o\u0000v\u0000e\u0000 \u0000p\u0000e\u0000r\u0000f\u0000.\u0000\u001b\u0000[\u0000m\u0000\n", + "\u0000\u001b\u0000[\u00000\u0000;\u00009\u00003\u0000m\u00002\u00000\u00002\u00003\u0000-\u00000\u00009\u0000-\u00001\u00002\u0000 \u00001\u00002\u0000:\u00005\u00007\u0000:\u00005\u00004\u0000.\u00005\u00005\u00001\u00001\u00000\u00000\u00008\u0000 \u0000[\u0000W\u0000:\u0000o\u0000n\u0000n\u0000x\u0000r\u0000u\u0000n\u0000t\u0000i\u0000m\u0000e\u0000:\u0000,\u0000 \u0000s\u0000e\u0000s\u0000s\u0000i\u0000o\u0000n\u0000_\u0000s\u0000t\u0000a\u0000t\u0000e\u0000.\u0000c\u0000c\u0000:\u00001\u00001\u00006\u00004\u0000 \u0000o\u0000n\u0000n\u0000x\u0000r\u0000u\u0000n\u0000t\u0000i\u0000m\u0000e\u0000:\u0000:\u0000V\u0000e\u0000r\u0000i\u0000f\u0000y\u0000E\u0000a\u0000c\u0000h\u0000N\u0000o\u0000d\u0000e\u0000I\u0000s\u0000A\u0000s\u0000s\u0000i\u0000g\u0000n\u0000e\u0000d\u0000T\u0000o\u0000A\u0000n\u0000E\u0000p\u0000]\u0000 \u0000R\u0000e\u0000r\u0000u\u0000n\u0000n\u0000i\u0000n\u0000g\u0000 \u0000w\u0000i\u0000t\u0000h\u0000 \u0000v\u0000e\u0000r\u0000b\u0000o\u0000s\u0000e\u0000 \u0000o\u0000u\u0000t\u0000p\u0000u\u0000t\u0000 \u0000o\u0000n\u0000 \u0000a\u0000 \u0000n\u0000o\u0000n\u0000-\u0000m\u0000i\u0000n\u0000i\u0000m\u0000a\u0000l\u0000 \u0000b\u0000u\u0000i\u0000l\u0000d\u0000 \u0000w\u0000i\u0000l\u0000l\u0000 \u0000s\u0000h\u0000o\u0000w\u0000 \u0000n\u0000o\u0000d\u0000e\u0000 \u0000a\u0000s\u0000s\u0000i\u0000g\u0000n\u0000m\u0000e\u0000n\u0000t\u0000s\u0000.\u0000\u001b\u0000[\u0000m\u0000\n", + "\u0000 apply: Fused LayerNormalization: 49\n", + " apply: Fused Gelu: 24\n", + " apply: Fused SkipLayerNormalization: 48\n", + " apply: Fused Attention: 24\n", + " prune_graph: Removed 5 nodes\n", + " apply: Fused EmbedLayerNormalization(with mask): 1\n", + " prune_graph: Removed 10 nodes\n", + " apply: Fused BiasGelu: 24\n", + " apply: Fused SkipLayerNormalization(add bias): 48\n", " optimize: opset version: 11\n", + "get_fused_operator_statistics: Optimized operators:{'EmbedLayerNormalization': 1, 'Attention': 24, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 24, 'GemmFastGelu': 0, 'LayerNormalization': 0, 'SkipLayerNormalization': 48, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}\n", + " main: The model has been fully optimized.\n", " save_model_to_file: Sort graphs in topological order\n", - " save_model_to_file: Output model to ./onnx/bert-base-cased-squad_opt_gpu_fp16.onnx\n", - "get_fused_operator_statistics: Optimized operators:{'EmbedLayerNormalization': 1, 'Attention': 24, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 24, 'LayerNormalization': 0, 'SkipLayerNormalization': 48}\n", - " main: The model has been fully optimized.\n" + " save_model_to_file: Model saved to ./onnx/bert-base-cased-squad_opt_gpu_fp16.onnx\n" ] } ], "source": [ "optimized_fp16_model_path = './onnx/bert-base-cased-squad_opt_{}_fp16.onnx'.format('gpu' if use_gpu else 'cpu')\n", - "!{sys.executable} -m onnxruntime.transformers.optimizer --input $export_model_path --output $optimized_fp16_model_path --float16" + "!{sys.executable} -m onnxruntime.transformers.optimizer --input $export_model_path --output $optimized_fp16_model_path --float16 $USE_GPU" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "test setting TestSetting(batch_size=1, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=None, seed=3, verbose=False)\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=32,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 566.45 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=24,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.74 ms, Throughput = 574.96 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=15,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.74 ms, Throughput = 574.28 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=14,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.74 ms, Throughput = 575.17 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=13,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.76 ms, Throughput = 569.77 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=12,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.79 ms, Throughput = 559.84 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=11,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 566.09 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=10,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 563.97 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=9,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 565.70 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 565.50 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=7,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 566.38 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=6,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.75 ms, Throughput = 572.89 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=5,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.76 ms, Throughput = 568.67 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=4,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.78 ms, Throughput = 561.98 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 566.14 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=2,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.78 ms, Throughput = 563.25 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=1,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 565.09 QPS\n", + "test setting TestSetting(batch_size=1, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=None, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=1 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=4,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 6.78 ms, Throughput = 147.54 QPS\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 6.76 ms, Throughput = 147.85 QPS\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=2,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 6.79 ms, Throughput = 147.30 QPS\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=1,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 6.81 ms, Throughput = 146.75 QPS\n", - "Test summary is saved to onnx/perf_results_GPU_B1_S128_20210714-002224.txt\n" + "Test summary is saved to onnx\\perf_results_GPU_B1_S128_20230912-130021.txt\n" ] } ], "source": [ - "GPU_OPTION = '--use_gpu' if use_gpu else ''\n", + "GPU_OPTION = '--use_gpu --use_io_binding' if use_gpu else ''\n", "!python -m onnxruntime.transformers.bert_perf_test --model $optimized_fp16_model_path --batch_size 1 --sequence_length 128 --samples 1000 --test_times 1 $GPU_OPTION" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Float32 model perf results from ./onnx/perf_results_GPU_B1_S128_20210714-002224.txt\n" + "Perf results from ./onnx\\perf_results_GPU_B1_S128_20230912-130021.txt\n" ] }, { @@ -1026,82 +1224,243 @@ " \n", " \n", " 0\n", - " 6.76\n", - " 6.79\n", - " 6.81\n", - " 6.90\n", - " 6.91\n", - " 7.00\n", - " 147.85\n", - " 3\n", + " 1.74\n", + " 1.72\n", + " 1.72\n", + " 1.75\n", + " 1.80\n", + " 2.17\n", + " 575.17\n", + " 14\n", " \n", " \n", " 1\n", - " 6.78\n", - " 6.70\n", - " 6.79\n", - " 6.87\n", - " 6.90\n", - " 7.63\n", - " 147.54\n", - " 4\n", + " 1.74\n", + " 1.73\n", + " 1.73\n", + " 1.75\n", + " 1.76\n", + " 2.14\n", + " 574.96\n", + " 24\n", " \n", " \n", " 2\n", - " 6.79\n", - " 6.79\n", - " 6.81\n", - " 6.89\n", - " 6.91\n", - " 7.19\n", - " 147.30\n", - " 2\n", + " 1.74\n", + " 1.72\n", + " 1.73\n", + " 1.76\n", + " 1.79\n", + " 2.16\n", + " 574.28\n", + " 15\n", " \n", " \n", " 3\n", - " 6.81\n", - " 6.80\n", - " 6.89\n", - " 6.91\n", - " 6.97\n", - " 7.20\n", - " 146.75\n", + " 1.75\n", + " 1.72\n", + " 1.72\n", + " 1.76\n", + " 2.02\n", + " 2.15\n", + " 572.89\n", + " 6\n", + " \n", + " \n", + " 4\n", + " 1.76\n", + " 1.74\n", + " 1.74\n", + " 1.76\n", + " 1.81\n", + " 2.14\n", + " 569.77\n", + " 13\n", + " \n", + " \n", + " 5\n", + " 1.76\n", + " 1.72\n", + " 1.73\n", + " 1.80\n", + " 2.08\n", + " 2.15\n", + " 568.67\n", + " 5\n", + " \n", + " \n", + " 6\n", + " 1.77\n", + " 1.73\n", + " 1.74\n", + " 1.81\n", + " 2.12\n", + " 2.19\n", + " 566.45\n", + " 32\n", + " \n", + " \n", + " 7\n", + " 1.77\n", + " 1.74\n", + " 1.74\n", + " 1.77\n", + " 2.06\n", + " 2.17\n", + " 566.38\n", + " 7\n", + " \n", + " \n", + " 8\n", + " 1.77\n", + " 1.73\n", + " 1.74\n", + " 1.81\n", + " 2.10\n", + " 2.18\n", + " 566.14\n", + " 3\n", + " \n", + " \n", + " 9\n", + " 1.77\n", + " 1.73\n", + " 1.74\n", + " 1.82\n", + " 2.07\n", + " 2.17\n", + " 566.09\n", + " 11\n", + " \n", + " \n", + " 10\n", + " 1.77\n", + " 1.74\n", + " 1.75\n", + " 1.78\n", + " 2.02\n", + " 2.13\n", + " 565.70\n", + " 9\n", + " \n", + " \n", + " 11\n", + " 1.77\n", + " 1.73\n", + " 1.74\n", + " 1.93\n", + " 2.06\n", + " 2.16\n", + " 565.50\n", + " 8\n", + " \n", + " \n", + " 12\n", + " 1.77\n", + " 1.73\n", + " 1.74\n", + " 1.81\n", + " 2.11\n", + " 2.20\n", + " 565.09\n", " 1\n", " \n", + " \n", + " 13\n", + " 1.77\n", + " 1.74\n", + " 1.75\n", + " 1.85\n", + " 2.06\n", + " 2.15\n", + " 563.97\n", + " 10\n", + " \n", + " \n", + " 14\n", + " 1.78\n", + " 1.73\n", + " 1.74\n", + " 1.93\n", + " 2.13\n", + " 2.19\n", + " 563.25\n", + " 2\n", + " \n", + " \n", + " 15\n", + " 1.78\n", + " 1.74\n", + " 1.75\n", + " 1.88\n", + " 2.10\n", + " 2.19\n", + " 561.98\n", + " 4\n", + " \n", + " \n", + " 16\n", + " 1.79\n", + " 1.75\n", + " 1.76\n", + " 1.99\n", + " 2.08\n", + " 2.16\n", + " 559.84\n", + " 12\n", + " \n", " \n", "\n", "" ], "text/plain": [ - " Latency(ms) Latency_P50 Latency_P75 Latency_P90 Latency_P95 \\\n", - "0 6.76 6.79 6.81 6.90 6.91 \n", - "1 6.78 6.70 6.79 6.87 6.90 \n", - "2 6.79 6.79 6.81 6.89 6.91 \n", - "3 6.81 6.80 6.89 6.91 6.97 \n", + " Latency(ms) Latency_P50 Latency_P75 Latency_P90 Latency_P95 \\\n", + "0 1.74 1.72 1.72 1.75 1.80 \n", + "1 1.74 1.73 1.73 1.75 1.76 \n", + "2 1.74 1.72 1.73 1.76 1.79 \n", + "3 1.75 1.72 1.72 1.76 2.02 \n", + "4 1.76 1.74 1.74 1.76 1.81 \n", + "5 1.76 1.72 1.73 1.80 2.08 \n", + "6 1.77 1.73 1.74 1.81 2.12 \n", + "7 1.77 1.74 1.74 1.77 2.06 \n", + "8 1.77 1.73 1.74 1.81 2.10 \n", + "9 1.77 1.73 1.74 1.82 2.07 \n", + "10 1.77 1.74 1.75 1.78 2.02 \n", + "11 1.77 1.73 1.74 1.93 2.06 \n", + "12 1.77 1.73 1.74 1.81 2.11 \n", + "13 1.77 1.74 1.75 1.85 2.06 \n", + "14 1.78 1.73 1.74 1.93 2.13 \n", + "15 1.78 1.74 1.75 1.88 2.10 \n", + "16 1.79 1.75 1.76 1.99 2.08 \n", "\n", - " Latency_P99 Throughput(QPS) intra_op_num_threads \n", - "0 7.00 147.85 3 \n", - "1 7.63 147.54 4 \n", - "2 7.19 147.30 2 \n", - "3 7.20 146.75 1 " + " Latency_P99 Throughput(QPS) intra_op_num_threads \n", + "0 2.17 575.17 14 \n", + "1 2.14 574.96 24 \n", + "2 2.16 574.28 15 \n", + "3 2.15 572.89 6 \n", + "4 2.14 569.77 13 \n", + "5 2.15 568.67 5 \n", + "6 2.19 566.45 32 \n", + "7 2.17 566.38 7 \n", + "8 2.18 566.14 3 \n", + "9 2.17 566.09 11 \n", + "10 2.13 565.70 9 \n", + "11 2.16 565.50 8 \n", + "12 2.20 565.09 1 \n", + "13 2.15 563.97 10 \n", + "14 2.19 563.25 2 \n", + "15 2.19 561.98 4 \n", + "16 2.16 559.84 12 " ] }, - "execution_count": 22, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import os\n", - "import glob \n", - "import pandas\n", - "latest_result_file = max(glob.glob(\"./onnx/perf_results_GPU_B1_S128_*.txt\"), key=os.path.getmtime)\n", - "result_data = pandas.read_table(latest_result_file)\n", - "print(\"Float32 model perf results from\", latest_result_file)\n", - "# Remove some columns that have same values for all rows.\n", - "columns_to_remove = ['model', 'graph_optimization_level', 'batch_size', 'sequence_length', 'test_cases', 'test_times', 'use_gpu']\n", - "result_data.drop(columns_to_remove, axis=1, inplace=True)\n", - "result_data" + "fp32_result = load_last_perf_test_result()\n", + "fp32_result" ] }, { @@ -1117,59 +1476,265 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "test setting TestSetting(batch_size=32, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=3, seed=3, verbose=False)\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=32,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 20.41 ms, Throughput = 1567.65 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 1.73 ms, Throughput = 576.74 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=2,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 2.18 ms, Throughput = 917.92 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=4,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 3.25 ms, Throughput = 1229.91 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=8,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 5.38 ms, Throughput = 1486.89 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=16,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=128,random_sequence_length=False\n", + "Average latency = 9.90 ms, Throughput = 1616.79 QPS\n", + "test setting TestSetting(batch_size=32, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=32 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=32,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 168.40 ms, Throughput = 190.02 QPS\n", - "test setting TestSetting(batch_size=1, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=3, seed=3, verbose=False)\n", + "test setting TestSetting(batch_size=1, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=1 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 7.14 ms, Throughput = 140.00 QPS\n", - "test setting TestSetting(batch_size=2, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=3, seed=3, verbose=False)\n", + "test setting TestSetting(batch_size=2, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=2 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=2,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 11.27 ms, Throughput = 177.41 QPS\n", - "test setting TestSetting(batch_size=4, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=3, seed=3, verbose=False)\n", + "test setting TestSetting(batch_size=4, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=4 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=4,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 21.15 ms, Throughput = 189.09 QPS\n", - "test setting TestSetting(batch_size=8, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=3, seed=3, verbose=False)\n", + "test setting TestSetting(batch_size=8, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=8 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=8,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 42.27 ms, Throughput = 189.27 QPS\n", - "test setting TestSetting(batch_size=16, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, intra_op_num_threads=3, seed=3, verbose=False)\n", + "test setting TestSetting(batch_size=16, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=128, random_sequence_length=False)\n", "Generating 1000 samples for batch_size=16 sequence_length=128\n", - "Running test: model=bert-base-cased-squad_opt_gpu_fp16.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=3,batch_size=16,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True\n", - "Average latency = 83.77 ms, Throughput = 191.01 QPS\n", - "Test summary is saved to onnx/perf_results_GPU_B1-2-4-8-16-32_S128_20210714-002816.txt\n" + "Test summary is saved to onnx\\perf_results_GPU_B1-2-4-8-16-32_S128_20230912-130248.txt\n" ] } ], "source": [ - "GPU_OPTION = '--use_gpu' if use_gpu else ''\n", - "THREAD_SETTING = '--intra_op_num_threads 3'\n", + "THREAD_SETTING = '--intra_op_num_threads 8'\n", "!{sys.executable} -m onnxruntime.transformers.bert_perf_test --model $optimized_fp16_model_path --batch_size 1 2 4 8 16 32 --sequence_length 128 --samples 1000 --test_times 1 $THREAD_SETTING $GPU_OPTION" ] }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Perf results from ./onnx\\perf_results_GPU_B1-2-4-8-16-32_S128_20230912-130248.txt\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Latency(ms)Latency_P50Latency_P75Latency_P90Latency_P95Latency_P99Throughput(QPS)intra_op_num_threads
01.731.721.731.731.792.04576.748
12.182.162.162.182.292.76917.928
23.253.253.263.283.293.431229.918
35.385.385.395.425.445.601486.898
49.909.899.949.9710.0010.061616.798
520.4120.4120.4720.5220.5520.681567.658
\n", + "
" + ], + "text/plain": [ + " Latency(ms) Latency_P50 Latency_P75 Latency_P90 Latency_P95 \\\n", + "0 1.73 1.72 1.73 1.73 1.79 \n", + "1 2.18 2.16 2.16 2.18 2.29 \n", + "2 3.25 3.25 3.26 3.28 3.29 \n", + "3 5.38 5.38 5.39 5.42 5.44 \n", + "4 9.90 9.89 9.94 9.97 10.00 \n", + "5 20.41 20.41 20.47 20.52 20.55 \n", + "\n", + " Latency_P99 Throughput(QPS) intra_op_num_threads \n", + "0 2.04 576.74 8 \n", + "1 2.76 917.92 8 \n", + "2 3.43 1229.91 8 \n", + "3 5.60 1486.89 8 \n", + "4 10.06 1616.79 8 \n", + "5 20.68 1567.65 8 " + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fp16_result = load_last_perf_test_result()\n", + "fp16_result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Packing Mode (Effective Transformer)\n", + "\n", + "When padding ratio is high, it is helpful to use packing mode, also known as [effective transformer](https://github.com/bytedance/effective_transformer).\n", + "This feature requires onnxruntime-gpu verison 1.16 or later. \n", + "\n", + "In below example, average sequence length after removing paddings is 32, the sequence length with paddings is 128. We can see 3x throughput with packing mode (QPS increased from 1617 to 5652)." + ] + }, { "cell_type": "code", "execution_count": 24, - "metadata": { - "scrolled": false - }, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "_replace_attention_with_packing_attention: Converted 24 Attention nodes to PackedAttention.\n", + " save_model_to_file: Sort graphs in topological order\n", + " save: Delete the existing onnx file: ./onnx/bert-base-cased-squad_opt_gpu_fp16_packed.onnx\n", + " save: Delete the existing external data file: ./onnx/bert-base-cased-squad_opt_gpu_fp16_packed.onnx.data\n", + " save_model_to_file: Model saved to ./onnx/bert-base-cased-squad_opt_gpu_fp16_packed.onnx\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running test: model=bert-base-cased-squad_opt_gpu_fp16_packed.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=32,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=32,random_sequence_length=False\n", + "Average latency = 5.66 ms, Throughput = 5652.40 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16_packed.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=1,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=32,random_sequence_length=False\n", + "Average latency = 1.70 ms, Throughput = 586.97 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16_packed.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=2,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=32,random_sequence_length=False\n", + "Average latency = 1.79 ms, Throughput = 1114.37 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16_packed.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=4,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=32,random_sequence_length=False\n", + "Average latency = 1.77 ms, Throughput = 2262.31 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16_packed.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=8,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=32,random_sequence_length=False\n", + "Average latency = 2.18 ms, Throughput = 3666.45 QPS\n", + "Running test: model=bert-base-cased-squad_opt_gpu_fp16_packed.onnx,graph_optimization_level=ENABLE_ALL,intra_op_num_threads=8,batch_size=16,sequence_length=128,test_cases=1000,test_times=1,use_gpu=True,use_io_binding=True,average_sequence_length=32,random_sequence_length=False\n", + "Average latency = 3.31 ms, Throughput = 4829.58 QPS\n", + "test setting TestSetting(batch_size=32, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=32, random_sequence_length=False)\n", + "Generating 1000 samples for batch_size=32 sequence_length=128\n", + "test setting TestSetting(batch_size=1, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=32, random_sequence_length=False)\n", + "Generating 1000 samples for batch_size=1 sequence_length=128\n", + "test setting TestSetting(batch_size=2, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=32, random_sequence_length=False)\n", + "Generating 1000 samples for batch_size=2 sequence_length=128\n", + "test setting TestSetting(batch_size=4, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=32, random_sequence_length=False)\n", + "Generating 1000 samples for batch_size=4 sequence_length=128\n", + "test setting TestSetting(batch_size=8, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=32, random_sequence_length=False)\n", + "Generating 1000 samples for batch_size=8 sequence_length=128\n", + "test setting TestSetting(batch_size=16, sequence_length=128, test_cases=1000, test_times=1, use_gpu=True, use_io_binding=True, provider=None, intra_op_num_threads=8, seed=3, verbose=False, log_severity=2, average_sequence_length=32, random_sequence_length=False)\n", + "Generating 1000 samples for batch_size=16 sequence_length=128\n", + "Test summary is saved to onnx\\perf_results_GPU_B1-2-4-8-16-32_S128_20230912-130354.txt\n" + ] + } + ], + "source": [ + "assert use_gpu, \"Require GPU for packing mode\"\n", + "packed_fp16_model_path = './onnx/bert-base-cased-squad_opt_gpu_fp16_packed.onnx'\n", + "!{sys.executable} -m onnxruntime.transformers.convert_to_packing_mode --input $optimized_fp16_model_path --output $packed_fp16_model_path --use_external_data_format\n", + "!{sys.executable} -m onnxruntime.transformers.bert_perf_test --model $packed_fp16_model_path --batch_size 1 2 4 8 16 32 --sequence_length 128 --average_sequence_length 32 --samples 1000 --test_times 1 $THREAD_SETTING $GPU_OPTION " + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Float16 model summary from ./onnx/perf_results_GPU_B1-2-4-8-16-32_S128_20210714-002816.txt\n" + "Perf results from ./onnx\\perf_results_GPU_B1-2-4-8-16-32_S128_20230912-130354.txt\n" ] }, { @@ -1200,75 +1765,75 @@ " Latency_P95\n", " Latency_P99\n", " Throughput(QPS)\n", - " batch_size\n", + " intra_op_num_threads\n", " \n", " \n", " \n", " \n", " 0\n", - " 7.14\n", - " 7.10\n", - " 7.13\n", - " 7.25\n", - " 7.35\n", - " 10.99\n", - " 140.00\n", - " 1\n", + " 1.70\n", + " 1.63\n", + " 1.65\n", + " 2.13\n", + " 2.20\n", + " 2.32\n", + " 586.97\n", + " 8\n", " \n", " \n", " 1\n", - " 11.27\n", - " 11.23\n", - " 11.28\n", - " 11.53\n", - " 11.57\n", - " 12.05\n", - " 177.41\n", - " 2\n", + " 1.77\n", + " 1.74\n", + " 1.76\n", + " 1.82\n", + " 1.93\n", + " 2.17\n", + " 2262.31\n", + " 8\n", " \n", " \n", " 2\n", - " 21.15\n", - " 21.13\n", - " 21.25\n", - " 21.44\n", - " 21.59\n", - " 22.07\n", - " 189.09\n", - " 4\n", + " 1.79\n", + " 1.73\n", + " 1.74\n", + " 2.12\n", + " 2.18\n", + " 2.32\n", + " 1114.37\n", + " 8\n", " \n", " \n", " 3\n", - " 42.27\n", - " 42.26\n", - " 42.68\n", - " 42.95\n", - " 43.11\n", - " 45.11\n", - " 189.27\n", + " 2.18\n", + " 2.16\n", + " 2.17\n", + " 2.22\n", + " 2.30\n", + " 2.64\n", + " 3666.45\n", " 8\n", " \n", " \n", " 4\n", - " 83.77\n", - " 83.84\n", - " 84.29\n", - " 84.94\n", - " 85.35\n", - " 86.34\n", - " 191.01\n", - " 16\n", + " 3.31\n", + " 3.31\n", + " 3.32\n", + " 3.35\n", + " 3.39\n", + " 3.51\n", + " 4829.58\n", + " 8\n", " \n", " \n", " 5\n", - " 168.40\n", - " 169.62\n", - " 170.78\n", - " 171.94\n", - " 172.82\n", - " 174.28\n", - " 190.02\n", - " 32\n", + " 5.66\n", + " 5.66\n", + " 5.68\n", + " 5.71\n", + " 5.74\n", + " 5.91\n", + " 5652.40\n", + " 8\n", " \n", " \n", "\n", @@ -1276,38 +1841,30 @@ ], "text/plain": [ " Latency(ms) Latency_P50 Latency_P75 Latency_P90 Latency_P95 \\\n", - "0 7.14 7.10 7.13 7.25 7.35 \n", - "1 11.27 11.23 11.28 11.53 11.57 \n", - "2 21.15 21.13 21.25 21.44 21.59 \n", - "3 42.27 42.26 42.68 42.95 43.11 \n", - "4 83.77 83.84 84.29 84.94 85.35 \n", - "5 168.40 169.62 170.78 171.94 172.82 \n", + "0 1.70 1.63 1.65 2.13 2.20 \n", + "1 1.77 1.74 1.76 1.82 1.93 \n", + "2 1.79 1.73 1.74 2.12 2.18 \n", + "3 2.18 2.16 2.17 2.22 2.30 \n", + "4 3.31 3.31 3.32 3.35 3.39 \n", + "5 5.66 5.66 5.68 5.71 5.74 \n", "\n", - " Latency_P99 Throughput(QPS) batch_size \n", - "0 10.99 140.00 1 \n", - "1 12.05 177.41 2 \n", - "2 22.07 189.09 4 \n", - "3 45.11 189.27 8 \n", - "4 86.34 191.01 16 \n", - "5 174.28 190.02 32 " + " Latency_P99 Throughput(QPS) intra_op_num_threads \n", + "0 2.32 586.97 8 \n", + "1 2.17 2262.31 8 \n", + "2 2.32 1114.37 8 \n", + "3 2.64 3666.45 8 \n", + "4 3.51 4829.58 8 \n", + "5 5.91 5652.40 8 " ] }, - "execution_count": 24, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import os\n", - "import glob \n", - "import pandas\n", - "latest_result_file = max(glob.glob(\"./onnx/perf_results_*.txt\"), key=os.path.getmtime)\n", - "result_data = pandas.read_table(latest_result_file)\n", - "print(\"Float16 model summary from\", latest_result_file)\n", - "columns_to_remove = ['model', 'graph_optimization_level', 'test_cases', 'test_times', 'use_gpu', 'sequence_length']\n", - "columns_to_remove.extend(['intra_op_num_threads'])\n", - "result_data.drop(columns_to_remove, axis=1, inplace=True)\n", - "result_data" + "packing_result = load_last_perf_test_result()\n", + "packing_result" ] }, { @@ -1327,7 +1884,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": { "scrolled": true }, @@ -1336,131 +1893,53 @@ "name": "stdout", "output_type": "stream", "text": [ - "{\r\n", - " \"gpu\": {\r\n", - " \"driver_version\": \"450.51.05\",\r\n", - " \"devices\": [\r\n", - " {\r\n", - " \"memory_total\": 15843721216,\r\n", - " \"memory_available\": 9313189888,\r\n", - " \"name\": \"Tesla T4\"\r\n", - " }\r\n", - " ]\r\n", - " },\r\n", - " \"cpu\": {\r\n", - " \"brand\": \"AMD EPYC 7V12 64-Core Processor\",\r\n", - " \"cores\": 4,\r\n", - " \"logical_cores\": 4,\r\n", - " \"hz\": [\r\n", - " 2445417000,\r\n", - " 0\r\n", - " ],\r\n", - " \"l2_cache\": 524288,\r\n", - " \"flags\": [\r\n", - " \"3dnowext\",\r\n", - " \"3dnowprefetch\",\r\n", - " \"abm\",\r\n", - " \"adx\",\r\n", - " \"aes\",\r\n", - " \"apic\",\r\n", - " \"arat\",\r\n", - " \"avx\",\r\n", - " \"avx2\",\r\n", - " \"bmi1\",\r\n", - " \"bmi2\",\r\n", - " \"clflush\",\r\n", - " \"clflushopt\",\r\n", - " \"clwb\",\r\n", - " \"cmov\",\r\n", - " \"cmp_legacy\",\r\n", - " \"cpuid\",\r\n", - " \"cr8_legacy\",\r\n", - " \"cx16\",\r\n", - " \"cx8\",\r\n", - " \"de\",\r\n", - " \"extd_apicid\",\r\n", - " \"f16c\",\r\n", - " \"fma\",\r\n", - " \"fpu\",\r\n", - " \"fsgsbase\",\r\n", - " \"fxsr\",\r\n", - " \"fxsr_opt\",\r\n", - " \"ht\",\r\n", - " \"hypervisor\",\r\n", - " \"lahf_lm\",\r\n", - " \"lm\",\r\n", - " \"mca\",\r\n", - " \"mce\",\r\n", - " \"misalignsse\",\r\n", - " \"mmx\",\r\n", - " \"mmxext\",\r\n", - " \"movbe\",\r\n", - " \"msr\",\r\n", - " \"mtrr\",\r\n", - " \"nopl\",\r\n", - " \"nx\",\r\n", - " \"osvw\",\r\n", - " \"osxsave\",\r\n", - " \"pae\",\r\n", - " \"pat\",\r\n", - " \"pclmulqdq\",\r\n", - " \"pdpe1gb\",\r\n", - " \"pge\",\r\n", - " \"pni\",\r\n", - " \"popcnt\",\r\n", - " \"pse\",\r\n", - " \"pse36\",\r\n", - " \"rdpid\",\r\n", - " \"rdrand\",\r\n", - " \"rdrnd\",\r\n", - " \"rdseed\",\r\n", - " \"rdtscp\",\r\n", - " \"rep_good\",\r\n", - " \"sep\",\r\n", - " \"sha\",\r\n", - " \"sha_ni\",\r\n", - " \"smap\",\r\n", - " \"smep\",\r\n", - " \"ssbd\",\r\n", - " \"sse\",\r\n", - " \"sse2\",\r\n", - " \"sse4_1\",\r\n", - " \"sse4_2\",\r\n", - " \"sse4a\",\r\n", - " \"ssse3\",\r\n", - " \"syscall\",\r\n", - " \"topoext\",\r\n", - " \"tsc\",\r\n", - " \"umip\",\r\n", - " \"vme\",\r\n", - " \"vmmcall\",\r\n", - " \"xgetbv1\",\r\n", - " \"xsave\",\r\n", - " \"xsavec\",\r\n", - " \"xsaveerptr\",\r\n", - " \"xsaveopt\",\r\n", - " \"xsaves\"\r\n", - " ],\r\n", - " \"processor\": \"x86_64\"\r\n", - " },\r\n", - " \"memory\": {\r\n", - " \"total\": 29450223616,\r\n", - " \"available\": 22402334720\r\n", - " },\r\n", - " \"python\": \"3.6.13.final.0 (64 bit)\",\r\n", - " \"os\": \"Linux-5.4.0-1046-azure-x86_64-with-debian-buster-sid\",\r\n", - " \"onnxruntime\": {\r\n", - " \"version\": \"1.8.1\",\r\n", - " \"support_gpu\": true\r\n", - " },\r\n", - " \"onnxruntime_tools\": null,\r\n", - " \"pytorch\": {\r\n", - " \"version\": \"1.9.0+cu111\",\r\n", - " \"support_gpu\": true,\r\n", - " \"cuda\": \"11.1\"\r\n", - " },\r\n", - " \"tensorflow\": null\r\n", - "}\r\n" + "{\n", + " \"gpu\": {\n", + " \"driver_version\": \"537.13\",\n", + " \"devices\": [\n", + " {\n", + " \"memory_total\": 25757220864,\n", + " \"memory_available\": 18009264128,\n", + " \"name\": \"NVIDIA GeForce RTX 4090\"\n", + " }\n", + " ]\n", + " },\n", + " \"cpu\": {\n", + " \"brand\": \"13th Gen Intel(R) Core(TM) i9-13900\",\n", + " \"cores\": 24,\n", + " \"logical_cores\": 32,\n", + " \"hz\": \"2000000000,0\",\n", + " \"l2_cache\": 33554432,\n", + " \"flags\": \"3dnow,3dnowprefetch,abm,acpi,adx,aes,apic,avx,avx2,bmi1,bmi2,clflush,clflushopt,clwb,cmov,cx16,cx8,de,dts,erms,est,f16c,fma,fpu,fxsr,gfni,ht,hypervisor,ia64,intel_pt,invpcid,lahf_lm,mca,mce,mmx,monitor,movbe,msr,mtrr,osxsave,pae,pat,pbe,pcid,pclmulqdq,pdcm,pge,pni,popcnt,pse,pse36,rdpid,rdrnd,rdseed,sep,serial,sha,smap,smep,ss,sse,sse2,sse4_1,sse4_2,ssse3,tm,tm2,tsc,tscdeadline,umip,vaes,vme,vpclmulqdq,x2apic,xsave,xtpr\",\n", + " \"processor\": \"Intel64 Family 6 Model 183 Stepping 1, GenuineIntel\"\n", + " },\n", + " \"memory\": {\n", + " \"total\": 33992912896,\n", + " \"available\": 17272422400\n", + " },\n", + " \"os\": \"Windows-10-10.0.22621-SP0\",\n", + " \"python\": \"3.10.13.final.0 (64 bit)\",\n", + " \"packages\": {\n", + " \"flatbuffers\": \"23.5.26\",\n", + " \"numpy\": \"1.25.2\",\n", + " \"onnx\": \"1.14.1\",\n", + " \"onnxruntime-gpu\": \"1.16.0\",\n", + " \"protobuf\": \"3.20.3\",\n", + " \"sympy\": \"1.12\",\n", + " \"torch\": \"2.0.1+cu118\",\n", + " \"transformers\": \"4.33.1\"\n", + " },\n", + " \"onnxruntime\": {\n", + " \"version\": \"1.16.0\",\n", + " \"support_gpu\": true\n", + " },\n", + " \"pytorch\": {\n", + " \"version\": \"2.0.1+cu118\",\n", + " \"support_gpu\": true,\n", + " \"cuda\": \"11.8\"\n", + " },\n", + " \"tensorflow\": null\n", + "}\n" ] } ], @@ -1485,9 +1964,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.13" + "version": "3.10.13" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From a2fba28f6cdc6a4ea946a9d56b4b4d19241ebe27 Mon Sep 17 00:00:00 2001 From: "Nat Kershaw (MSFT)" Date: Thu, 14 Sep 2023 20:43:24 -0700 Subject: [PATCH 029/121] Remove extraneous javascript includes (#17558) --- docs/c_cxx/doxygen-header.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/c_cxx/doxygen-header.html b/docs/c_cxx/doxygen-header.html index 364f76f7f0580..6d95bf57ff98f 100644 --- a/docs/c_cxx/doxygen-header.html +++ b/docs/c_cxx/doxygen-header.html @@ -16,7 +16,7 @@ - + $treeview $search $mathjax From 9aafbe3febd569da2613230ce45b1ce1054488d5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 14 Sep 2023 21:14:44 -0700 Subject: [PATCH 030/121] [js/web] revise TensorView (#17473) ### Description This change: - removes the unused `Tensor` types declared in /js/web/lib/wasm/jsep/tensor.ts - removes duplicated util functions in /js/web/lib/wasm/jsep/tensor.ts - renames /js/web/lib/wasm/jsep/**tensor.ts** to /js/web/lib/wasm/jsep/**tensor-view.ts** and update corresponding references. It was kind of confusing that we have multiple `Tensor` types defined in different places also we have multiple `tensor.ts` source files. This is one of the prerequisites for supporting IO binding for WebGPU buffer in onnxruntime-web. list of prerequisites PRs: https://github.com/microsoft/onnxruntime/pull/17465 https://github.com/microsoft/onnxruntime/pull/17469 https://github.com/microsoft/onnxruntime/pull/17470 https://github.com/microsoft/onnxruntime/pull/17472 https://github.com/microsoft/onnxruntime/pull/17473 (this one) --- js/web/lib/wasm/jsep/backend-webgpu.ts | 2 +- js/web/lib/wasm/jsep/init.ts | 2 +- js/web/lib/wasm/jsep/tensor-view.ts | 39 ++++++ js/web/lib/wasm/jsep/tensor.ts | 115 ------------------ .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 2 +- .../ops/3rd-party/conv_backprop_webgpu.ts | 2 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 2 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 2 +- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 2 +- .../wasm/jsep/webgpu/ops/gather-elements.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 2 +- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 2 +- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 2 +- .../lib/wasm/jsep/webgpu/program-manager.ts | 2 +- js/web/lib/wasm/jsep/webgpu/types.ts | 3 +- 35 files changed, 72 insertions(+), 149 deletions(-) create mode 100644 js/web/lib/wasm/jsep/tensor-view.ts delete mode 100644 js/web/lib/wasm/jsep/tensor.ts diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e6e78df2cfb23..5e77a0343b4ee 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -4,7 +4,7 @@ import {Env} from 'onnxruntime-common'; import {configureLogger, LOG_DEBUG} from './log'; -import {TensorView} from './tensor'; +import {TensorView} from './tensor-view'; import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 24ff79cfad3ee..78316cbe1c825 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -8,7 +8,7 @@ import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; -import {TensorView} from './tensor'; +import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; diff --git a/js/web/lib/wasm/jsep/tensor-view.ts b/js/web/lib/wasm/jsep/tensor-view.ts new file mode 100644 index 0000000000000..69b9287f6de29 --- /dev/null +++ b/js/web/lib/wasm/jsep/tensor-view.ts @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Tensor} from 'onnxruntime-common'; + +import {tensorTypeToTypedArrayConstructor} from '../wasm-common'; + +export const createView = (dataBuffer: ArrayBuffer, type: Tensor.Type): Int32Array|Uint32Array|BigInt64Array| + BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array => + new (tensorTypeToTypedArrayConstructor(type))(dataBuffer); + +/** + * a TensorView does not own the data. + */ +export interface TensorView { + readonly data: number; + readonly dataType: number; + readonly dims: readonly number[]; + + /** + * get a Float32Array data view of the tensor data. tensor data must be on CPU. + */ + getFloat32Array(): Float32Array; + + /** + * get a BigInt64Array data view of the tensor data. tensor data must be on CPU. + */ + getBigInt64Array(): BigInt64Array; + + /** + * get a Int32Array data view of the tensor data. tensor data must be on CPU. + */ + getInt32Array(): Int32Array; + + /** + * create a new tensor view with the same data but different dimensions. + */ + reshape(newDims: readonly number[]): TensorView; +} diff --git a/js/web/lib/wasm/jsep/tensor.ts b/js/web/lib/wasm/jsep/tensor.ts deleted file mode 100644 index abe61e07fc0a8..0000000000000 --- a/js/web/lib/wasm/jsep/tensor.ts +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -export declare namespace Tensor { - export interface DataTypeMap { - bool: Uint8Array; - float32: Float32Array; - float64: Float64Array; - string: string[]; - int8: Int8Array; - uint8: Uint8Array; - int16: Int16Array; - uint16: Uint16Array; - int32: Int32Array; - uint32: Uint32Array; - int64: BigInt64Array; - uint64: BigUint64Array; - } - - export type DataType = keyof DataTypeMap; - - export type StringType = Tensor.DataTypeMap['string']; - export type BooleanType = Tensor.DataTypeMap['bool']; - export type IntegerType = Tensor.DataTypeMap['int8']|Tensor.DataTypeMap['uint8']|Tensor.DataTypeMap['int16']| - Tensor.DataTypeMap['uint16']|Tensor.DataTypeMap['int32']|Tensor.DataTypeMap['uint32']| - Tensor.DataTypeMap['int64']|Tensor.DataTypeMap['uint64']; - export type FloatType = Tensor.DataTypeMap['float32']|Tensor.DataTypeMap['float64']; - export type NumberType = BooleanType|IntegerType|FloatType; - - export type Id = number; -} - -export const sizeof = (type: Tensor.DataType): number => { - switch (type) { - case 'bool': - case 'int8': - case 'uint8': - return 1; - case 'int16': - case 'uint16': - return 2; - case 'int32': - case 'uint32': - case 'float32': - return 4; - case 'int64': - case 'uint64': - case 'float64': - return 8; - default: - throw new Error(`cannot calculate sizeof() on type ${type}`); - } -}; - -const dataviewConstructor = (type: Tensor.DataType) => { - switch (type) { - case 'bool': - case 'uint8': - return Uint8Array; - case 'int8': - return Int8Array; - case 'int16': - return Int16Array; - case 'uint16': - return Uint16Array; - case 'int32': - return Int32Array; - case 'uint32': - return Uint32Array; - case 'int64': - return BigInt64Array; - case 'uint64': - return BigUint64Array; - case 'float32': - return Float32Array; - case 'float64': - return Float64Array; - default: - // should never run to here - throw new Error('unspecified error'); - } -}; - -export const createView = (dataBuffer: ArrayBuffer, type: Tensor.DataType): Int32Array|Uint32Array|BigInt64Array| - BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array => - new (dataviewConstructor(type))(dataBuffer); - -/** - * a TensorView does not own the data. - */ -export interface TensorView { - readonly data: number; - readonly dataType: number; - readonly dims: readonly number[]; - - /** - * get a Float32Array data view of the tensor data. tensor data must be on CPU. - */ - getFloat32Array(): Float32Array; - - /** - * get a BigInt64Array data view of the tensor data. tensor data must be on CPU. - */ - getBigInt64Array(): BigInt64Array; - - /** - * get a Int32Array data view of the tensor data. tensor data must be on CPU. - */ - getInt32Array(): Int32Array; - - /** - * create a new tensor view with the same data but different dimensions. - */ - reshape(newDims: readonly number[]): TensorView; -} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 02507ad802b36..08b1d1f30b233 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -20,7 +20,7 @@ // modified to fit the needs of the project import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor'; +import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; import {ConvAttributes} from '../conv'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 82fe3d5b6af43..ec6df438129fb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -18,7 +18,7 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor'; +import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; import {inputVariable, outputVariable, ShaderHelper} from '../common'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index ab4f608451101..8d43dbb378a69 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,7 +19,7 @@ // // modified to fit the needs of the project -import {TensorView} from '../../../tensor'; +import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index 12a13d9d8e0a0..412e61a3cc0f9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -6,7 +6,7 @@ // a optimized codepath for this. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index b004ca37a2ea8..13d3a91bb339e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 9b294803d3787..279632c190ded 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 8a794ce16a0b5..1b7b7e0b29a25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index acdfd7e40f258..e7d1ddf771650 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 3a83b1c2de6c1..95a64e5787841 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts index 0abece9559630..21c0b97042fbb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index f0196f37c3153..fc9ebf004ad25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 2d845775f1c62..824ce682c0c4b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index 57c5fccfd8c26..a7d355bc13704 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index a915a4bbd969c..0db060dbec54a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 3ce963b54f3ee..1a36d4a7545d6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 449073a133295..5a148bda0a9f7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 8a9927b25a52e..d6a79e9460c3f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index e4dae00db6305..837ac8410f291 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index d90296b5c5a46..c2f89fd2845df 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 79071d32443d6..8c8c12fc54ddb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index cb592c838dd97..0b8d03ea73b6b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 1d0b8229a76f7..8b9dbbf57ac75 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 4b845bcf2121b..7bfdd73b8af18 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 4211e526898e6..257b9ebc1fdac 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index e2443b24410a5..495a4bcea4f47 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,7 +5,7 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 9a150d21ea02e..3367091bbac23 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 99d9668757caa..109c29bfc8a80 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 9243b0e4af6b6..38dcaeab54c54 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index ef63d1177768c..7e52954734216 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index cce61be3448cd..cf2687e4c7382 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -4,7 +4,7 @@ import {tensorDataTypeEnumToString} from '../../wasm-common'; import {WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; -import {TensorView} from '../tensor'; +import {TensorView} from '../tensor-view'; import {createShaderHelper} from './ops/common'; import {Artifact, GpuData, ProgramInfo} from './types'; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index ddbb9afc275f2..78f80b89774e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor, TensorView} from '../tensor'; +import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; @@ -19,7 +19,6 @@ export interface GpuData { } export interface TensorInfo { - id?: Tensor.Id; dims: readonly number[]; dataType: number; gpuDataType: GpuDataType; From 94f2ed6bbd95d0561eaad7c7a42ee3f56eb9f29e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 14 Sep 2023 21:15:10 -0700 Subject: [PATCH 031/121] run_CIs_for_external_pr.py: update required pipelines (#17557) ### Description Add required pipeline "Windows x64 QNN CI Pipeline" to script "run_CIs_for_external_pr.py" --- tools/python/run_CIs_for_external_pr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index beee4efc74c30..dcc6a92d84ef2 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -71,6 +71,7 @@ def main(): pipelines = [ # windows "Windows ARM64 QNN CI Pipeline", + "Windows x64 QNN CI Pipeline", "Windows CPU CI Pipeline", "Windows GPU CI Pipeline", "Windows GPU TensorRT CI Pipeline", From 4d931edd78370b18805bd908e92c28e50994f6af Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 15 Sep 2023 08:20:47 -0700 Subject: [PATCH 032/121] Update tensorrt_dependencies in setup.py (#17562) ### Description The files should not have the minor version number. The names were added in #17365 by mistake. ### Motivation and Context We did not successfully exclude them out. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 13731eb4e76bb..7e6ab93194b0d 100644 --- a/setup.py +++ b/setup.py @@ -214,7 +214,7 @@ def run(self): "libhsa-runtime64.so.1", ] - tensorrt_dependencies = ["libnvinfer.so.8.6", "libnvinfer_plugin.so.8.6", "libnvonnxparser.so.8.6"] + tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] dest = "onnxruntime/capi/libonnxruntime_providers_openvino.so" if path.isfile(dest): From a5302fec93cd91d976e3798e38873927a2ddd8a9 Mon Sep 17 00:00:00 2001 From: zesongw Date: Fri, 15 Sep 2023 23:29:48 +0800 Subject: [PATCH 033/121] [WebNN EP] Fix bug for PRelu on CPU backend. (#17543) ### Description WebNN CPU backend expects slope of PRelu to be a static value. For now, we will not support it. ### Motivation and Context Fallback this case to pass the CI. --- .../webnn/builders/impl/binary_op_builder.cc | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 8e588aad15f4c..5adaf80543279 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -18,6 +18,10 @@ class BinaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -50,6 +54,24 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } +bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + + // XNNPACK prelu operator expects slope to be a static value. + // https://github.com/google/XNNPACK/issues/4692 + // TODO: Remove this check after it is solved. + if (op_type == "PRelu" && !Contains(initializers, input_defs[1]->Name()) && device_type == WebnnDeviceType::CPU) { + LOGS(logger, VERBOSE) << "The second input (slope) for PRelu must be a constant initializer for WebNN CPU backend."; + return false; + } + + return true; +} + void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; From 377f959c69e9f213cd4a8c71a5e80162a412989a Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 15 Sep 2023 23:35:55 +0800 Subject: [PATCH 034/121] Run Final_Jar_Testing_Linux_GPU in docker (#17533) ### Description 1. Create a package test image based on [RedHat UBI](https://www.redhat.com/en/blog/introducing-red-hat-universal-base-image) 2. Install TensorRT 8.6.1.6 in RedHat. (Ref. https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#maclearn-net-repo-install-rpm) 3. Run Final_Jar_Testing_Linux_GPU in docker (base image: nvidia/cuda:11.8.0-cudnn8-devel-ubi8) ### Motivation and Context [AB#18470](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/18470) ### Verification https://dev.azure.com/aiinfra/Lotus/_build/results?buildId=354004&view=logs&j=8939b564-1402-57b5-92dc-510eba75e069&t=8939b564-1402-57b5-92dc-510eba75e069 --- .../c-api-noopenmp-packaging-pipelines.yml | 35 ++++++++++----- .../flex-downloadPipelineArtifact.yml | 2 +- ...ckerfile.package_ubi8_cuda11_8_tensorrt8_6 | 45 +++++++++++++++++++ .../linux/docker/scripts/install_java.sh | 12 +++++ 4 files changed, 82 insertions(+), 12 deletions(-) create mode 100644 tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 create mode 100755 tools/ci_build/github/linux/docker/scripts/install_java.sh diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 09b2a0697447e..fdd8c09333737 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -424,19 +424,32 @@ stages: - checkout: self submodules: false - template: templates/set-version-number-variables-step.yml - - task: DownloadPipelineArtifact@2 - displayName: 'Download Final Jar' - inputs: - buildType: 'current' - artifactName: 'onnxruntime-java-gpu' - targetPath: '$(Build.BinariesDirectory)/final-jar' - - task: Bash@3 + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Final Jar' + ArtifactName: onnxruntime-java-gpu + TargetPath: '$(Build.BinariesDirectory)/final-jar' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - bash: | + docker run --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) displayName: 'Test' - inputs: - targetType: filePath - filePath: 'tools/ci_build/github/linux/java_linux_final_test.sh' - arguments: '-r $(Build.BinariesDirectory) -v $(OnnxRuntimeVersion)' - template: templates/component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml b/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml index 0f4e0553d05bf..a83451a1b33d9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml +++ b/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml @@ -18,7 +18,7 @@ parameters: steps: - task: DownloadPipelineArtifact@2 - displayName: ${{ parameters.StepName }}} + displayName: ${{ parameters.StepName }} inputs: artifactName: ${{ parameters.ArtifactName}} targetPath: '${{ parameters.TargetPath }}' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 new file mode 100644 index 0000000000000..cdf504c8e3b03 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -0,0 +1,45 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to Test ONNX Runtime on UBI8 with CUDA 11.8 and TensorRT 8.6 + +# Build base image with required system packages +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 AS base + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} + +RUN dnf install -y bash wget &&\ + dnf clean dbcache + +# Install python3 +RUN dnf install -y \ + python3.8 \ + python38-pip \ + python38-wheel &&\ + cd /usr/local/bin &&\ + ln -s /usr/bin/python3 python3.8 &&\ + ln -s /usr/bin/pip3 pip3.8; + +RUN pip3 install --upgrade pip +RUN pip3 install setuptools>=41.0.0 + +# Install TensorRT +RUN dnf install -y libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 +RUN v="8.6.1.6-1+cuda11.8" &&\ + dnf downgrade -y libnvinfer8-${v} libnvinfer8-${v} libnvonnxparsers8-${v} libnvparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-lean8-${v} libnvinfer-vc-plugin8-${v} libnvinfer-dispatch8-${v} &&\ + dnf install -y dnf-plugin-versionlock &&\ + dnf versionlock libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 +RUN dnf clean dbcache + + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && /tmp/scripts/install_java.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/scripts/install_java.sh b/tools/ci_build/github/linux/docker/scripts/install_java.sh new file mode 100755 index 0000000000000..d11e29f693b8b --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/install_java.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e -x + +if [ -f /etc/redhat-release ]; then + dnf install -y java-11-openjdk-devel \ + && dnf clean dbcache +elif [ -f /etc/os-release ]; then + apt-get update && apt-get install -y openjdk-11-jdk +else + echo "Unsupported OS" + exit 1 +fi From af80542e65334979d3e37c1cb6c88068c5c65f3b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 15 Sep 2023 10:17:20 -0700 Subject: [PATCH 035/121] Update optimize_pipeline for SDXL (#17536) - [x] Optimize SDXL models exported by optimum. - [x] Enable it to run locally instead of using module. - [x] Detect external data file in original model, and save with same format by default. - [x] Add tests ### Example ``` pip install optimum transformers diffusers onnx onnxruntime-gpu>=1.16 optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 ``` ### Known issues (1) VAE decoder cannot be converted to float16. Otherwise, there will be black image in output. (2) To use the float16 models, need a minor change in optimum to convert the inputs for VAE decoder from float16 to float32 since we keep VAE decoder as float32. The change is to append a line like the following after [this line](https://github.com/huggingface/optimum/blob/afd2b5a36663bebd1f501486acee065c728947bc/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py#L483) ``` latents = latents.astype(np.float32) ``` --- .../models/stable_diffusion/README.md | 11 +- .../stable_diffusion/optimize_pipeline.py | 161 +++++++++++------- .../models/stable_diffusion/requirements.txt | 8 +- .../test_optimizer_stable_diffusion.py | 115 +++++++++++++ 4 files changed, 226 insertions(+), 69 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index facbd3bf6944e..7ffefdd05f215 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -154,6 +154,12 @@ curl https://raw.githubusercontent.com/huggingface/diffusers/v0.15.1/scripts/con python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd_v1_5/fp32 ``` +For SDXL, use optimum to export the model: +``` +pip install optimum diffusers onnx onnxruntime-gpu +optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx +``` + ### Optimize ONNX Pipeline Example to optimize the exported float32 ONNX models, and save to float16 models: @@ -161,7 +167,10 @@ Example to optimize the exported float32 ONNX models, and save to float16 models python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 ``` -If you installed ONNX Runtime v1.14, some optimizations (packed QKV and BiasAdd) will be disabled automatically since they are not available in v1.14. +For SDXL model, it is recommended to use a machine with 32 GB or more memory to optimize. +``` +python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 +``` ### Run Benchmark diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 22fee4bfeab29..4512c971ac27c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -5,19 +5,15 @@ # # This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference. # -# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint to float32 onnx models. +# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint +# to float32 onnx models. # -# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16 like the following: +# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16 +# like the following: # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16 # -# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support for the fused opeartors. -# In this case, the users should disable the operator fusion manually to workaround. -# -# Stable diffusion 2.1 model will get black images using float16 Attention. A walkaround is to force Attention to run in float32 like the following: -# python optimize_pipeline.py -i ./sd-v2-1 -o ./sd-v2-1-fp16 --float16 --force_fp32_ops unet:Attention -# -# If you are using nightly package (or built from source), you can force MultiHeadAttention to run in float32: -# python optimize_pipeline.py -i ./sd-v2-1 -o ./sd-v2-1-fp16 --float16 --force_fp32_ops unet:MultiHeadAttention +# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support +# for the fused opeartors. The users could disable the operator fusion manually to workaround. import argparse import logging @@ -25,8 +21,9 @@ import shutil import tempfile from pathlib import Path -from typing import List +from typing import List, Optional +import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs import onnx from fusion_options import FusionOptions @@ -41,11 +38,19 @@ logger = logging.getLogger(__name__) -def optimize_sd_pipeline( +def has_external_data(onnx_model_path): + original_model = onnx.load_model(str(onnx_model_path), load_external_data=False) + for initializer in original_model.graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL: + return True + return False + + +def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, overwrite: bool, - use_external_data_format: bool, + use_external_data_format: Optional[bool], float16: bool, force_fp32_ops: List[str], enable_runtime_optimization: bool, @@ -57,7 +62,7 @@ def optimize_sd_pipeline( source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. overwrite (bool): Overwrite files if exists. - use_external_data_format (bool): save onnx model to two files: one for onnx graph, another for weights + use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision force_fp32_ops(List[str]): operators that are forced to run in float32. enable_runtime_optimization(bool): run graph optimization using Onnx Runtime. @@ -71,6 +76,7 @@ def optimize_sd_pipeline( "vae_encoder": "vae", "vae_decoder": "vae", "text_encoder": "clip", + "text_encoder_2": "clip", "safety_checker": "unet", } @@ -85,9 +91,12 @@ def optimize_sd_pipeline( "vae_encoder": [], "vae_decoder": [], "text_encoder": [], + "text_encoder_2": [], "safety_checker": [], } + is_xl = (source_dir / "text_encoder_2").exists() + if force_fp32_ops: for fp32_operator in force_fp32_ops: parts = fp32_operator.split(":") @@ -100,26 +109,21 @@ def optimize_sd_pipeline( for name, model_type in model_type_mapping.items(): onnx_model_path = source_dir / name / "model.onnx" - if not os.path.exists(onnx_model_path): - message = f"input onnx model does not exist: {onnx_model_path}." - if name not in ["safety_checker"]: - raise RuntimeError(message) + if name != "safety_checker": + logger.info("input onnx model does not exist: %s", onnx_model_path) + # some model are optional so we do not raise error here. continue # Prepare output directory optimized_model_path = target_dir / name / "model.onnx" output_dir = optimized_model_path.parent - if optimized_model_path.exists(): - if not overwrite: - raise RuntimeError(f"output onnx model path existed: {optimized_model_path}") - - if output_dir.exists(): - shutil.rmtree(output_dir) output_dir.mkdir(parents=True, exist_ok=True) + if use_external_data_format is None: + use_external_data_format = has_external_data(onnx_model_path) + # Graph fusion before fp16 conversion, otherwise they cannot be fused later. - # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. logger.info(f"Optimize {onnx_model_path}...") args.model_type = model_type @@ -143,12 +147,15 @@ def optimize_sd_pipeline( ) if float16: - logger.info("Convert %s to float16 ...", name) - op_block_list = ["RandomNormalLike"] - m.convert_float_to_float16( - keep_io_types=False, - op_block_list=op_block_list + force_fp32_operators[name], - ) + # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. + if is_xl and name == "vae_decoder": + logger.info("Skip converting %s to float16 to avoid NaN", name) + else: + logger.info("Convert %s to float16 ...", name) + m.convert_float_to_float16( + keep_io_types=False, + op_block_list=force_fp32_operators[name], + ) if enable_runtime_optimization: # Use this step to see the final graph that executed by Onnx Runtime. @@ -174,35 +181,24 @@ def optimize_sd_pipeline( logger.info("*" * 20) -def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): +def _copy_extra_directory(source_dir: Path, target_dir: Path): """Copy extra directory that does not have onnx model Args: source_dir (Path): source directory target_dir (Path): target directory - overwrite (bool): overwrite if exists Raises: RuntimeError: source path does not exist - RuntimeError: output path exists but overwrite is false. """ - extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] + extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"] for name in extra_dirs: source_path = source_dir / name - if not os.path.exists(source_path): - message = f"source path does not exist: {source_path}" - if name not in ["feature_extractor"]: - raise RuntimeError(message) continue target_path = target_dir / name - if target_path.exists(): - if not overwrite: - raise RuntimeError(f"output path existed: {target_path}") - shutil.rmtree(target_path) - shutil.copytree(source_path, target_path) logger.info("%s => %s", source_path, target_path) @@ -213,15 +209,54 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): raise RuntimeError(f"source path does not exist: {source_path}") target_path = target_dir / name - if target_path.exists(): - if not overwrite: - raise RuntimeError(f"output path existed: {target_path}") - os.remove(target_path) shutil.copyfile(source_path, target_path) logger.info("%s => %s", source_path, target_path) + # Some directory are optional + onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"] + for onnx_model_dir in onnx_model_dirs: + source_path = source_dir / onnx_model_dir / "config.json" + target_path = target_dir / onnx_model_dir / "config.json" + if source_path.exists(): + target_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(source_path, target_path) + logger.info("%s => %s", source_path, target_path) + -def parse_arguments(): +def optimize_stable_diffusion_pipeline( + input_dir: str, + output_dir: str, + overwrite: bool, + use_external_data_format: Optional[bool], + float16: bool, + enable_runtime_optimization: bool, + args, +): + if os.path.exists(output_dir): + if args.overwrite: + shutil.rmtree(output_dir, ignore_errors=True) + else: + raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") + + source_dir = Path(input_dir) + target_dir = Path(output_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + _copy_extra_directory(source_dir, target_dir) + + _optimize_sd_pipeline( + source_dir, + target_dir, + overwrite, + use_external_data_format, + float16, + args.force_fp32_ops, + enable_runtime_optimization, + args, + ) + + +def parse_arguments(argv: Optional[List[str]] = None): """Parse arguments Returns: @@ -265,7 +300,8 @@ def parse_arguments(): "--inspect", required=False, action="store_true", - help="Inspect the optimized graph from Onnx Runtime for debugging purpose. This option has no impact on model performance.", + help="Save the optimized graph from Onnx Runtime. " + "This option has no impact on inference performance except it might reduce session creation time.", ) parser.set_defaults(inspect=False) @@ -283,32 +319,25 @@ def parse_arguments(): required=False, action="store_true", help="Onnx model larger than 2GB need to use external data format. " - "Save onnx model to two files: one for onnx graph, another for large weights.", + "If specifed, save each onnx model to two files: one for onnx graph, another for weights. " + "If not specified, use same format as original model by default. ", ) - parser.set_defaults(use_external_data_format=False) + parser.set_defaults(use_external_data_format=None) FusionOptions.add_arguments(parser) - args = parser.parse_args() + args = parser.parse_args(argv) return args -def main(): - coloredlogs.install(fmt="%(funcName)20s: %(message)s") - args = parse_arguments() +def main(argv: Optional[List[str]] = None): + args = parse_arguments(argv) logger.info("Arguments: %s", str(args)) - copy_extra_directory(Path(args.input), Path(args.output), args.overwrite) - optimize_sd_pipeline( - Path(args.input), - Path(args.output), - args.overwrite, - args.use_external_data_format, - args.float16, - args.force_fp32_ops, - args.inspect, - args, + optimize_stable_diffusion_pipeline( + args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args ) if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") main() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 68947a1618a76..d4e6c9fa07695 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,5 +1,5 @@ -diffusers>=0.15.1 -transformers>=4.26.0 +diffusers>=0.19.3 +transformers>=4.31.0 numpy>=1.24.1 accelerate onnx>=1.13.0 @@ -8,3 +8,7 @@ packaging protobuf==3.20.3 psutil sympy +# The following are for SDXL +optimum>=1.11.1 +safetensors +invisible_watermark diff --git a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py index cde6b56a6648c..dca250f39fae2 100644 --- a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py +++ b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py @@ -7,6 +7,7 @@ import shutil import unittest +import numpy as np import pytest from parity_utilities import find_transformers_source @@ -19,6 +20,12 @@ from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.optimizer import optimize_model +if find_transformers_source(["models", "stable_diffusion"]): + from optimize_pipeline import main as optimize_stable_diffusion +else: + from onnxruntime.transformers.models.stable_diffusion.optimize_pipeline import main as optimize_stable_diffusion + + TINY_MODELS = { "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", @@ -151,6 +158,114 @@ def test_clip_sdxl(self): }, ) + @pytest.mark.slow + def test_optimize_sdxl_fp32(self): + save_directory = "tiny-random-stable-diffusion-xl" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion-xl" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionXLPipeline + + baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) + if not os.path.exists(save_directory): + baseline.save_pretrained(save_directory) + + batch_size, num_images_per_prompt, height, width = 2, 2, 64, 64 + latents = baseline.prepare_latents( + batch_size * num_images_per_prompt, + baseline.unet.config["in_channels"], + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), + ) + + optimized_directory = "tiny-random-stable-diffusion-xl-optimized" + argv = [ + "--input", + save_directory, + "--output", + optimized_directory, + "--disable_group_norm", + "--disable_bias_splitgelu", + "--overwrite", + ] + optimize_stable_diffusion(argv) + + treatment = ORTStableDiffusionXLPipeline.from_pretrained(optimized_directory, provider="CUDAExecutionProvider") + inputs = { + "prompt": ["starry night by van gogh"] * batch_size, + "num_inference_steps": 3, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + "output_type": "np", + } + + ort_outputs_1 = baseline(latents=latents, **inputs) + ort_outputs_2 = treatment(latents=latents, **inputs) + self.assertTrue(np.allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-3)) + + @pytest.mark.slow + def test_optimize_sdxl_fp16(self): + """This tests optimized fp16 pipeline, and result is deterministic for a given seed""" + save_directory = "tiny-random-stable-diffusion-xl" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion-xl" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionXLPipeline + + baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) + if not os.path.exists(save_directory): + baseline.save_pretrained(save_directory) + + optimized_directory = "tiny-random-stable-diffusion-xl-optimized-fp16" + argv = [ + "--input", + save_directory, + "--output", + optimized_directory, + "--disable_group_norm", + "--disable_bias_splitgelu", + "--float16", + "--overwrite", + ] + optimize_stable_diffusion(argv) + + fp16_pipeline = ORTStableDiffusionXLPipeline.from_pretrained( + optimized_directory, provider="CUDAExecutionProvider" + ) + batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 + inputs = { + "prompt": ["starry night by van gogh"] * batch_size, + "num_inference_steps": 3, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + "output_type": "latent", + } + + seed = 123 + np.random.seed(seed) + ort_outputs_1 = fp16_pipeline(**inputs) + + np.random.seed(seed) + ort_outputs_2 = fp16_pipeline(**inputs) + + np.random.seed(seed) + ort_outputs_3 = fp16_pipeline(**inputs) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + if __name__ == "__main__": unittest.main() From adb0be45d38e83609e9941917d233bb43404ff36 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 15 Sep 2023 10:57:29 -0700 Subject: [PATCH 036/121] Refactoring of attention cuda kernel: move prepare qkv and concat_past_to_present (#17559) To avoid a huge cu file and make code more readable: - Move PrepareQKV to separate cu file (attention_prepare_qkv.cu) - Move ConcatPastToPresent to attention_concat.cu - Add default value for AttentionData - Add a data structure QkvData to track Q, K and V pointers and track QKV format. --- cmake/onnxruntime_rocm_hipify.cmake | 1 + .../contrib_ops/cuda/bert/attention.cc | 34 +- .../contrib_ops/cuda/bert/attention_concat.cu | 160 +++-- .../contrib_ops/cuda/bert/attention_impl.cu | 584 ++---------------- .../contrib_ops/cuda/bert/attention_impl.h | 98 +-- .../cuda/bert/attention_prepare_qkv.cu | 492 +++++++++++++++ .../cuda/bert/multihead_attention.cc | 3 - .../quantization/attention_quantization.cc | 31 +- 8 files changed, 729 insertions(+), 674 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 25202f82f468d..b483e6de81a47 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -10,6 +10,7 @@ set(contrib_ops_excluded_files "bert/attention_impl.cu" "bert/attention_softmax.h" "bert/attention_softmax.cu" + "bert/attention_prepare_qkv.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index a79ad96b94d91..f0385ea5abdfb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -249,30 +249,28 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr == bias ? nullptr : reinterpret_cast(bias->Data()); - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = (nullptr == relative_position_bias) - ? nullptr - : reinterpret_cast(relative_position_bias->Data()); + if (nullptr != bias) { + data.bias = reinterpret_cast(bias->Data()); + } + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + if (nullptr != past) { + data.past = reinterpret_cast(past->Data()); + } + if (nullptr != relative_position_bias) { + data.relative_position_bias = reinterpret_cast(relative_position_bias->Data()); + } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } data.fused_runner = reinterpret_cast(fused_runner); - data.fused_cross_attention_kernel = nullptr; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu index 5d9cfcc69773a..8378ee2691c6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu @@ -93,16 +93,16 @@ __global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, } Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const float* tensor_in, + const float* tensor_add, + float* tensor_out) { const dim3 grid(all_sequence_length, batch_size, matrix_num); if (0 == (head_size & 1)) { const int H = head_size / 2; @@ -137,16 +137,16 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, } Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const half* tensor_in, + const half* tensor_add, + half* tensor_out) { const dim3 grid(all_sequence_length, batch_size, matrix_num); if (0 == (head_size % 4)) { const int H = head_size / 4; @@ -197,15 +197,15 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, } Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* past, + const float* k_v, + float* present) { return LaunchConcatTensorToTensor( stream, all_sequence_length, @@ -221,15 +221,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, } Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* past, + const half* k_v, + half* present) { return LaunchConcatTensorToTensor( stream, all_sequence_length, @@ -244,6 +244,90 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, present); } +#ifndef USE_ROCM // exclude from hipify +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv) { + // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. + // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) + // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) + // When there is past state, the head size for Q/K/V shall be same: H == H_v. + + if (nullptr != data.present) { + assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + ORT_RETURN_IF_ERROR( + LaunchConcatPastToPresent( + stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, data.past, qkv.k, data.present)); + + // Update pointers to present_k and present_v. + qkv.k = data.present; + qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } + + if (nullptr != data.past_key || nullptr != data.present_key) { + if (nullptr != data.past_key && nullptr == data.present_key) { + qkv.k = const_cast(data.past_key); + qkv.v = const_cast(data.past_value); + } else if (nullptr == data.past_key && nullptr != data.present_key) { + if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) { + qkv.k = data.present_key; + qkv.v = data.present_value; + } else { + assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); + qkv.k = data.temp_k_workspace; + qkv.v = data.temp_v_workspace; + } + } else if (pass_past_in_kv) { + // past_key and past_value are used directly as key and value in attention computations + qkv.k = const_cast(data.past_key); + qkv.v = const_cast(data.past_value); + + // This path has a memory copy from past_key and past_value to present_key and present_value + // Avoid this path since the memory copy is unnecessary because past_key == present_key and + // past_value == present_value + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } else { + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, qkv.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, qkv.v, data.present_value)); + // Update pointers to present_k and present_v. + qkv.k = data.present_key; + qkv.v = data.present_value; + } + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Template Instantiation +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv); + +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv); +#endif + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index ae7696eb9fe0f..366d8fee1473b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -26,16 +26,11 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_softmax.h" #include "contrib_ops/cuda/bert/transformer_common.h" -#include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cpu/bert/attention_base.h" @@ -43,6 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -286,446 +282,6 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const half* qkv_buffer, half* present); -template -Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - void* fused_runner = data.fused_runner; - bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); - } - return Status::OK(); -} - -// For MultiHeadAttention with past state -template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); - - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; - } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); - - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed - constexpr int format = 0; - // Query (BxSxNxH) => Q (BxNxSxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); - - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } - return Status::OK(); -} - -// For MultiHeadAttention without past state, with packed QKV inputs -template -Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); - - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, - true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); - } - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } - return Status::OK(); -} - -// For MultiHeadAttention without past state, with packed KV inputs -template -Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); - LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, kv_bias, k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } - return Status::OK(); -} - -// For MultiHeadAttention without past state, with Q, K and V inputs -template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif - - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } - return Status::OK(); -} - -template -Status PrepareQkv(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } - - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - return Status::OK(); -} - template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -755,92 +311,22 @@ Status QkvToContext( const int batches = batch_size * num_heads; - T* qkv = nullptr; - T* q = nullptr; - T* k = nullptr; - T* v = nullptr; - T* scratch1 = data.workspace; - if (data.has_qkv_workspace) { - const int size_per_batch_q = sequence_length * qk_head_size; - const int size_per_batch_k = kv_sequence_length * qk_head_size; - const int size_per_batch_v = kv_sequence_length * v_head_size; - const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); - const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); - const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv = data.workspace; - q = qkv; - k = q + elements_q; - v = k + elements_k; - scratch1 = v + elements_v; - } - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); + QkvData qkv; + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, qkv)); + T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace; int present_size_per_batch_k = 0; int present_size_per_batch_v = 0; if (!past_present_share_buffer) { - // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. - // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) - // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) - // When there is past state, the head size for Q/K/V shall be same: H == H_v. present_size_per_batch_k = total_sequence_length * qk_head_size; present_size_per_batch_v = total_sequence_length * v_head_size; + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data, qkv)); - if (nullptr != data.present) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent( - stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, k, data.present)); - - // Update pointers to present_k and present_v. - k = data.present; - v = data.present + batches * present_size_per_batch_k; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - k = const_cast(data.past_key); - v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { - k = data.present_key; - v = data.present_value; - } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - k = data.temp_k_workspace; - v = data.temp_v_workspace; - } - } else if (parameters.pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - k = const_cast(data.past_key); - v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, v, data.present_value)); - // Update pointers to present_k and present_v. - k = data.present_key; - v = data.present_value; - } - } } else { // past_present_share_buffer assert(qk_head_size == v_head_size); assert(data.fused_cross_attention_kernel == nullptr); @@ -870,15 +356,15 @@ Status QkvToContext( present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; present_size_per_batch_v = present_size_per_batch_k; - k = data.present; - v = data.present + batches * present_size_per_batch_k; + qkv.k = data.present; + qkv.v = data.present + batches * present_size_per_batch_k; } // Q, K and V are ready now DUMP_TENSOR_INIT(); if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); + assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. @@ -902,8 +388,8 @@ Status QkvToContext( reinterpret_cast(data.fused_cross_attention_kernel); // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = q; - void const* packed_kv = k; + void const* query = qkv.q; + void const* packed_kv = qkv.k; if (data.value == nullptr && data.bias == nullptr) { query = data.query; packed_kv = data.key; @@ -951,10 +437,10 @@ Status QkvToContext( fused_fp16_runner->setup(S, B); if (use_fused_kernel) { - assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); + assert(qkv.format == AttentionQkvFormat::QKV_BSN3H); // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv; + void const* packed_qkv = qkv.q; if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { packed_qkv = data.query; } @@ -962,7 +448,7 @@ Status QkvToContext( fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); } @@ -975,22 +461,22 @@ Status QkvToContext( #if USE_FLASH_ATTENTION if (data.use_flash_attention) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); assert(nullptr == data.mask_index); assert(nullptr == data.relative_position_bias); assert(parameters.head_size == parameters.v_head_size); - void* query = reinterpret_cast(q); - void* key = reinterpret_cast(k); - void* value = reinterpret_cast(v); + void* query = reinterpret_cast(qkv.q); + void* key = reinterpret_cast(qkv.k); + void* value = reinterpret_cast(qkv.v); // For packed KV, we can use query input directly. if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { query = reinterpret_cast(const_cast(data.query)); } DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); constexpr bool is_causal = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( @@ -1008,11 +494,11 @@ Status QkvToContext( if (data.use_memory_efficient_attention) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - const void* query = q; - const void* key = k; - const void* value = v; + const void* query = qkv.q; + const void* key = qkv.k; + const void* value = qkv.v; // For packed KV, we can use query input directly. if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { assert(data.bias == nullptr); @@ -1020,8 +506,8 @@ Status QkvToContext( } DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -1061,7 +547,7 @@ Status QkvToContext( #endif // The following are unfused attention. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH); const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; @@ -1082,12 +568,12 @@ Status QkvToContext( CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, k, qk_head_size, present_size_per_batch_k, - q, qk_head_size, sequence_length * qk_head_size, + &alpha, qkv.k, qk_head_size, present_size_per_batch_k, + qkv.q, qk_head_size, sequence_length * qk_head_size, &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length); DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, @@ -1126,14 +612,14 @@ Status QkvToContext( } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv; + T* temp_output = qkv.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, v, v_head_size, present_size_per_batch_v, + &one, qkv.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index af7373dd9fa1b..c361a47c364d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -2,11 +2,12 @@ // Licensed under the MIT License. #pragma once -#include "core/providers/cuda/shared_inc/cuda_utils.h" + #include #include -#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/gsl.h" #include "core/framework/allocator.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { @@ -49,39 +50,56 @@ size_t GetAttentionWorkspaceSize( template struct AttentionData { - T* gemm_buffer; - const T* bias; + T* gemm_buffer = nullptr; + const T* bias = nullptr; - const T* query; - const T* key; - const T* value; - const int* mask_index; + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const int* mask_index = nullptr; gsl::span mask_index_dims; - const T* past; - const T* past_key; - const T* past_value; - const T* relative_position_bias; + const T* past = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + const T* relative_position_bias = nullptr; - bool has_qkv_workspace; - T* workspace; - T* temp_k_workspace; - T* temp_v_workspace; + bool has_qkv_workspace = false; + T* workspace = nullptr; + T* temp_k_workspace = nullptr; + T* temp_v_workspace = nullptr; - T* output; - T* present; - T* present_key; - T* present_value; + T* output = nullptr; + T* present = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; - void* fused_runner; - const void* fused_cross_attention_kernel; + void* fused_runner = nullptr; + const void* fused_cross_attention_kernel = nullptr; - bool use_flash_attention; - bool use_memory_efficient_attention; + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; }; +// Intermediate data pointers available after PrepareQKV +template +struct QkvData { + T* q = nullptr; + T* k = nullptr; + T* v = nullptr; + T* after_v = nullptr; // pointer right after v + AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH; +}; + +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv_data); + template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -161,27 +179,13 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, const half* tensor_add, half* tensor_out); -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present); - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present); +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv); template Status LaunchStridedCopy(cudaStream_t stream, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu new file mode 100644 index 0000000000000..cd4137ab11de9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -0,0 +1,492 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const bool past_present_share_buffer = parameters.past_present_share_buffer; + void* fused_runner = data.fused_runner; + bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + if (data.bias == nullptr) { + assert(nullptr == fused_runner); + // For quantized attention, bias has been added so only need transpose here. + // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH + assert(qk_head_size == v_head_size); + int matrix_to_trans = (past_present_share_buffer ? 1 : 3); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.gemm_buffer, qkv, 3)); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } else { + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.past_sequence_length); + } + return Status::OK(); +} + +// For MultiHeadAttention with past state +template +Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + + if (data.bias == nullptr) { + // Below logic does not support fused attention with past without bias + // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. + + // cross attention with past state + if (data.past_key != nullptr && data.present_key == nullptr) { + assert(data.past_value != nullptr); + assert(data.query != nullptr); + assert(data.key == nullptr); + assert(data.value == nullptr); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + } + // cross attention with present state or self attention with present state + else if (data.past_key == nullptr && data.present_key != nullptr) { + assert(data.past_value == nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + + // TODO: supporting packed kv for cross attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.present_key)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.present_value)); + } + // self attention with past and present state + else { + assert(data.past_key != nullptr); + assert(data.past_value != nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, v)); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.past_key != nullptr && + data.past_value != nullptr && + parameters.pass_past_in_kv) { + // Transpose past_key and past_value to use memory efficient attention + + // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_key, data.temp_k_workspace)); + // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_value, data.temp_v_workspace)); + + // query => q, temp_k_workspace => k, temp_v_workspace => v + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + data.past_key = nullptr; + data.past_value = nullptr; + } + // When there is no past_key/past_value and there is present_key/present_value + // (e.g. get initial kv to use as past_kv in the next iteration) + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.present_key != nullptr && + data.present_value != nullptr) { + // Use memory efficient attention kernel + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); + + // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.temp_k_workspace, data.present_key)); + + // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else { + // Use unfused kernel for Q, use unfused kernel for K and V if needed + constexpr int format = 0; + // Query (BxSxNxH) => Q (BxNxSxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + if (!parameters.pass_past_in_kv) { + T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; + T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, k_dest, + true, -1); + + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed QKV inputs +template +Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed KV inputs +template +Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + // gemm_buffer == nullptr and not packed + assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + +#if DUMP_TENSOR_LEVEL > 1 + if (data.bias != nullptr) { + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } +#endif + + if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + } + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else if (use_fused_kernel) { + assert(qk_head_size == v_head_size); + + // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); + + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); + + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv) { + if (data.has_qkv_workspace) { + const int size_per_batch_q = parameters.sequence_length * parameters.head_size; + const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; + const int size_per_batch_v = parameters.kv_sequence_length * parameters.v_head_size; + const int batches = parameters.batch_size * parameters.num_heads; + const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); + const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); + const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); + qkv.q = data.workspace; + qkv.k = data.workspace + elements_q; + qkv.v = qkv.k + elements_k; + qkv.after_v = qkv.v + elements_v; + } + + if (nullptr != data.gemm_buffer) { // Attention operator + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, + qkv.format)); + } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } else { // multihead attention operator, no past, separated Q/K/V inputs + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); +} + +// Template Instantiation +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv); + +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 8f1252f863ef6..25f3f59165e43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -263,14 +263,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; - data.gemm_buffer = nullptr; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past = nullptr; data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) : (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); @@ -283,7 +281,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); - data.present = nullptr; data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index b0556512de0b7..705f2d49fe2bf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -195,28 +195,21 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr; // bias has been added - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + + if (nullptr != past_tensor) { + data.past = reinterpret_cast(past_tensor->Data()); + } + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; - data.fused_runner = fused_runner; - data.fused_cross_attention_kernel = nullptr; - data.use_flash_attention = use_flash_attention; - data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } return QkvToContext(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data); } From fdb132643d8e07d383d6cd9af1c19f547fe1718f Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Sep 2023 12:13:37 -0700 Subject: [PATCH 037/121] Remove redundant Resolve() after each inlined function (#17556) ### Description Remove `Resolve()` on the entire graph as each function is resolved. We retain `Resolve()` after each inlining iteration. ### Motivation and Context Poor performance for inlining the model and session initialization. Original model before Resolve() removal FunctionTest.Profiling (**65953 ms**) After Resolve() Removal FunctionTest.Profiling (**2911 ms**) RelWithDebInfo pre-inlined model. Presumably because it runs Level1 optimizers Non-inlined model consists of functions and Level1 optimizers have no effect. FunctionTest.Profiling (**9851 ms**) --- include/onnxruntime/core/graph/graph.h | 1 + onnxruntime/core/graph/graph.cc | 2 -- .../orttraining/test/training_ops/function_op_test_utils.cc | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 19caa69d94ccf..f153e88909b8d 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1135,6 +1135,7 @@ class Graph { /** Directly insert the nodes in the function Node provided into this Graph. + The Graph needs to be Resolve()d after this call. @param node Node with Node::Type of Node::Type::Fused @returns Status indicating success or providing an error message. */ diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index d4164681f2bba..383c1d689d3c3 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4145,8 +4145,6 @@ Status Graph::InlineFunction(Node& callnode) { // std::cout << "Graph after inlining\n\n" << *this << std::endl << std::flush; - ORT_RETURN_IF_ERROR(this->Resolve()); - return Status::OK(); } diff --git a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc index 5eed4765abfd7..9504ba2c1e69a 100644 --- a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc +++ b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc @@ -25,8 +25,8 @@ void OpFunctionTester::RunFunctionBodyGraphOnCPU(TwoDArray& results) { auto& node = *graph.Nodes().begin(); ASSERT_EQ(node.OpType(), op); - // Inline function will call Resolve itself ASSERT_STATUS_OK(graph.InlineFunction(node)); + ASSERT_STATUS_OK(graph.Resolve()); // Hookup the inputs and outputs std::unordered_map feeds; From 155887593db4e95d3b808ec605e9b089daab4ef0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:55:25 -0700 Subject: [PATCH 038/121] [js/web] update npm test to load test cases only for required backends (#17555) ### Description update npm test to load test cases for required backends. No need to load test case list for the backends that we don't test. --- js/web/script/test-runner-cli.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 520ef62b2c719..a75321d45f1ef 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -84,8 +84,10 @@ async function main() { .flat(); for (const backend of DEFAULT_BACKENDS) { - nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders)); - opTests.set(backend, loadOpTests(backend)); + if (args.backends.indexOf(backend) !== -1) { + nodeTests.set(backend, loadNodeTests(backend, allNodeTestsFolders)); + opTests.set(backend, loadOpTests(backend)); + } } } From efd416b71f0925108ff7cb0e83e99beabe7c05cd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 15 Sep 2023 14:40:22 -0700 Subject: [PATCH 039/121] [js/web] update test to explicitly fail for webnn without proxy (#17554) ### Description Update test to explicitly fail for webnn without proxy. I am doing this change because if I test webnn with other backend together, it silently enables proxy. I want to make test runner behave with less implicit flag reset. If proxy is not enabled, webnn test should fail. @Honry please let me know if other places (eg. CI scripts) should change also. --- js/web/script/test-runner-cli-args.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 7b41850948149..f90f568879146 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -382,8 +382,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - // Backend webnn is restricted in the dedicated worker. - globalEnvFlags.wasm!.proxy = true; + throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); } // Options: From 705f8a371886828dcd2a380d10ba3a6549a60b9b Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:16:11 -0700 Subject: [PATCH 040/121] [TensorRT EP] Fallback to CUDA EP if it's explicitly assigned (#17535) ### Description * TensorRT EP can fall back to CUDA EP if it's explicitly assigned * MIGraphX can fall back to ROCM if it's explicitly assigned Test cases: | When user specifies providers= | self._fallback_providers= | | ------------------------------------------------------------ | ------------------------------------------------- | | ["TensorrtExecutionProvider", "CUDAExecutionProvider"] | ["CUDAExecutionProvider", "CPUExecutionProvider"] | | ["TensorrtExecutionProvider",("CUDAExecutionProvider", cuda_options)] | ["CUDAExecutionProvider", "CPUExecutionProvider"] | | ["TensorrtExecutionProvider"] | ["CPUExecutionProvider"] | | [("TensorrtExecutionProvider", trt_options)] | ["CPUExecutionProvider"] | | [("TensorrtExecutionProvider", trt_options), ("CUDAExecutionProvider", cuda_options)] | ["CUDAExecutionProvider", "CPUExecutionProvider"] | | ["TensorrtExecutionProvider", "CPUExecutionProvider"] | ["CPUExecutionProvider"] | ### Motivation and Context Apply comments of https://github.com/microsoft/onnxruntime/issues/17394 and unify the logic to [MIGraphX, ROCM] --- .../onnxruntime_inference_collection.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index b73fcbbff5456..4124822adef1f 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -420,8 +420,10 @@ def __init__( except (ValueError, RuntimeError) as e: if self._enable_fallback: try: + print("*************** EP Error ***************") print(f"EP Error {e} when using {providers}") print(f"Falling back to {self._fallback_providers} and retrying.") + print("****************************************") self._create_inference_session(self._fallback_providers, None) # Fallback only once. self.disable_fallback() @@ -434,11 +436,26 @@ def __init__( def _create_inference_session(self, providers, provider_options, disabled_optimizers=None): available_providers = C.get_available_providers() - # Tensorrt can fall back to CUDA. All others fall back to CPU. + # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. if "TensorrtExecutionProvider" in available_providers: - self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + if any( + provider == "CUDAExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") + for provider in providers + ): + self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + self._fallback_providers = ["CPUExecutionProvider"] + # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. elif "MIGraphXExecutionProvider" in available_providers: - self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + if any( + provider == "ROCMExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") + for provider in providers + ): + self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + else: + self._fallback_providers = ["CPUExecutionProvider"] else: self._fallback_providers = ["CPUExecutionProvider"] From c96923732141bf3823f1b5acb666c407f07fa9c2 Mon Sep 17 00:00:00 2001 From: simonjub <78098752+simonjub@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:19:32 -0400 Subject: [PATCH 041/121] [TRT EP] Fix ProviderOptions functions (#17567) ### Description When trying to use the TRT EP option trt_extra_plugin_lib_paths I noticed that my custom op library was not being loaded by the EP. After some digging I found that code was missing to update this option when UpdateTensorRTProviderOptions() is used to set it. At the same time I noticed that char arrays were allocated in that function and wondered where they are de-allocated. When I found it was done in ReleaseTensorRTProviderOptions(), I noticed that a few de-allocations were missing. ### Motivation and Context This PR fixes the problems described above. --- .../tensorrt/tensorrt_provider_factory.cc | 14 ++++++++++++++ onnxruntime/core/session/provider_bridge_ort.cc | 11 ++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 90aeeb64c9d24..18ec113734b97 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -202,6 +202,20 @@ struct Tensorrt_Provider : Provider { trt_options.trt_tactic_sources = (const char*)dest; } + str_size = internal_options.extra_plugin_lib_paths.size(); + if (str_size == 0) { + trt_options.trt_extra_plugin_lib_paths = nullptr; + } else { + dest = new char[str_size + 1]; +#ifdef _MSC_VER + strncpy_s(dest, str_size + 1, internal_options.extra_plugin_lib_paths.c_str(), str_size); +#else + strncpy(dest, internal_options.extra_plugin_lib_paths.c_str(), str_size); +#endif + dest[str_size] = '\0'; + trt_options.trt_extra_plugin_lib_paths = (const char*)dest; + } + str_size = internal_options.profile_min_shapes.size(); if (str_size == 0) { trt_options.trt_profile_min_shapes = nullptr; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 8f0a5aeaa3975..bf7a3bbd9d380 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1930,9 +1930,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorRTProviderOptionsByName, ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* ptr) { #ifdef USE_TENSORRT if (ptr != nullptr) { - delete ptr->trt_int8_calibration_table_name; - delete ptr->trt_engine_cache_path; - delete ptr->trt_engine_decryption_lib_path; + delete[] ptr->trt_int8_calibration_table_name; + delete[] ptr->trt_engine_cache_path; + delete[] ptr->trt_engine_decryption_lib_path; + delete[] ptr->trt_tactic_sources; + delete[] ptr->trt_extra_plugin_lib_paths; + delete[] ptr->trt_profile_min_shapes; + delete[] ptr->trt_profile_max_shapes; + delete[] ptr->trt_profile_opt_shapes; } std::unique_ptr p(ptr); From af14ae8050aa77d40031e7794fb2bf38bf30acef Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Mon, 18 Sep 2023 13:34:39 +0800 Subject: [PATCH 042/121] [ROCm] Update whisper benchmark script (#17391) - update whisper benchmark for ROCm EP. --- .../tools/transformers/benchmark_helper.py | 172 ++++++++++++------ .../transformers/models/whisper/benchmark.py | 60 ++++-- .../models/whisper/benchmark_all.py | 57 +++--- 3 files changed, 198 insertions(+), 91 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index f4d3f2fa1c317..67d3c95922a87 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -8,7 +8,10 @@ import logging import os import random +import sys +import time import timeit +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import Enum @@ -439,68 +442,127 @@ def get_gpu_info() -> Optional[List[Dict[str, Any]]]: return None -def measure_memory(is_gpu, func): - class MemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring +class MemoryMonitor(ABC): + def __init__(self, keep_measuring=True): + self.keep_measuring = keep_measuring - def measure_cpu_usage(self): - import psutil + def measure_cpu_usage(self): + import psutil - max_usage = 0 + max_usage = 0 + while True: + max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2) + sleep(0.005) # 5ms + if not self.keep_measuring: + break + return max_usage + + @abstractmethod + def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + raise NotImplementedError() + + +class CudaMemoryMonitor(MemoryMonitor): + def __init__(self, keep_measuring=True): + super().__init__(keep_measuring) + + def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + from py3nvml.py3nvml import ( + NVMLError, + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlDeviceGetName, + nvmlInit, + nvmlShutdown, + ) + + max_gpu_usage = [] + gpu_name = [] + try: + nvmlInit() + device_count = nvmlDeviceGetCount() + if not isinstance(device_count, int): + logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}") + return None + + max_gpu_usage = [0 for i in range(device_count)] + gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] while True: - max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2) + for i in range(device_count): + info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) + if isinstance(info, str): + logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}") + return None + max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) sleep(0.005) # 5ms if not self.keep_measuring: break - return max_usage - - def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) + nvmlShutdown() + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(device_count) + ] + except NVMLError as error: + logger.error("Error fetching GPU information using nvml: %s", error) + return None - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - sleep(0.005) # 5ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - logger.error("Error fetching GPU information using nvml: %s", error) - return None - monitor = MemoryMonitor(False) +class RocmMemoryMonitor(MemoryMonitor): + def __init__(self, keep_measuring=True): + super().__init__(keep_measuring) + rocm_smi_path = "/opt/rocm/libexec/rocm_smi" + if os.path.exists(rocm_smi_path): + if rocm_smi_path not in sys.path: + sys.path.append(rocm_smi_path) + try: + import rocm_smi + + self.rocm_smi = rocm_smi + self.rocm_smi.initializeRsmi() + except ImportError: + self.rocm_smi = None + + def get_used_memory(self, dev): + if self.rocm_smi is None: + return -1 + return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 + + def measure_gpu_usage(self): + if self.rocm_smi is None: + return None + + device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 + max_gpu_usage = [0 for i in range(device_count)] + gpu_name = [f"GPU{i}" for i in range(device_count)] + while True: + for i in range(device_count): + max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) + time.sleep(0.005) # 2ms + if not self.keep_measuring: + break + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(device_count) + ] + + +def measure_memory(is_gpu, func, monitor_type="cuda"): + memory_monitor_type = None + if monitor_type == "rocm": + memory_monitor_type = RocmMemoryMonitor + else: + memory_monitor_type = CudaMemoryMonitor + + monitor = memory_monitor_type(False) if is_gpu: memory_before_test = monitor.measure_gpu_usage() @@ -508,7 +570,7 @@ def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: return None with ThreadPoolExecutor() as executor: - monitor = MemoryMonitor() + monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) try: fn_thread = executor.submit(func) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 07995f0a38e26..283528bea7465 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -11,7 +11,7 @@ import psutil import torch import whisper -from benchmark_helper import setup_logger +from benchmark_helper import measure_memory, setup_logger from onnxruntime_extensions import get_library_path from optimum.onnxruntime import ORTModelForSpeechSeq2Seq from torch.profiler import ProfilerActivity, profile, record_function @@ -19,7 +19,6 @@ from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory logger = logging.getLogger(__name__) @@ -123,6 +122,9 @@ def get_model(args: argparse.Namespace): if args.verbose: sess_options.log_verbosity_level = 1 sess_options.log_severity_level = 1 + if args.tune: + ort.set_default_logger_severity(0) + ort.set_default_logger_verbosity(0) else: raise Exception(f"Cannot recognize {args.benchmark_type}") @@ -159,6 +161,9 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): + warmup_inputs = inputs[0] if type(inputs) is tuple else inputs + benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + # Warm up warmup_range = ( range(args.warmup_runs) @@ -167,11 +172,11 @@ def time_fn(args, fn, inputs): ) if args.verbose: - outputs = fn(inputs) + outputs = fn(warmup_inputs) logger.info(outputs) for _ in warmup_range: - fn(inputs) + fn(warmup_inputs) # Benchmark if args.device != "cpu": @@ -184,7 +189,7 @@ def time_fn(args, fn, inputs): else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: - fn(inputs) + fn(benchmark_inputs) if args.device != "cpu": torch.cuda.synchronize() @@ -244,7 +249,7 @@ def measure_fn(args, fn, inputs): # Measure memory usage gc.collect() torch.cuda.empty_cache() - measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs)) + measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type) # Flush output so memory usage is printed sys.stdout.flush() @@ -255,7 +260,7 @@ def run_hf_inference(args, inputs, model): def get_pred_ids(inputs): # Inference pass with predicted token ids generation predicted_ids = model.generate(**inputs) - return predicted_ids, [""] + return predicted_ids def gen_and_dec(inputs): # Inference pass with generation and decoding @@ -315,7 +320,7 @@ def gen_and_dec(inputs): def run_ort_inference(args, inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, warmup=False): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -324,6 +329,9 @@ def prepare_ort_inputs(inputs): logger.error(f"The following model inputs are missing: {missing_inputs}") raise Exception("There are missing inputs to the model. Please add them and try again.") + if warmup and args.tune: + inputs["min_length"] = inputs["max_length"] + # Remove unnecessary inputs from model inputs unnecessary_inputs = user_inputs - model_inputs if len(unnecessary_inputs): @@ -352,6 +360,13 @@ def without_io_binding(inputs): outputs = model.run(None, inputs) return outputs + def handle_output(output): + if args.eos_token_id in output: + first_end = np.where(output == args.eos_token_id)[0][0] + return output[: first_end + 1] + + return output + generate_fn = with_io_binding if args.device != "cpu" else without_io_binding ort_inputs = prepare_ort_inputs(inputs) @@ -367,7 +382,12 @@ def without_io_binding(inputs): # ORT evaluation logger.info("\nEvaluating ONNX Runtime...") - time_fn(args, generate_fn, ort_inputs) + ort_evaluate_inputs = ort_inputs + if args.tune: + ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True) + ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs) + + time_fn(args, generate_fn, ort_evaluate_inputs) ort_outputs = generate_fn(ort_inputs) if args.device != "cpu": ort_outputs = ort_outputs.copy_outputs_to_cpu() @@ -378,7 +398,10 @@ def without_io_binding(inputs): logger.info(f"Transcription: {ort_outputs[0][0]}") else: # convert_to_onnx model produces generated ids - logger.info(f"Generated token length: {len(ort_outputs[0][0])} tokens") + actual_output = handle_output(ort_outputs[0][0]) + logger.info(f"Generated token length: {len(actual_output)} tokens") + transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0] + logger.info(f"Transcription: {transcription}") measure_fn(args, generate_fn, ort_inputs) @@ -483,6 +506,12 @@ def parse_args(): parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display") parser.add_argument("--verbose", default=False, action="store_true") parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files") + parser.add_argument( + "--tune", + default=False, + action="store_true", + help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel", + ) args = parser.parse_args() @@ -490,13 +519,21 @@ def parse_args(): np.random.seed(args.seed) torch.manual_seed(args.seed) + args.monitor_type = args.device # Set runtime properties if "ort" in args.benchmark_type: args.execution_provider = f"{args.device.upper()}ExecutionProvider" if args.execution_provider == "CUDAExecutionProvider": args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = ( + args.execution_provider, + { + "device_id": args.device_id, + "tunable_op_enable": 1, + "tunable_op_tuning_enable": 1 if args.tune else 0, + }, + ) args.device = "cuda" # Check that model paths have been specified for any benchmarking with ORT @@ -527,6 +564,7 @@ def main(): setattr(args, "target_device", target_device) # noqa: B010 setattr(args, "use_fp16", use_fp16) # noqa: B010 setattr(args, "has_audio_stream", False) # noqa: B010 + setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010 logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}") diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index f12723f1af2df..08d7befec3cfd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -109,6 +109,8 @@ def get_args(): help="Number of mins to attempt the benchmark before moving on", ) + parser.add_argument("--tune", default=False, action="store_true") + args = parser.parse_args() setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010 @@ -292,6 +294,7 @@ def main(): ort_decoder_input_ids_cmd = ( ["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else [] ) + ort_tune_cmd = ["--tune"] if args.tune else [] all_results = [] for audio_file in os.listdir(args.audio_path): @@ -395,31 +398,35 @@ def main(): # Benchmark ONNX Runtime if args.ort_model_path: - benchmark_cmd = [ # noqa: RUF005 - "python3", - "-m", - "models.whisper.benchmark", - "--audio-path", - audio_path, - "--benchmark-type", - "ort", - "--ort-model-path", - args.ort_model_path, - "--model-name", - args.model_name, - "--precision", - args.precision, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - ] + ort_decoder_input_ids_cmd + benchmark_cmd = ( + [ # noqa: RUF005 + "python3", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "ort", + "--ort-model-path", + args.ort_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + + ort_decoder_input_ids_cmd + + ort_tune_cmd + ) logger.info("Benchmark ONNX Runtime") results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration) all_results.extend(results) From dea425e7c140a7216727421c434a1c5a3fb1f54d Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 18 Sep 2023 09:43:34 -0700 Subject: [PATCH 043/121] [QNN/CPU EP] Add 16-bit Quantize/Dequantize contrib ops (#17015) ### Description - Adds 16-bit integer support to: - Quantization kernel implementations: Intel, Neon, and Power intrinsics - DequantizeLinear and QuantizeLinear contrib ops - QNN EP Quantize and Dequantize operators - Python quantization scripts - Disables QDQ fusions for most 16-bit QDQ node groups (need to add 16-bit support to QLinear* ops) - Retains support for dropping QDQ nodes from Split, Gather, Reshape, Transpose, Squeeze, and Unsqueeze node groups. Sample python code to generate QDQ model with 16-bit activations and 8-bit weights: ```python quantize_static( input_model_path, output_model_path, data_reader, quant_format=args.quant_format, per_channel=args.per_channel, activation_type=QuantType.QUInt16, weight_type=QuantType.QUInt8, extra_options={"DedicatedQDQPair": True, "ForceQuantizeNoInputCheck": True, "UseQDQContribOps": True}, ) ``` Note that enabling the `UseQDQContribOps` extra option is not strictly necessary. If the 16bit types are used without enabling `UseQDQContribOps`, the QDQ ops domains are overridden to 'com.microsoft', and a warning is printed to stdout. ### Automated Tests MLAS/CPU EP: - [x] 16-bit QuantizeLinear computation - [x] 16-bit DequantizeLinear computation Optimizer: - [x] Transpose QDQ fusion - [x] Gather QDQ fusion - [x] Reshape QDQ fusion - [x] Squeeze QDQ fusion - [x] Unsqueeze QDQ fusion - [x] Split drop QDQ - [x] DoubleQDQPairRemover - [x] Transpose optimization - [x] EnsureUniqueDQForNodeUnit - [x] Common subexpression elimination (DQ not removed) - [x] Constant folding QNN EP: - [x] Conv 16-bit activations, 8-bit weights - [x] MatMul 16-bit activations, 8-bit weights - [x] Unary 16-bit QDQ ops - [x] Binary 16-bit QDQ ops Quantization tool: - [x] Test creation of 16-bit QDQ model ### Motivation and Context Support mixed precision (8bit weights, 16bit activations) models. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- docs/ContribOperators.md | 13 +- docs/OperatorKernels.md | 4 +- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 8 + .../cpu/quantization/quantize_ops.cc | 56 -- .../graph/contrib_ops/quantization_defs.cc | 16 +- onnxruntime/core/mlas/lib/mlasi.h | 24 + onnxruntime/core/mlas/lib/platform.cpp | 4 + .../core/mlas/lib/power/QuantizePower.cpp | 39 +- onnxruntime/core/mlas/lib/quantize.cpp | 321 ++++++++- .../optimizer/double_qdq_pairs_remover.cc | 283 ++++---- .../core/optimizer/double_qdq_pairs_remover.h | 32 +- .../qdq_selector_action_transformer.cc | 37 +- .../selectors_actions/qdq_selectors.cc | 101 ++- .../selectors_actions/qdq_selectors.h | 97 ++- .../onnx_transpose_optimization.cc | 2 +- .../cpu/quantization/quantize_linear.cc | 133 +++- .../cpu/quantization/quantize_linear.h | 45 -- .../builder/opbuilder/simple_op_builder.cc | 124 +++- .../qnn/builder/qnn_model_wrapper.cc | 10 + .../tools/quantization/onnx_quantizer.py | 2 +- .../tools/quantization/qdq_quantizer.py | 20 + .../python/tools/quantization/quant_utils.py | 63 +- .../python/tools/quantization/quantize.py | 10 + .../test/contrib_ops/quantize_ops_test.cc | 81 ++- .../mlas/unittest/test_quantizelinear.cpp | 40 +- .../ensure_unique_dq_for_node_unit_test.cc | 21 +- .../test/optimizer/graph_transform_test.cc | 282 ++++---- .../optimizer/graph_transform_test_builder.h | 12 + .../test/optimizer/qdq_transformer_test.cc | 433 ++++++++---- .../optimizer/transpose_optimizer_test.cc | 442 ++++++------ onnxruntime/test/providers/qnn/conv_test.cc | 653 ++++++++++++------ .../test/providers/qnn/matmul_test.cpp | 102 ++- .../test/providers/qnn/qnn_basic_test.cc | 2 +- .../test/providers/qnn/qnn_test_utils.cc | 8 +- .../test/providers/qnn/qnn_test_utils.h | 18 +- .../test/providers/qnn/reduce_op_test.cc | 2 +- .../test/providers/qnn/simple_op_htp_test.cc | 306 +++++++- .../test/python/quantization/test_qdq.py | 47 +- .../transform/convert_qdq_ops_to_ms_domain.py | 155 ++++- ...olding_dequantizelinear.qdq16_contrib.onnx | Bin 0 -> 3938 bytes ..._node_unit.graph_output.qdq16_contrib.onnx | Bin 0 -> 1433 bytes ...t_folding_qdq_node_unit.qdq16_contrib.onnx | Bin 0 -> 2061 bytes ...consumer_dq_nodes.fixed.qdq16_contrib.onnx | Bin 0 -> 272066 bytes 43 files changed, 2837 insertions(+), 1211 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc delete mode 100644 onnxruntime/core/providers/cpu/quantization/quantize_linear.h create mode 100644 onnxruntime/test/testdata/transform/fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx create mode 100644 onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5bd1a89c0dea1..95dc8c3cde46c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1351,8 +1351,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int8), tensor(uint8), tensor(int32)
-
Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit signed integer tensors.
+
T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+
Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
T2 : tensor(float16), tensor(float)
Constrain 'y', 'x_scale' to float tensors.
@@ -4194,8 +4194,9 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QuantizeLinear** The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor. - The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. - For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. + The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8, + [0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even. + Refer to https://en.wikipedia.org/wiki/Rounding for details. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis'). #### Version @@ -4232,8 +4233,8 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float16), tensor(float)
Constrain 'x', 'y_scale' to float tensors.
-
T2 : tensor(int8), tensor(uint8)
-
Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.
+
T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+
Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d46f3ed9bd262..33c187a28b62e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -439,7 +439,7 @@ Do not modify directly.* |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| -|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float)| +|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)| @@ -472,7 +472,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 660c8bd9e0624..0ec5088808656 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -56,9 +56,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLine class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid); @@ -191,9 +195,13 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc b/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc deleted file mode 100644 index 28a304bfc7f0e..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cpu/quantization/quantize_linear.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace contrib { - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - DequantizeLinear, - 1, - uint8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - DequantizeLinear, - 1, - int8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - DequantizeLinear, - 1, - int32_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - QuantizeLinear, - 1, - uint8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - QuantizeLinear, - 1, - int8_t, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index aa2ad9f1ff6b1..4313fae767fe5 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -136,8 +136,9 @@ Performs element-wise binary {name} on 8 bit data types (with Numpy-style broadc static const char* QuantizeLinear_ver1_doc = R"DOC( The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor. -The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. -For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. +The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8, +[0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even. +Refer to https://en.wikipedia.org/wiki/Rounding for details. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -161,8 +162,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T2", OpSchema::Optional) .Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2") .TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.") - .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, - "Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"}, + "Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.") .SetDoc(QuantizeLinear_ver1_doc) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { if (ctx.getNumInputs() == 3 && ctx.getInputType(2) != nullptr) { @@ -202,9 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1, "T1", OpSchema::Optional) .Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "T2") - .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int32)"}, - "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit " - "signed integer tensors.") + .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", + "tensor(uint16)", "tensor(int32)"}, + "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, " + "16-bit integer tensors, or 32-bit signed integer tensors.") .TypeConstraint("T2", {"tensor(float16)", "tensor(float)"}, "Constrain 'y', 'x_scale' to float tensors.") .SetDoc(DequantizeLinear_ver1_doc) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index f517be185b3fa..b6ac4a1ca1d6c 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -633,6 +633,24 @@ void int8_t ZeroPoint ); +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_U16_KERNEL)( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_S16_KERNEL)( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -749,6 +767,8 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; @@ -959,6 +979,8 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; @@ -986,6 +1008,8 @@ struct MLAS_PLATFORM { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 86b7450a7c4e5..7e2b117d6f249 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -230,6 +230,8 @@ Return Value: this->QLinearAddU8Kernel = MlasQLinearAddU8Kernel; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -475,6 +477,8 @@ Return Value: this->GemmDoubleKernel = MlasDgemmKernel; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; #if defined(__linux__) unsigned long hwcap2 = getauxval(AT_HWCAP2); diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 0d38288c6d42c..830a3a6a492db 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -1,3 +1,4 @@ +#include #include "mlasi.h" #include @@ -82,8 +83,15 @@ Return Value: auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); - auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, (int8_t *) Output); + + if constexpr (std::is_same_v || std::is_same_v) { + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, Output); + } else { + static_assert(std::is_same_v || std::is_same_v); + vec_xst(ShortVector0, 0, Output); + vec_xst(ShortVector1, 0, &Output[8]); + } Output += 16; Input += 16; @@ -124,3 +132,30 @@ MlasQuantizeLinearS8Kernel( { MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } + +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index c6e8af38c0020..133ad79594c55 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -21,6 +21,7 @@ Module Name: #include "mlasi.h" #if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#include // // QuantizeLinear implementation using NEON or SSE2 intrinsics. @@ -79,6 +80,20 @@ MlasQuantizeLinearPackBytes( MLAS_INT32X4 IntegerVector ); +template +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ); + +template +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ); + #if defined(MLAS_NEON64_INTRINSICS) template @@ -100,6 +115,104 @@ MlasQuantizeLinearPackBytes( return vreinterpretq_s32_u8(ByteVector); } +template<> +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + // + // Swizzle the least significant u16 from each int32_t element to the + // bottom eight bytes of the vector register. + // + + uint16x8_t WordVector = vreinterpretq_u16_s32(IntegerVector); + WordVector = vuzp1q_u16(WordVector, WordVector); + return vreinterpretq_s32_u16(WordVector); +} + +template<> +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + // + // Swizzle the least significant u16 from each int32_t element to the + // bottom eight bytes of the vector register. + // + + int16x8_t WordVector = vreinterpretq_s16_s32(IntegerVector); + WordVector = vuzp1q_s16(WordVector, WordVector); + return vreinterpretq_s32_s16(WordVector); +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + vst1q_lane_s32(reinterpret_cast(Output), IntegerVector, 0); + } else { + static_assert(std::is_same_v || std::is_same_v); + vst1q_lane_s64(reinterpret_cast(Output), vreinterpretq_s64_s32(IntegerVector), 0); + } +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + uint8_t* Output + ) +{ + // Copies the lower 8-bit element of the vector into memory (Output). + vst1q_lane_u8(Output, vreinterpretq_u8_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + int8_t* Output + ) +{ + // Copies the lower 8-bit element of the vector into memory (Output). + vst1q_lane_s8(Output, vreinterpretq_s8_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + uint16_t* Output + ) +{ + // Copies the lower 16-bit element of the vector into memory (Output). + vst1q_lane_u16(Output, vreinterpretq_u16_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + int16_t* Output + ) +{ + // Copies the lower 16-bit element of the vector into memory (Output). + vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0); +} + #else template<> @@ -128,6 +241,86 @@ MlasQuantizeLinearPackBytes( return IntegerVector; } +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ +#if defined(MLAS_SSE41_INTRINSICS) + IntegerVector = _mm_packus_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. +#else + // Cannot use _mm_packus_epi32 because that was not available until SSE4.1. + // Instead, emulate by sign-extending the first 16-bits of each packed 32-bit element. + // Afterwards, can use _mm_packs_epi32, which is available on SSE2. + // See: https://stackoverflow.com/a/11028244 + + IntegerVector = _mm_slli_epi32(IntegerVector, 16); + IntegerVector = _mm_srai_epi32(IntegerVector, 16); // Sign-extend: undo left shift with right arithmetic shift + IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. +#endif // defined(MLAS_SSE41_INTRINSICS) + + return IntegerVector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. + + return IntegerVector; +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + *(reinterpret_cast(Output)) = _mm_cvtsi128_si32(IntegerVector); + } else { + static_assert(std::is_same_v || std::is_same_v); + +#if defined(MLAS_TARGET_IX86) + // x86 does not support _mm_cvtsi128_si64, so use _mm_maskmoveu_si128 instead. + constexpr uint32_t bytes_high_bit = 0x80808080; + const __m128i first_8_bytes_mask = _mm_set_epi32(0, 0, bytes_high_bit, bytes_high_bit); + _mm_maskmoveu_si128(IntegerVector, first_8_bytes_mask, reinterpret_cast(Output)); +#else + *(reinterpret_cast(Output)) = _mm_cvtsi128_si64(IntegerVector); +#endif // defined(MLAS_TARGET_IX86) + } +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + + // Copies the lower element of the vector into memory (Output). + // Expects that the 32-bit element in lane 0 is already within the valid numerical + // range of the OutputType. + *Output = static_cast(_mm_cvtsi128_si32(IntegerVector)); +} + #endif template @@ -180,12 +373,7 @@ Return Value: MinimumValueVector, MaximumValueVector, ZeroPointVector); IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - -#if defined(MLAS_NEON64_INTRINSICS) - vst1q_lane_s32((int32_t*)Output, IntegerVector, 0); -#else - *((int32_t*)Output) = _mm_cvtsi128_si32(IntegerVector); -#endif + MlasQuantizeLinearStore4PackedValues(IntegerVector, Output); Input += 4; Output += 4; @@ -202,11 +390,7 @@ Return Value: auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, MinimumValueVector, MaximumValueVector, ZeroPointVector); -#if defined(MLAS_NEON64_INTRINSICS) - vst1q_lane_u8((uint8_t*)Output + n, vreinterpretq_u8_s32(IntegerVector), 0); -#else - *((uint8_t*)Output + n) = (uint8_t)_mm_cvtsi128_si32(IntegerVector); -#endif + MlasQuantizeLinearStoreSingleValue(IntegerVector, &Output[n]); } } @@ -236,6 +420,32 @@ MlasQuantizeLinearU8Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint +) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint +) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + template<> void MLASCALL @@ -274,6 +484,44 @@ MlasQuantizeLinear( Input, Output, N, Scale, ZeroPoint); } +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU16Kernel( +#else + MlasQuantizeLinearU16Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearS16Kernel( +#else + MlasQuantizeLinearS16Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + #else #if defined(MLAS_TARGET_POWER) @@ -306,6 +554,34 @@ MlasQuantizeLinear( GetMlasPlatform().QuantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); } +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearS16Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint); +} + #endif // @@ -381,6 +657,29 @@ MlasQuantizeLinear( float Scale, uint8_t ZeroPoint ); + +template +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ); + +template +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ); + #endif #endif diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index b67f6d6ec0794..624679e7b1b4b 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -1,131 +1,37 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/optimizer/double_qdq_pairs_remover.h" +#include #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" namespace onnxruntime { -Status DoubleQDQPairsRemover::ApplyImpl( - Graph& graph, - bool& modified, - int /*graph_level*/, - const logging::Logger& /*logger*/) const { - const GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - for (const auto& self_index : node_topology_list) { - NodeIndex parent_index = 0; - NodeIndex child_index = 0; - NodeIndex grandchild_index = 0; - if (IsNodeRemovable(graph, self_index, parent_index, child_index, grandchild_index)) { - graph.RemoveEdge(parent_index, self_index, 0, 0); - graph.RemoveEdge(self_index, child_index, 0, 0); - graph.RemoveEdge(child_index, grandchild_index, 0, 0); - graph_utils::ReplaceNodeInput(*graph.GetNode(grandchild_index), 0, *graph.GetNode(self_index)->MutableInputDefs()[0]); - graph.AddEdge(parent_index, grandchild_index, 0, 0); - graph.RemoveNode(child_index); - graph.RemoveNode(self_index); - modified = true; - } - } - return Status::OK(); -} - -bool DoubleQDQPairsRemover::IsNodeRemovable( - Graph& graph, - const NodeIndex& self_index, - NodeIndex& parent_index, - NodeIndex& child_index, - NodeIndex& grandchild_index) { - // Check if the self is a DQ, and have one parent and one child, and cannot be a graph output - Node* self = graph.GetNode(self_index); - if (self == nullptr || - self->OpType() != "DequantizeLinear" || - self->GetInputEdgesCount() != 1 || - self->GetOutputEdgesCount() != 1 || - self->InputDefs().size() != InputIndex::TOTAL_COUNT || - graph.NodeProducesGraphOutput(*self)) { - return false; - } - - // Type is either "tensor(uint8)" or "tensor(int8)" - const auto& self_zp_type = *self->InputDefs()[InputIndex::ZERO_POINT_ID]->Type(); - // child should be a Q, and have only one child, have the same type as self, and cannot be a graph output - child_index = self->OutputEdgesBegin()->GetNode().Index(); - const Node* child = graph.GetNode(child_index); - if (child == nullptr || - child->OpType() != "QuantizeLinear" || - child->GetOutputEdgesCount() != 1 || - child->InputDefs().size() != InputIndex::TOTAL_COUNT || - *child->InputDefs()[InputIndex::ZERO_POINT_ID]->Type() != self_zp_type || - graph.NodeProducesGraphOutput(*child)) { - return false; - } - - // parent should be a Q, and have only one output, and cannot be a graph output - parent_index = self->InputEdgesBegin()->GetNode().Index(); - Node* parent = graph.GetNode(parent_index); - if (parent == nullptr || - parent->GetOutputEdgesCount() != 1 || - parent->OpType() != "QuantizeLinear" || - graph.NodeProducesGraphOutput(*parent)) { - return false; - } - - // grandchild should be a DQ - grandchild_index = child->OutputEdgesBegin()->GetNode().Index(); - Node* grandchild = graph.GetNode(grandchild_index); - if (grandchild == nullptr || - grandchild->OpType() != "DequantizeLinear") { - return false; - } - const auto get_constant_initializer = [&graph](const std::string& initializer_name) { - return graph.GetConstantInitializer(initializer_name, true); - }; - if (!QDQ::IsQDQPairSupported(*parent, *self, get_constant_initializer, graph.ModelPath()) || - !QDQ::IsQDQPairSupported(*child, *grandchild, get_constant_initializer, graph.ModelPath())) { - return false; - } - bool skip_reset = false; - float new_scale = 0.0f; - if (self_zp_type == "tensor(uint8)") { - uint8_t new_zero_point = 0; - if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point, skip_reset)) { - return false; - } - if (skip_reset) { - return true; - } - ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point); - ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point); - } else { - int8_t new_zero_point = 0; - if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point, skip_reset)) { - return false; - } - if (skip_reset) { - return true; - } - ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point); - ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point); - } - return true; +// Applies a new zero point or scale as the input for a Q/DQ node. +template +static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) { + const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); + Initializer input_init{*input_tensor, graph.ModelPath()}; + ONNX_NAMESPACE::TensorProto new_input_tensor(*input_tensor); + input_init.data()[0] = value; + input_init.ToProto(new_input_tensor); + auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); + new_input_tensor.set_name(new_name); + NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); + graph_utils::ReplaceNodeInput(node, index, new_input); } +// Returns a new zero point and scale value for the given Q/DQ nodes. template -bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, - float& new_scale, T& new_zero_point, bool& skip_reset) { +static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, + float& new_scale, T& new_zero_point, bool& skip_reset) { // scale & zero point share same initializer, no need to reset the value - const std::string& node1_scale_name = node1.InputDefs()[InputIndex::SCALE_ID]->Name(); - const std::string& node2_scale_name = node2.InputDefs()[InputIndex::SCALE_ID]->Name(); - const std::string& node1_zp_name = node1.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); - const std::string& node2_zp_name = node2.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); + const std::string& node1_scale_name = node1.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + const std::string& node2_scale_name = node2.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + const std::string& node1_zp_name = node1.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); + const std::string& node2_zp_name = node2.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); skip_reset = false; if (node1_scale_name == node2_scale_name && node1_zp_name == node2_zp_name) { skip_reset = true; @@ -175,16 +81,141 @@ bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const N return true; } -template -void DoubleQDQPairsRemover::ApplyNewInputValue(Graph& graph, Node& node, const InputIndex& index, T value) { - const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); - Initializer input_init{*input_tensor, graph.ModelPath()}; - TensorProto new_input_tensor(*input_tensor); - input_init.data()[0] = value; - input_init.ToProto(new_input_tensor); - auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); - new_input_tensor.set_name(new_name); - NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); - graph_utils::ReplaceNodeInput(node, index, new_input); +// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2). This is necessary because +// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2, where +// the first pair has (zp1, scale1) and the second pair has (zp2, scale2). +// After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed +// for correctness. +template +static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, Node& dq2) { + bool skip_reset = false; + float new_scale = 0.0f; + ZeroPointType new_zero_point = 0; + if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, skip_reset)) { + return false; + } + if (skip_reset) { + return true; + } + ApplyNewInputValue(graph, dq2, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + + return true; +} + +// Checks if the provided node index (dq1_index) is a part of a valid double QDQ pair sequence +// (i.e., Q1 -> DQ1 -> Q2 -> DQ2) that can be reduced to the outer Q/DQ nodes (i.e., Q1 -> DQ2). +// If so, the zero point and scale of the outer Q/DQ nodes are recomputed and the node indices of the other nodes +// in the sequence (i.e., Q1, Q2, and DQ2) are returned via output parameters. +static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, NodeIndex dq1_index, + NodeIndex& q2_index, NodeIndex& dq2_index) { + // Ensure that dq1 is a DQ operator, has one parent and one child, and is not a graph output + Node* dq1 = graph.GetNode(dq1_index); + if (dq1 == nullptr || + dq1->OpType() != "DequantizeLinear" || + dq1->GetInputEdgesCount() != 1 || + dq1->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*dq1)) { + return false; + } + + // Ensure that q2 is a Q operator, has only one child, and is not a graph output + q2_index = dq1->OutputEdgesBegin()->GetNode().Index(); + const Node* q2 = graph.GetNode(q2_index); + if (q2 == nullptr || + q2->OpType() != "QuantizeLinear" || + q2->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*q2)) { + return false; + } + + // Ensure that q1 is a Q operator, has only one output, and is not a graph output + q1_index = dq1->InputEdgesBegin()->GetNode().Index(); + Node* q1 = graph.GetNode(q1_index); + if (q1 == nullptr || + q1->GetOutputEdgesCount() != 1 || + q1->OpType() != "QuantizeLinear" || + graph.NodeProducesGraphOutput(*q1)) { + return false; + } + + // Ensure the dq2 is a DQ operator. + dq2_index = q2->OutputEdgesBegin()->GetNode().Index(); + Node* dq2 = graph.GetNode(dq2_index); + if (dq2 == nullptr || + dq2->OpType() != "DequantizeLinear") { + return false; + } + + const auto get_constant_initializer = [&graph](const std::string& initializer_name) { + return graph.GetConstantInitializer(initializer_name, true); + }; + + // Each QDQ pair (i.e., q1 -> dq1, q2 -> dq2) has to meet the following additional requirements: + // - Scalar/constant zero-point and scale. + // - The DQ and Q ops within a pair must have the same scale and zero-point. + // However, each pair is allowed to have different scales and zero-points. + // + // TODO: IsQDQPairSupported() requires an explicit zero-point input, but technically a default + // value of 0 could be fine. + if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath()) || + !QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + return false; + } + + const auto& dq1_input_defs = dq1->InputDefs(); + const ONNX_NAMESPACE::TensorProto* dq1_zp_tensor_proto = graph.GetConstantInitializer( + dq1_input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true); + + assert(dq1_zp_tensor_proto != nullptr); // IsQDQPairSupported should have checked that this exists. + + auto dq1_zp_type = dq1_zp_tensor_proto->data_type(); + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) { + return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + } + + return false; // Unsupported zero-point type +} + +Status DoubleQDQPairsRemover::ApplyImpl( + Graph& graph, + bool& modified, + int /*graph_level*/, + const logging::Logger& /*logger*/) const { + const GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (const auto& dq1_index : node_topology_list) { + NodeIndex q1_index = 0; + NodeIndex q2_index = 0; + NodeIndex dq2_index = 0; + if (IsReducibleDoubleQDQSequence(graph, q1_index, dq1_index, q2_index, dq2_index)) { + graph.RemoveEdge(q1_index, dq1_index, 0, 0); + graph.RemoveEdge(dq1_index, q2_index, 0, 0); + graph.RemoveEdge(q2_index, dq2_index, 0, 0); + graph_utils::ReplaceNodeInput(*graph.GetNode(dq2_index), 0, *graph.GetNode(dq1_index)->MutableInputDefs()[0]); + graph.AddEdge(q1_index, dq2_index, 0, 0); + graph.RemoveNode(q2_index); + graph.RemoveNode(dq1_index); + modified = true; + } + } + return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h index c016f7181b7fe..1833b007674fd 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h @@ -3,19 +3,16 @@ #pragma once -#include "core/common/common.h" #include "core/optimizer/graph_transformer.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" namespace onnxruntime { -using ONNX_NAMESPACE::TensorProto; -using ONNX_NAMESPACE::TensorProto_DataType; -using QDQ::InputIndex; - /** * @Class DoubleQDQPairsRemover * @brief Remove one pair of Q-DQ from Double Q-DQ pairs. + * Specifically, this transformer converts the sequence Q1 -> DQ1 -> Q2 -> DQ2, where the first pair has (zp1, scale1) + * and the second pair has (zp2, scale2), into the sequence Q1 -> DQ2 by removing the middle two nodes. The zero-point + * and scale of the final QDQ pair is recomputed to preserve equality to the original sequence. */ class DoubleQDQPairsRemover : public GraphTransformer { public: @@ -27,28 +24,5 @@ class DoubleQDQPairsRemover : public GraphTransformer { bool& modified, int graph_level, const logging::Logger& logger) const override; - - static bool IsNodeRemovable( - Graph& graph, - const NodeIndex& self_index, - NodeIndex& parent_index, - NodeIndex& child_index, - NodeIndex& grandchild_index); - - template - static bool FindNewZeroPointAndScale( - const Graph& graph, - const Node& node1, - const Node& node2, - float& new_scale, - T& new_zero_point, - bool& skip_reset); - - template - static void ApplyNewInputValue( - Graph& graph, - Node& node, - const InputIndex& index, - T value); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d7039cb4b7cfc..0e383c3031ca6 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" @@ -32,7 +33,8 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // create rules for ops that don't change the data void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q. - const std::string action_name{"drop"}; + const std::string drop_action_name{"drop"}; + const std::string drop_action_no_int16_name{"drop_no_int16_support"}; NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; @@ -42,22 +44,33 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { MoveToSlot(dq, ArgType::kInput, 0, ArgType::kInput, 0), MoveToSlot(q, ArgType::kOutput, 0, ArgType::kOutput, 0)}; - std::unique_ptr action = std::make_unique(std::move(moves)); + std::unique_ptr drop_action_no_int16 = std::make_unique( + std::vector(moves)); // Copy before std::move(moves) + std::unique_ptr drop_action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); - qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + // Use a separate selector + action that disallows 16-bit types for MaxPool and Resize. + // int16 MaxPool is not supported by the ONNX specification. + // int16 Resize is not supported by the ORT implementation (although allowed by ONNX). + std::unique_ptr selector_disallow_16bit = std::make_unique(false); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, + {{"MaxPool", {12}}, + {"Resize", {}}}, + std::move(selector_disallow_16bit), + std::move(drop_action_no_int16)); + + std::unique_ptr selector = std::make_unique(true); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name, {{"Gather", {}}, {"Reshape", {}}, {"Transpose", {}}, - {"MaxPool", {12}}, - {"Resize", {}}, {"Squeeze", {}}, {"Unsqueeze", {}}}, std::move(selector), - std::move(action)); + std::move(drop_action)); #else - qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); + qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16)); + qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action)); #endif } @@ -74,6 +87,7 @@ void DropDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when ArgMax supports 16-bit integer input tensors. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"ArgMax", {}}}, @@ -91,6 +105,7 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when unary QLinear* ops support 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"AveragePool", {}}, @@ -112,6 +127,7 @@ void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Add", {}}, @@ -131,6 +147,7 @@ void VariadicOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearConcat supports 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, @@ -152,6 +169,7 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_ std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit. std::unique_ptr selector = std::make_unique(is_int8_allowed); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, @@ -174,6 +192,7 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. std::unique_ptr selector = std::make_unique(is_int8_allowed); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, @@ -195,6 +214,7 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QGemm supports 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Gemm", {}}}, @@ -215,6 +235,7 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + // TODO: Enable 16-bit types in selector when QLinearWhere supports 16-bit. std::unique_ptr selector = std::make_unique(); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Where", {}}}, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 02a7fb733813c..16c7bd5fce960 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -14,6 +14,12 @@ namespace onnxruntime { namespace QDQ { namespace { + +constexpr bool Is16BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -110,6 +116,17 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + const Node& dq_node = *dq_nodes.front(); const Node& q_node = *q_nodes.front(); @@ -124,7 +141,7 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - int num_dq_inputs = NumActualValues(node, true); + constexpr int num_dq_inputs = 1; if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { return false; } @@ -136,6 +153,12 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, (void)q_nodes; const Node& dq_node = *dq_nodes.front(); + const int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { return graph_viewer.GetConstantInitializer(initializer_name, true); @@ -154,7 +177,16 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_input == dt_output; + if (dt_input != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + + return true; } bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, @@ -168,8 +200,18 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_input_1 == dt_input_2 && - dt_input_1 == dt_output; + + // All input and output types must match. + if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input_1)) { + return false; + } + + return true; } bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, @@ -194,7 +236,17 @@ bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } } - return dt_input == dt_output; + + if (dt_input != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + + return true; } void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { @@ -227,12 +279,19 @@ bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } - if (dq_nodes.size() < 3) { // no bias - return true; + if (dq_nodes.size() == 3) { // has bias + int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } } - int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && (Is16BitIntType(dt_input) || Is16BitIntType(dt_weight))) { + return false; + } + + return true; } void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { @@ -256,6 +315,11 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && (Is16BitIntType(dt_input) || Is16BitIntType(dt_weight))) { + return false; + } + // potential match for QLinearMatMul or MatMulIntegerToFloat bool qlinear = !q_nodes.empty(); @@ -299,6 +363,11 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && (Is16BitIntType(dt_A) || Is16BitIntType(dt_B))) { + return false; + } + if (dq_nodes.size() < 3) { // no bias return true; } @@ -326,8 +395,18 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& const int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); const int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); const int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_input_1 == dt_input_2 && - dt_input_1 == dt_output; + + // All input and output types must match. + if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) { + return false; + } + + // 16-bit int types must be explicitly allowed. + if (!allow_16bit_ && Is16BitIntType(dt_input_1)) { + return false; + } + + return true; } bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 58ebf81508962..d8fefdd8dc3d9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -52,45 +52,75 @@ class NodeGroupSelector { // Single DQ -> node that does not change data -> Q. // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { + public: + explicit DropQDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // Single DQ -> node. class DropDQNodeGroupSelector : public NodeGroupSelector { + public: + explicit DropDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // single input. default is to only support uint8. class UnaryNodeGroupSelector : public NodeGroupSelector { + public: + explicit UnaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // 2 DQ nodes providing input -> node -> Q class BinaryNodeGroupSelector : public NodeGroupSelector { + public: + explicit BinaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // Variadic DQ nodes -> node -> Q class VariadicNodeGroupSelector : public NodeGroupSelector { + public: + explicit VariadicNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: // default to 'true' - ConvNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {} + ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true) + : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -98,16 +128,20 @@ class ConvNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool int8_allowed_; + bool allow_16bit_; }; class WhereNodeGroupSelector : public NodeGroupSelector { public: - WhereNodeGroupSelector() = default; + explicit WhereNodeGroupSelector(bool allow_16bit = true) + : allow_16bit_(allow_16bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; class PadNodeGroupSelector : public NodeGroupSelector { @@ -125,9 +159,11 @@ class PadNodeGroupSelector : public NodeGroupSelector { class MatMulNodeGroupSelector : public NodeGroupSelector { public: MatMulNodeGroupSelector(bool int8_allowed = true, - bool matmulintegertofloat_allowed = false) + bool matmulintegertofloat_allowed = false, + bool allow_16bit = true) : int8_allowed_(int8_allowed), - matmulintegertofloat_allowed_(matmulintegertofloat_allowed) { + matmulintegertofloat_allowed_(matmulintegertofloat_allowed), + allow_16bit_(allow_16bit) { } private: @@ -136,15 +172,21 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool int8_allowed_; bool matmulintegertofloat_allowed_; + bool allow_16bit_; }; // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { + public: + explicit GemmNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + private: bool Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; + + bool allow_16bit_; }; // Input: DQ nodes for input, scale, and B @@ -207,28 +249,33 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - DropQDQNodesSelector() : BaseSelector(std::make_unique()) {} + explicit DropQDQNodesSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; class DropDQNodesSelector : public BaseSelector { public: - DropDQNodesSelector() : BaseSelector(std::make_unique()) {} + explicit DropDQNodesSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; class UnarySelector : public BaseSelector { public: - UnarySelector() : BaseSelector(std::make_unique()) {} + explicit UnarySelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; class BinarySelector : public BaseSelector { public: - BinarySelector() : BaseSelector(std::make_unique()) {} + explicit BinarySelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; // Variadic DQ nodes -> node -> Q class InputVariadicSelector : public BaseSelector { public: - InputVariadicSelector() : BaseSelector(std::make_unique()) {} + explicit InputVariadicSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -244,46 +291,36 @@ class OutputVariadicSelector : public BaseSelector { // DQ nodes for X, W and optionally B -> node -> Q class ConvSelector : public BaseSelector { public: - ConvSelector(bool int8_allowed = false) : BaseSelector(std::make_unique(int8_allowed)) {} + ConvSelector(bool int8_allowed = false, bool allow_16bit = false) + : BaseSelector(std::make_unique(int8_allowed, allow_16bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; + class WhereSelector : public BaseSelector { public: - WhereSelector() : BaseSelector(std::make_unique()) {} + explicit WhereSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} }; + // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(bool int8_allowed) - : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true)) {} + MatMulSelector(bool int8_allowed, bool allow_16bit = false) + : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, + allow_16bit)) {} }; // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - GemmSelector() - : BaseSelector(std::make_unique()) {} + explicit GemmSelector(bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; -// Input: DQ nodes for input, scale, and B (bias) -// Output: Q node for output -class InstanceNormalizationSelector : public BaseSelector { - public: - InstanceNormalizationSelector() - : BaseSelector(std::make_unique()) {} -}; - -// DQ nodes for X, W and optionally B, (mean, var not required) -> node -> Q -class BatchNormalizationSelector : public BaseSelector { - public: - BatchNormalizationSelector(bool int8_allowed = false) - : BaseSelector(std::make_unique(int8_allowed)) {} -}; - } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 3723ee6032582..2c11bf144999e 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -1195,7 +1195,7 @@ bool TransposeQuantizeDequantizeAxis(const api::GraphRef& graph, const std::vect static bool HandleQuantizeDequantizeAxis(const api::GraphRef& graph, const std::vector& perm, api::NodeRef& node, int64_t opset) { if (opset < 13) { - // no `axis` value until opset 13 + // no `axis` attribute until opset 13 return true; } diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 67a9a5991939a..a0d75e8cc0e69 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -5,13 +5,47 @@ #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" #include "core/framework/float16.h" -#include "core/providers/cpu/quantization/quantize_linear.h" +#include "core/framework/op_kernel.h" #include "core/providers/common.h" #include "core/mlas/inc/mlas.h" #include "core/util/qmath.h" namespace onnxruntime { +template +class DequantizeLinear final : public OpKernel { + public: + explicit DequantizeLinear(const OpKernelInfo& info) : OpKernel(info) { + if (!info.GetAttr("axis", &axis_).IsOK()) { + axis_ = 1; + } + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t axis_; +}; + +template +class QuantizeLinear final : public OpKernel { + public: + explicit QuantizeLinear(const OpKernelInfo& info) : OpKernel(info) { + if (!info.GetAttr("axis", &axis_).IsOK()) { + axis_ = 1; + } + if (!info.GetAttr("saturate", &saturate_).IsOK()) { + saturate_ = 1; + } + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t axis_; + int64_t saturate_; +}; + static void PrepareForQDQ(const TensorShape& input_shape, const Tensor& scale, const Tensor* zero_point_ptr, @@ -86,6 +120,59 @@ REGISTER_DEQUANTIZELINEAR_VERSIONED(int8_t) REGISTER_DEQUANTIZELINEAR_VERSIONED(uint8_t) REGISTER_DEQUANTIZELINEAR_VERSIONED(int32_t) +#if !defined(DISABLE_CONTRIB_OPS) +namespace contrib { + +// Register alternate MS domain versions of the DequantizeLinear kernel. +// The MS domain versions additionally support 16-bit integer quantization types. +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + uint8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + int8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + uint16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + int16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + int32_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +} // namespace contrib +#endif // !defined(DISABLE_CONTRIB_OPS) + template struct DequantizeLinearApply { void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) { @@ -220,6 +307,49 @@ REGISTER_QUANTIZELINEAR(Float8E5M2FNUZ) REGISTER_QUANTIZELINEAR_VERSIONED(int8_t) REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t) +#if !defined(DISABLE_CONTRIB_OPS) +namespace contrib { + +// Register alternate MS domain versions of the QuantizeLinear kernel. +// The MS domain versions additionally support 16-bit integer quantization types. +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + uint8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + int8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + uint16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + int16_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); +} // namespace contrib +#endif // !defined(DISABLE_CONTRIB_OPS) + template void ParQuantizeLinear(const InputType* Input, OutputType* Output, @@ -279,5 +409,4 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { return Status::OK(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.h b/onnxruntime/core/providers/cpu/quantization/quantize_linear.h deleted file mode 100644 index 60e9d09665ab2..0000000000000 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" -#include "core/util/math_cpuonly.h" - -namespace onnxruntime { - -template -class DequantizeLinear final : public OpKernel { - public: - DequantizeLinear(const OpKernelInfo& info) : OpKernel(info) { - if (!info.GetAttr("axis", &axis_).IsOK()) { - axis_ = 1; - } - } - - Status Compute(OpKernelContext* context) const override; - - private: - int64_t axis_; -}; - -template -class QuantizeLinear final : public OpKernel { - public: - QuantizeLinear(const OpKernelInfo& info) : OpKernel(info) { - if (!info.GetAttr("axis", &axis_).IsOK()) { - axis_ = 1; - } - if (!info.GetAttr("saturate", &saturate_).IsOK()) { - saturate_ = 1; - } - } - - Status Compute(OpKernelContext* context) const override; - - private: - int64_t axis_; - int64_t saturate_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 556a86bb1519b..8081033c35618 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -30,6 +30,12 @@ class SimpleOpBuilder : public BaseOpBuilder { private: Status ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; + Status ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + std::vector&& param_tensor_names, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; @@ -279,10 +285,120 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names)); } - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, - std::move(input_names), - std::move(param_tensor_names), - logger, do_op_validation, GetQnnOpType(op_type))); + if (op_type == "Sigmoid" || op_type == "Tanh") { + // QNN requires 16-bit QDQ Sigmoid and Tanh to use specific output scale and zero-point values + // regardless of floating-point range. + return ProcessSigmoidOrTanhOutput(qnn_model_wrapper, + node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation); + } + + return ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(op_type)); +} + +/** + * Overrides offset and scale quantization parameters for operators (e.g., Sigmoid or Tanh) that require + * specific values. Returns true if the quantization parameters were overridden. + * + * \param op_type The ONNX operator type. + * \param qnn_data_type The QNN tensor data type. + * \param quant_params Output scale/offset parameter that may be overridden. + * \return True if the offset and scale were overridden. + */ +static bool OverrideQuantParams(const std::string& op_type, Qnn_DataType_t qnn_data_type, + Qnn_ScaleOffset_t& quant_params) { + const int32_t orig_offset = quant_params.offset; + const float orig_scale = quant_params.scale; + + if (op_type == "Sigmoid") { + switch (qnn_data_type) { + case QNN_DATATYPE_UFIXED_POINT_16: + quant_params.offset = 0; + quant_params.scale = 1.0f / 65536.0f; + break; + case QNN_DATATYPE_SFIXED_POINT_16: + quant_params.offset = 0; + quant_params.scale = 1.0f / 32768.0f; + break; + default: + break; // Do nothing. + } + } + + if (op_type == "Tanh") { + switch (qnn_data_type) { + case QNN_DATATYPE_UFIXED_POINT_16: + quant_params.offset = -32768; + quant_params.scale = 1.0f / 32768.0f; + break; + case QNN_DATATYPE_SFIXED_POINT_16: + quant_params.offset = 0; + quant_params.scale = 1.0f / 32768.0f; + break; + default: + break; // Do nothing. + } + } + + return quant_params.offset != orig_offset || quant_params.scale != orig_scale; +} + +/** + * Processes the output for Sigmoid or Tanh operators and creates the corresponding QNN operator. + * These operator types are handled separately because QNN requires 16-bit QDQ Sigmoid and Tanh operators to use + * specific scale and zero-point values regardless of floating-point range. + * + * \param qnn_model_wrapper The QNN model wrapper object. + * \param node_unit The QDQ node unit for the Sigmoid or Tanh node. + * \param input_names List of input names. + * \param param_tensor_names List of param tensor names. + * \param logger Logger used to report information. + * \param do_op_validation True if the new QNN node should be validated. + */ +Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + std::vector&& param_tensor_names, + const logging::Logger& logger, + bool do_op_validation) const { + const std::string& op_type = node_unit.OpType(); + const auto& output = node_unit.Outputs()[0]; + const std::string& output_name = output.node_arg.Name(); + + OnnxInputInfo output_info = {}; + + // TODO(adrianlizarraga): Rename GetOnnxInputInfo() since it can be used for outputs as well. + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(output, output_info)); + + if (output_info.quant_param.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + if (OverrideQuantParams(op_type, output_info.qnn_data_type, output_info.quant_param.scaleOffsetEncoding)) { + const int32_t offset = output_info.quant_param.scaleOffsetEncoding.offset; + const float scale = output_info.quant_param.scaleOffsetEncoding.scale; + + LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values " + << "of <" << offset << ", " << scale << ">. QNN EP will override the original values."; + } + } + + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ + : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param, + std::move(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + GetQnnOpType(op_type), + std::move(input_names), + {output_name}, + std::move(param_tensor_names), + do_op_validation), + "Failed to add node."); + return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index eebe75d839b12..9d339387b0a43 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -301,6 +301,16 @@ bool QnnModelWrapper::ProcessOffset(const std::string& offset_name, offset_value = 0 - (uint8_span.data()[0]); break; } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + offset_value = -static_cast(uint16_span.data()[0]); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); + offset_value = -static_cast(int16_span.data()[0]); + break; + } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); offset_value = -(int32_span.data()[0]); diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 924d4c72b6390..2d1e418f9d2b4 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -104,7 +104,7 @@ def __init__( ) self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] self.is_weight_symmetric = ( - weight_qType in (QuantType.QInt8, QuantType.QFLOAT8E4M3FN) + weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) if "WeightSymmetric" not in self.extra_options else self.extra_options["WeightSymmetric"] ) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index f87a9d8228bac..e595b580b20df 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -25,6 +25,7 @@ add_quant_output_suffix, add_quant_suffix, find_by_name, + ms_domain, ) from .registry import CreateQDQQuantizer @@ -119,6 +120,20 @@ def __init__( else extra_options["QDQOpTypePerChannelSupportToAxis"] ) + self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None + + # The ONNX spec does not yet support 16-bit Q/DQ ops. So, must override the Q/DQ op domain to 'com.microsoft' + # if the activation or weight types are 16-bit integers. + # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. + int16_types = (TensorProto.UINT16, TensorProto.INT16) + if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types): + logging.warning( + "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. " + f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " + "enable support." + ) + self.qdq_op_domain = ms_domain + def _is_tensor_quantizable(self, tensor_name): """ Check if tensor can be quantized @@ -249,6 +264,7 @@ def _create_qdq_nodes( [q_output], quant_node_name, axis=axis, + domain=self.qdq_op_domain, ) dequant_node = onnx.helper.make_node( DEQUANT_OP_NAME, @@ -256,6 +272,7 @@ def _create_qdq_nodes( [dq_output], dequant_node_name, axis=axis, + domain=self.qdq_op_domain, ) self.model.add_nodes([qlinear_node, dequant_node]) @@ -300,6 +317,7 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): [weight_dequant_output], add_dequant_suffix(weight_name), axis=axis, + domain=self.qdq_op_domain, ) self.model.add_node(dequant_node) @@ -443,6 +461,7 @@ def _quantize_bias_tensors(self): [bias_name], node_name, axis=quant_value.axis, + domain=self.qdq_op_domain, ) else: dequant_node = onnx.helper.make_node( @@ -450,6 +469,7 @@ def _quantize_bias_tensors(self): inputs, [bias_name], node_name, + domain=self.qdq_op_domain, ) else: raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.") diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 4d5bcca29618f..74e54c3f1fa37 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -72,6 +72,8 @@ class QuantType(Enum): QInt8 = 0 QUInt8 = 1 QFLOAT8E4M3FN = 2 + QInt16 = 3 + QUInt16 = 4 def __str__(self): return self.name @@ -89,6 +91,10 @@ def tensor_type(self): return TensorProto.INT8 if self == QuantType.QUInt8: return TensorProto.UINT8 + if self == QuantType.QUInt16: + return TensorProto.UINT16 + if self == QuantType.QInt16: + return TensorProto.INT16 if self == QuantType.QFLOAT8E4M3FN: return TensorProto.FLOAT8E4M3FN raise ValueError(f"Unexpected value qtype={self!r}.") @@ -112,12 +118,35 @@ def from_string(format): ONNX_TYPE_TO_NP_TYPE = { onnx_proto.TensorProto.INT8: numpy.dtype("int8"), onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"), + onnx_proto.TensorProto.INT16: numpy.dtype("int16"), + onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, } +ONNX_INT_TYPE_RANGE = { + onnx_proto.TensorProto.UINT8: (0, 255), + onnx_proto.TensorProto.INT8: (-128, 127), + onnx_proto.TensorProto.UINT16: (0, 65535), + onnx_proto.TensorProto.INT16: (-32768, 32767), +} + +ONNX_INT_TYPE_SYMMETRIC_RANGE = { + onnx_proto.TensorProto.INT8: (-127, 127), + onnx_proto.TensorProto.INT16: (-32767, 32767), +} + +ONNX_INT_TYPE_REDUCED_RANGE = { + onnx_proto.TensorProto.UINT8: (0, 127), + onnx_proto.TensorProto.INT8: (-64, 64), + onnx_proto.TensorProto.UINT16: (0, 32767), + onnx_proto.TensorProto.INT16: (-16384, 16384), +} + def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): - assert qType in ONNX_TYPE_TO_NP_TYPE, f"Unexpected data type {qType} requested. Only INT8 and UINT8 are supported." + assert ( + qType in ONNX_TYPE_TO_NP_TYPE + ), f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported." if qType in ( onnx_proto.TensorProto.FLOAT8E4M3FN, onnx_proto.TensorProto.FLOAT8E4M3FNUZ, @@ -146,8 +175,10 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): return ref.run(None, {"X": arr.astype(numpy.float32), "scale": scale.astype(numpy.float32)})[0] else: dtype = ONNX_TYPE_TO_NP_TYPE[qType] - cliplow = max(0 if dtype == numpy.uint8 else -127, -127 if low is None else low) - cliphigh = min(255 if dtype == numpy.uint8 else 127, 255 if high is None else high) + (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) + + cliplow = max(qmin, low) if low is not None else qmin + cliphigh = min(qmax, high) if high is not None else qmax arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point) numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32) return arr_fp32.astype(dtype) @@ -267,7 +298,7 @@ def quantize_data(data, qType, symmetric, reduce_range=False): ) return rmin, rmax, zero_point, scale, quantized_data - if qType in (TensorProto.INT8, TensorProto.UINT8): + if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16): if len(data): qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric) @@ -283,18 +314,22 @@ def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8 :return: qmin, qmax """ - if qType == onnx_proto.TensorProto.UINT8: - (qmin, qmax) = (0, 127) if reduce_range else (0, 255) - elif qType == onnx_proto.TensorProto.INT8: - if symmetric: - (qmin, qmax) = (-64, 64) if reduce_range else (-127, 127) - else: - (qmin, qmax) = (-64, 64) if reduce_range else (-128, 127) - elif qType == onnx_proto.TensorProto.FLOAT8E4M3FN: + if qType == onnx_proto.TensorProto.FLOAT8E4M3FN: raise NotImplementedError("This function is not implemented for float 8 as not needed.") + + qrange = None + + if reduce_range: + qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType) + elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE: + qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType] else: - raise ValueError(f"Unexpected data type {qType} requested. Only INT8 and UINT8 are supported.") - return qmin, qmax + qrange = ONNX_INT_TYPE_RANGE.get(qType) + + if not qrange: + raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.") + + return qrange def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802 diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 6b1646aec9679..706047fe32400 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -240,6 +240,11 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua f"weight_type={weight_type}!=QuantType.QFLOAT8E4M3FN" ) + q16_types = [QuantType.QInt16, QuantType.QUInt16] + + if (activation_type in q16_types or weight_type in q16_types) and quant_format != QuantFormat.QDQ: + raise ValueError("Only QuantFormat.QDQ supports 16-bit quantization types.") + if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ: logging.warning( "Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. " @@ -356,6 +361,11 @@ def quantize_static( SmoothQuantFolding = True/False : Default is True. It only works if SmoothQuant is True. If enabled, inserted Mul ops during SmoothQuant will be folded into the previous op if the previous op is foldable. + UseQDQContribOps = True/False : + Default is False. If enabled, the inserted QuantizeLinear and DequantizeLinear ops will have the + `com.microsoft` domain, which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear + contrib op implementations. The contrib op implementations may support features not standardized + into the ONNX specification (e.g., 16-bit quantization types). """ if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN: if calibrate_method != CalibrationMethod.Distribution: diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc index af29f972a64cf..64a97ed4f945b 100644 --- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" namespace onnxruntime { namespace test { @@ -40,7 +41,31 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Scalar zero & scale with int32 +// Test int16 com.microsoft.DequantizeLinear (per tensor) +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int16_cpu) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {-300, -30, -1025, 1270}); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {-1024}, true); + test.AddOutput("y", dims, {1448.0f, 1988.0f, -2.0f, 4588.0f}); + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint16 com.microsoft.DequantizeLinear (per tensor) +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_uint16_cpu) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {30000, 31000, 32768, 33000}); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {32767}, true); + test.AddOutput("y", dims, {-5534.0f, -3534.0f, 2.0f, 466.0f}); + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test int32 DequantizeLinear with scalar zero-point & scale. TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int32_cpu) { OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); std::vector dims{4}; @@ -256,6 +281,60 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// Test uint16 com.microsoft.QuantizeLinear (per tensor) +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint16) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{12}; + test.AddInput("x", dims, { + 0.f, -128.f, 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // round < .5 + 3.1f, -3.1f, // round > .5 + 65536.f, -65534.f, // critical point + 70000.f, -70000.f // saturate case + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {32767}, true); + test.AddOutput("y", dims, + {32767, 32703, + 32769, 32765, + 32768, 32766, + 32769, 32765, + 65535, 0, + 65535, 0}); + + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test int16 com.microsoft.QuantizeLinear (per tensor) +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int16) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{16}; + test.AddInput("x", dims, { + 0.f, -514.f, 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // round < .5 + 3.1f, -3.1f, // round > .5 + 65022.f, -66046.f, // critical point + 65023.f, -66047.f, // critical point + 65024.f, -66048.f, // critical point + 70000.f, -70000.f // saturate case + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {256}, true); + test.AddOutput("y", dims, + {256, -1, + 258, 254, + 257, 255, + 258, 254, + 32767, -32767, + 32767, -32768, + 32767, -32768, + 32767, -32768}); + + // Disable Tensorrt EP due to error: unsupported data type + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + #ifdef USE_CUDA TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_half_uint8) { OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index 55d1a2f4f3608..2832598fef1a9 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -3,26 +3,26 @@ #include "test_util.h" -template +template class MlasQuantizeLinearTest : public MlasTestBase { private: MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; - void GenerateReference(const float* Input, xint8_t* OutputReference, size_t N, float Scale, xint8_t ZeroPoint) { + void GenerateReference(const float* Input, QuantInt* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { for (size_t n = 0; n < N; n++) { float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); - FloatValue = std::max(FloatValue, float(std::numeric_limits::min())); - FloatValue = std::min(FloatValue, float(std::numeric_limits::max())); - OutputReference[n] = (xint8_t)FloatValue; + FloatValue = std::max(FloatValue, static_cast(std::numeric_limits::min())); + FloatValue = std::min(FloatValue, static_cast(std::numeric_limits::max())); + OutputReference[n] = static_cast(FloatValue); } } void Test(size_t N) { float* Input = BufferInput.GetBuffer(N); - xint8_t* Output = BufferOutput.GetBuffer(N); - xint8_t* OutputReference = BufferOutputReference.GetBuffer(N); + QuantInt* Output = BufferOutput.GetBuffer(N); + QuantInt* OutputReference = BufferOutputReference.GetBuffer(N); std::default_random_engine generator(static_cast(N)); @@ -34,8 +34,9 @@ class MlasQuantizeLinearTest : public MlasTestBase { float Scale = (MaximumValue - MinimumValue) / 512.f; - std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), std::numeric_limits::max()); - xint8_t ZeroPoint = static_cast(zp_distribution(generator)); + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); std::uniform_real_distribution distribution(MinimumValue, MaximumValue); for (size_t n = 0; n < N; n++) { @@ -52,8 +53,15 @@ class MlasQuantizeLinearTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - static const std::string suite_name(std::is_signed::value ? "QuantizeLinearS8" : "QuantizeLinearU8"); - return suite_name.c_str(); + if constexpr (std::is_same_v) { + return "QuantizeLinearS8"; + } else if (std::is_same_v) { + return "QuantizeLinearU8"; + } else if (std::is_same_v) { + return "QuantizeLinearS16"; + } else { + return "QuantizeLinearU16"; + } } void ExecuteShort(void) override { @@ -67,12 +75,18 @@ template <> MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index d0ce4898a472c..feff607703341 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -20,15 +20,17 @@ struct GraphConfig { bool has_subgraph_consumer{false}; }; -auto GetGraphBuilder(const GraphConfig& config, bool use_ms_domain_qdq_ops) { +template +std::function GetGraphBuilder(const GraphConfig& config, bool use_ms_domain_qdq_ops) { return [config, use_ms_domain_qdq_ops](ModelTestBuilder& builder) { const auto input_shape = std::vector{1, 2, 4}; constexpr float scale = 0.5f; - constexpr uint8_t zero_point = 0; + constexpr QuantType zero_point = 0; - auto* dq_input = builder.MakeInput(input_shape, uint8_t{0}, uint8_t{255}); + auto* dq_input = builder.MakeInput(input_shape, std::numeric_limits::min(), + std::numeric_limits::max()); auto* dq_output = config.has_graph_output ? builder.MakeOutput() : builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(dq_input, scale, zero_point, dq_output, use_ms_domain_qdq_ops); + builder.AddDequantizeLinearNode(dq_input, scale, zero_point, dq_output, use_ms_domain_qdq_ops); for (size_t i = 0; i < config.num_explicit_consumer_nodes; ++i) { // use Concat for the explicit consumer node as it supports a variadic number of inputs @@ -71,10 +73,12 @@ auto GetGraphBuilder(const GraphConfig& config, bool use_ms_domain_qdq_ops) { } void RunEnsureUniqueDQForNodeUnitTest(const GraphConfig& config, int expected_dq_count) { - auto run_tests = [config, expected_dq_count](bool use_ms_domain_qdq_ops) { + auto run_tests = [config, expected_dq_count](bool use_ms_domain_qdq_ops, bool use_16bit_qdq_ops) { constexpr int opset_version = 12; const char* dequantize_linear_key = use_ms_domain_qdq_ops ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - std::function graph_builder_fn = GetGraphBuilder(config, use_ms_domain_qdq_ops); + std::function graph_builder_fn = use_16bit_qdq_ops + ? GetGraphBuilder(config, use_ms_domain_qdq_ops) + : GetGraphBuilder(config, use_ms_domain_qdq_ops); { SCOPED_TRACE("test with standalone transformer"); @@ -117,9 +121,10 @@ void RunEnsureUniqueDQForNodeUnitTest(const GraphConfig& config, int expected_dq } }; - run_tests(false); + run_tests(false, false); #if !defined(DISABLE_CONTRIB_OPS) - run_tests(true); // Use contrib QDQ ops. + run_tests(true, false); // Use contrib QDQ ops. + run_tests(true, true); // Use 16-bit contrib QDQ ops. #endif } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 553fcca92aa78..dce1f2d40e8b9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -83,6 +83,7 @@ #include "test/util/include/test_utils.h" #include "core/optimizer/pre_shape_node_elimination.h" #include "core/optimizer/double_qdq_pairs_remover.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" #ifdef ENABLE_TRAINING #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #endif @@ -155,44 +156,43 @@ TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) { ASSERT_TRUE(op_to_count["Add"] == 1); } +// Runs a model to ensure that common subexpression elimination does not eliminate +// DequantizeLinear nodes. TEST_F(GraphTransformationTests, DequantizeLinearNodeNotEliminated) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["DequantizeLinear"], 25); + auto test_case = [](const ORTCHAR_T* model_uri, + bool use_contrib_qdq, + const logging::Logger& logger) { + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count[dq_key], 25); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), - TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); - // CommonSubexpressionElimination should skip the DequantizeLinear nodes - op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["DequantizeLinear"], 25); -} + // CommonSubexpressionElimination should skip the DequantizeLinear nodes + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count[dq_key], 25); + }; + test_case(MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.onnx", + false, // use_contrib_qdq + *logger_); #if !defined(DISABLE_CONTRIB_OPS) -// Test that com.microsoft.DequantizeLinear is not eliminated in CommonSubexpressionElimination -TEST_F(GraphTransformationTests, MsDomainDequantizeLinearNodeNotEliminated) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 25); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), - TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - - // CommonSubexpressionElimination should skip the DequantizeLinear nodes - op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 25); -} + // Test with 8-bit com.microsoft.DequantizeLinear + test_case(MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.qdq_contrib.onnx", + true, // use_contrib_qdq + *logger_); + // Test with 16-bit com.microsoft.DequantizeLinear + test_case(MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx", + true, // use_contrib_qdq + *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "scan9_sum.onnx"; @@ -836,158 +836,120 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == 1); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); - ASSERT_TRUE(op_to_count["Conv"] == 1); - - std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, - {"DequantizeLinear", 3}, - {"Conv", 1}}; - - SessionOptions session_options; - // Check DequantizeLinear aren't constant folded for default setting. - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); - - // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); + auto test_case = [](const ORTCHAR_T* model_uri, + bool use_contrib_qdq, + const logging::Logger& logger) { + const char* q_key = use_contrib_qdq ? "com.microsoft.QuantizeLinear" : "QuantizeLinear"; + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - // set SessionOptionsEnableQuantQDQ to disable it - expected_op_counts["DequantizeLinear"] = 1; - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count[q_key] == 1); + ASSERT_TRUE(op_to_count[dq_key] == 3); + ASSERT_TRUE(op_to_count["Conv"] == 1); -#if !defined(DISABLE_CONTRIB_OPS) -// Test constant folding with a com.microsoft.DequantizeLinear node -TEST_F(GraphTransformationTests, ConstantFoldingWithMsDomainDequantizeLinear) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_dequantizelinear.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.QuantizeLinear"], 1); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 3); - ASSERT_EQ(op_to_count["Conv"], 1); + std::unordered_map expected_op_counts = {{q_key, 1}, + {dq_key, 3}, + {"Conv", 1}}; - std::unordered_map expected_op_counts = {{"com.microsoft.QuantizeLinear", 1}, - {"com.microsoft.DequantizeLinear", 3}, - {"Conv", 1}}; + SessionOptions session_options; + // Check DequantizeLinear aren't constant folded for default setting. + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); - SessionOptions session_options; - // Check DequantizeLinear aren't constant folded for default setting. - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); + // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); - // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); + // set SessionOptionsEnableQuantQDQ to disable it + expected_op_counts[dq_key] = 1; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); + }; - // set SessionOptionsEnableQuantQDQ to disable it - expected_op_counts["com.microsoft.DequantizeLinear"] = 1; - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + test_case(MODEL_FOLDER "fusion/constant_folding_dequantizelinear.onnx", + false, *logger_); +#if !defined(DISABLE_CONTRIB_OPS) + // Test with 8-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_dequantizelinear.qdq_contrib.onnx", + true, *logger_); + // Test with 16-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx", + true, *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} // model with 2 QDQ node units that can be constant folded as they are simple DQ -> Node -> Q where DQ and Node have // single consumer and do not produce graph outputs. Node is deterministic. // there are also other DQ nodes that should be ignored. TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnit) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == 3); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == 4); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); - ASSERT_TRUE(op_to_count["Transpose"] == 1); + auto test_case = [](const ORTCHAR_T* model_uri, bool use_contrib_qdq, const logging::Logger& logger) { + const char* q_key = use_contrib_qdq ? "com.microsoft.QuantizeLinear" : "QuantizeLinear"; + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - SessionOptions session_options; - - // 2 QDQ node units should be constant folded and go away - std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, - {"DequantizeLinear", 2}, - {"Transpose", 0}, - {"Unsqueeze", 0}}; - - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count[q_key] == 3); + ASSERT_TRUE(op_to_count[dq_key] == 4); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + ASSERT_TRUE(op_to_count["Transpose"] == 1); -#if !defined(DISABLE_CONTRIB_OPS) -// model with 2 (com.microsoft) QDQ node units that can be constant folded as they are simple DQ -> Node -> Q where -// DQ and Node have single consumer and do not produce graph outputs. Node is deterministic. -// there are also other DQ nodes that should be ignored. -TEST_F(GraphTransformationTests, ConstantFoldingMsDomainQDQNodeUnit) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.QuantizeLinear"], 3); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 4); - ASSERT_EQ(op_to_count["Unsqueeze"], 1); - ASSERT_EQ(op_to_count["Transpose"], 1); + SessionOptions session_options; - SessionOptions session_options; + // 2 QDQ node units should be constant folded and go away + std::unordered_map expected_op_counts = {{q_key, 1}, + {dq_key, 2}, + {"Transpose", 0}, + {"Unsqueeze", 0}}; - // 2 QDQ node units should be constant folded and go away - std::unordered_map expected_op_counts = {{"com.microsoft.QuantizeLinear", 1}, - {"com.microsoft.DequantizeLinear", 2}, - {"Transpose", 0}, - {"Unsqueeze", 0}}; + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); + }; - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.onnx", false, *logger_); +#if !defined(DISABLE_CONTRIB_OPS) + // Test with 8-bit com.microsoft.Q/DQ + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.qdq_contrib.onnx", true, *logger_); + // Test with 16-bit com.microsoft.Q/DQ + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx", true, *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} // Simple QDQ Node Unit but shouldn't be constant folded as the node in the middle produces a graph output TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnitGraphOutput) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == 2); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + auto test_case = [](const ORTCHAR_T* model_uri, bool use_contrib_qdq, const logging::Logger& logger) { + const char* q_key = use_contrib_qdq ? "com.microsoft.QuantizeLinear" : "QuantizeLinear"; + const char* dq_key = use_contrib_qdq ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - std::unordered_map expected_op_counts = {{"QuantizeLinear", 2}, - {"DequantizeLinear", 3}, - {"Unsqueeze", 1}}; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count[q_key] == 2); + ASSERT_TRUE(op_to_count[dq_key] == 3); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); - SessionOptions session_options; - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + std::unordered_map expected_op_counts = {{q_key, 2}, + {dq_key, 3}, + {"Unsqueeze", 1}}; -#if !defined(DISABLE_CONTRIB_OPS) -// Simple (com.microsoft) QDQ Node Unit but shouldn't be constant folded as the node in the middle produces a -// graph output -TEST_F(GraphTransformationTests, ConstantFoldingMsDomainQDQNodeUnitGraphOutput) { - constexpr const ORTCHAR_T* model_uri = - MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["com.microsoft.QuantizeLinear"], 2); - ASSERT_EQ(op_to_count["com.microsoft.DequantizeLinear"], 3); - ASSERT_EQ(op_to_count["Unsqueeze"], 1); + SessionOptions session_options; + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, logger); + }; - std::unordered_map expected_op_counts = {{"com.microsoft.QuantizeLinear", 2}, - {"com.microsoft.DequantizeLinear", 3}, - {"Unsqueeze", 1}}; + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.onnx", false, *logger_); +#if !defined(DISABLE_CONTRIB_OPS) + // Test with 8-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx", true, *logger_); - SessionOptions session_options; - VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); -} + // Test with 16-bit contrib QDQ ops + test_case(MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx", true, *logger_); #endif // !defined(DISABLE_CONTRIB_OPS) +} TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_remove_dangling_inputs.onnx"; @@ -3898,12 +3860,12 @@ TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) { std::string zp_name_after_reshape_node; for (auto& node : graph.Nodes()) { if (node.Name() == "dq_2") { - dq_scale_name_before_reshape_node = node.InputDefs()[InputIndex::SCALE_ID]->Name(); - zp_name_before_reshape_node = node.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); + dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); } if (node.Name() == "q_3") { - dq_scale_name_after_reshape_node = node.InputDefs()[InputIndex::SCALE_ID]->Name(); - zp_name_after_reshape_node = node.InputDefs()[InputIndex::ZERO_POINT_ID]->Name(); + dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); } } EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 743faee3ee2a5..63577131480c6 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -39,9 +39,21 @@ namespace test { template struct IsTypeQuantLinearCompatible : utils::IsByteType {}; +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + template struct IsTypeDequantLinearCompatible : utils::IsByteType {}; +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + template <> struct IsTypeDequantLinearCompatible : std::true_type {}; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 0dfeb599d0ae3..a438a61cb9b36 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -891,37 +891,139 @@ TEST(QDQTransformerTests, Gemm_S8S8U8) { QDQTransformerGemmTests(); } +// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Gather -> Q. +template +static void RunGatherDropQDQTestCase(const std::vector& input1_shape, + const std::vector& weights_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [input1_shape, weights_shape, use_contrib_qdq](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, 0, weights_shape[0] - 1); + auto* output_arg = builder.MakeOutput(); + + // add Gather + auto* weight = builder.MakeInitializer(weights_shape, std::numeric_limits::min(), + std::numeric_limits::max()); + auto* dq_w_output = builder.MakeIntermediate(); + auto* gather_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, .003f, 1, dq_w_output, use_contrib_qdq); + builder.AddNode("Gather", {dq_w_output, input1_arg}, {gather_output}); + + // add Q + builder.AddQuantizeLinearNode(gather_output, .003f, 1, output_arg, use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Gather"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +} + +// Checks that Q/DQ nodes are dropped from DQ -> Gather -> Q. Uses 8-bit and 16-bit Q/DQ ops. TEST(QDQTransformerTests, Gather) { - auto test_case = [&](const std::vector& input1_shape, const std::vector& weights_shape, - bool use_contrib_qdq = false) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput(input1_shape, 0, weights_shape[0] - 1); - auto* output_arg = builder.MakeOutput(); + RunGatherDropQDQTestCase({12, 37}, {24, 12}); + RunGatherDropQDQTestCase({12, 37}, {24, 12}, true); // Use com.microsoft QDQ ops + RunGatherDropQDQTestCase({12, 37}, {24, 12}, true); // Use int16 com.microsoft QDQ ops +} - // add Gather - auto* weight = builder.MakeInitializer(weights_shape, -128, 127); - auto* dq_w_output = builder.MakeIntermediate(); - auto* gather_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(weight, .003f, 1, dq_w_output, use_contrib_qdq); - builder.AddNode("Gather", {dq_w_output, input1_arg}, {gather_output}); +// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Reshape -> Q. +template +static void RunReshapeDropQDQTestCase(const std::vector& input_shape, + const std::vector& new_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [input_shape, new_shape, use_contrib_qdq](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + QuantType zero_point = 1 + (qmax + qmin) / 2; + + // Add Reshape node + auto* new_shape_arg = builder.Make1DInitializer(new_shape); + auto* input_arg_dq = builder.MakeIntermediate(); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq); + builder.AddNode("Reshape", {input_arg_dq, new_shape_arg}, {reshape_output}); + + // add Q + builder.AddQuantizeLinearNode(reshape_output, .003f, zero_point, output_arg, use_contrib_qdq); + }; - // add Q - builder.AddQuantizeLinearNode(gather_output, .003f, 1, output_arg, use_contrib_qdq); - }; + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Reshape"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["Gather"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +} - TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +// Checks that Q/DQ nodes are dropped from DQ -> Reshape -> Q. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, ReshapeDropQDQ) { + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}); + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, true); // Use com.microsoft QDQ ops + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, true); // Use int16 com.microsoft QDQ ops + RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, true); // Use int16 com.microsoft QDQ ops +} + +// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> (Un)Squeeze -> Q. +template +static void RunSqueezeUnsqueezeDropQDQTestCase(const std::string& squeeze_type, + const std::vector& input_shape, + const std::vector& axes, + bool use_contrib_qdq = false) { + auto build_test_case = [squeeze_type, input_shape, axes, use_contrib_qdq](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + QuantType zero_point = 1 + (qmax + qmin) / 2; + + // Add Squeeze node + auto* axes_arg = builder.Make1DInitializer(axes); + auto* input_arg_dq = builder.MakeIntermediate(); + auto* xsqueeze_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq); + builder.AddNode(squeeze_type, {input_arg_dq, axes_arg}, {xsqueeze_output}); + + // add Q + builder.AddQuantizeLinearNode(xsqueeze_output, .003f, zero_point, output_arg, use_contrib_qdq); }; - test_case({12, 37}, {24, 12}); - test_case({12, 37}, {24, 12}, true); // Use com.microsoft QDQ ops + auto check_graph = [squeeze_type, use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[squeeze_type], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 13 /* opset_version */); +} + +// Checks that Q/DQ nodes are dropped from DQ -> Squeeze -> Q. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, SqueezeDropQDQ) { + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}); + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}, true); // Use MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Squeeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops +} + +// Checks that Q/DQ nodes are dropped from DQ -> Unsqueeze -> Q. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, UnsqueezeDropQDQ) { + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}); + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}, true); // Use MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops + RunSqueezeUnsqueezeDropQDQTestCase("Unsqueeze", {1, 3, 2, 2}, {0}, true); // Use int16 MS domain QDQ ops } TEST(QDQTransformerTests, DoubleQDQ) { @@ -1066,52 +1168,69 @@ TEST(QDQTransformerTests, DoubleQDQ) { bad_float_point, good_float_point_2, true); // Use com.microsoft QDQ ops } -TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { - auto test_case = [&](int output_index, int expected_Q_count, int expected_DQ_count, - bool use_contrib_qdq = false) { - auto graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count); - }; - TransformerTester( - BuildDoubleQDQWithoutLastOutput(output_index, use_contrib_qdq), - graph, - TransformerLevel::Default, - TransformerLevel::Level1); +template +static void RunDoubleQDQWithoutLastNodeBeingOutput(int output_index, int expected_Q_count, int expected_DQ_count, + bool use_contrib_qdq = false) { + auto graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count); }; + TransformerTester( + BuildDoubleQDQWithoutLastOutput(output_index, use_contrib_qdq), + graph, + TransformerLevel::Default, + TransformerLevel::Level1); +} + +TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { constexpr bool use_contrib_qdq = true; // For readability. - test_case(0, 2, 2); - test_case(0, 2, 2, use_contrib_qdq); - test_case(1, 2, 3); // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expect one more (3) - test_case(1, 2, 3, use_contrib_qdq); // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expect one more (3) - test_case(2, 2, 2); - test_case(2, 2, 2, use_contrib_qdq); - test_case(3, 1, 1); - test_case(3, 1, 1, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); + + // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expected one more (3) + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + + RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2); + RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2, use_contrib_qdq); + + RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1); + RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, use_contrib_qdq); +} + +// Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. +template +static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, + bool use_contrib_qdq = false) { + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Split"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + {12, 18, 19}); } -// Because split isn't one the supported ops, this will stay the same +// Test that DQ -> Split -> Q (many) is replaced with just Split for various quantization types. TEST(QDQTransformerTests, Split) { - auto test_case = [&](const std::vector& input_shape, const int64_t& axis, - bool use_contrib_qdq = false) { - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["Split"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; - TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), - check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - {12, 18, 19}); - }; - test_case({6, 18, 54}, 0); - test_case({6, 18, 54}, 0, true); // Use com.microsoft QDQ ops + RunDropSplitQDQTestCase({6, 18, 54}, 0); + RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int8 QDQ ops + RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int16 QDQ ops + RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft uint16 QDQ ops } // Because split isn't one the supported ops, this will stay the same @@ -1174,59 +1293,66 @@ TEST(QDQTransformerTests, Where) { test_case({1}, {1}, {1}, true /*use_contrib_qdq*/); } -TEST(QDQTransformerTests, Transpose) { - auto test_case = [&](const std::vector& input_shape, const std::vector& perms, - bool use_contrib_qdq = false) { - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["Transpose"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; - - TransformerTester(BuildQDQTransposeTestCase(input_shape, perms, use_contrib_qdq), - check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2); +template +static void RunDropQDQTransposeTestCase(const std::vector& input_shape, const std::vector& perms, + bool use_contrib_qdq = false) { + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["Transpose"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); }; - test_case({2, 13, 12, 37}, {0, 3, 1, 2}); - test_case({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + TransformerTester(BuildQDQTransposeTestCase(input_shape, perms, use_contrib_qdq), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2); } -TEST(QDQTransformerTests, Transpose_No_Fusion) { - auto test_case = [&](const std::vector& input1_shape, const std::vector& perms, - bool use_contrib_qdq = false) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput(input1_shape, -128, 127); - auto* output_arg = builder.MakeOutput(); - - // add DQ - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input1_arg, .003f, 1, dq_output, use_contrib_qdq); - - // add Transpose - auto* transpose_output = builder.MakeOutput(); // transpose output is graph output - Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output}); - transpose_node.AddAttribute("perm", perms); - - // add Q - builder.AddQuantizeLinearNode(transpose_output, .003f, 1, output_arg, use_contrib_qdq); - }; - - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); - }; +TEST(QDQTransformerTests, Transpose) { + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}); + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunDropQDQTransposeTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); +} + +template +static void RunQDQTransposeNoFusionTestCase(const std::vector& input1_shape, const std::vector& perms, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .003f, 1, dq_output, use_contrib_qdq); + + // add Transpose + auto* transpose_output = builder.MakeOutput(); // transpose output is graph output + Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output}); + transpose_node.AddAttribute("perm", perms); + + // add Q + builder.AddQuantizeLinearNode(transpose_output, .003f, 1, output_arg, use_contrib_qdq); + }; - TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); }; - test_case({2, 13, 12, 37}, {0, 3, 1, 2}); - test_case({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); +} + +TEST(QDQTransformerTests, Transpose_No_Fusion) { + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}); + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); + RunQDQTransposeNoFusionTestCase({2, 13, 12, 37}, {0, 3, 1, 2}, true /*use_contrib_qdq*/); } TEST(QDQTransformerTests, Resize) { @@ -1376,50 +1502,59 @@ TEST(QDQTransformerTests, ResizeReshapeSqueezeUnsqueeze) { test_case({1, 2, 26, 42}, {4}, true /*use_contrib_qdq*/); } -TEST(QDQTransformerTests, ArgMax) { - auto test_case = [&](const std::vector& input_shape, - int axis, - int keepdims, - int select_last_index, - bool use_contrib_qdq) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput(input_shape, - std::numeric_limits::min(), - std::numeric_limits::max()); - auto* output_arg = builder.MakeOutput(); +// Runs a test case that checks if the DQ node is dropped from DQ -> Op (e.g., ArgMax). +template +static void RunArgMaxDropDQTestCase(const std::vector& input_shape, + int axis, + int keepdims, + int select_last_index, + bool use_contrib_qdq, + bool expect_drop_dq = true) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, 1, dq_output, use_contrib_qdq); + + // add ArgMax + Node& argmax_node = builder.AddNode("ArgMax", {dq_output}, {output_arg}); + argmax_node.AddAttribute("axis", static_cast(axis)); + argmax_node.AddAttribute("keepdims", static_cast(keepdims)); + argmax_node.AddAttribute("select_last_index", static_cast(select_last_index)); + }; - // add DQ - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, .003f, 1, dq_output, use_contrib_qdq); + auto check_graph = [use_contrib_qdq, expect_drop_dq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["ArgMax"], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expect_drop_dq ? 0 : 1); + }; - // add ArgMax - Node& argmax_node = builder.AddNode("ArgMax", {dq_output}, {output_arg}); - argmax_node.AddAttribute("axis", static_cast(axis)); - argmax_node.AddAttribute("keepdims", static_cast(keepdims)); - argmax_node.AddAttribute("select_last_index", static_cast(select_last_index)); - }; + TransformerTester(build_test_case, check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + /* opset_version */ 13); + TransformerTester(build_test_case, check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + /* opset_version */ 19); +} - auto check_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_to_count["ArgMax"], 1); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); - }; +// Checks that the DQ node is dropped from DQ -> ArgMax. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, ArgMax) { + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, false); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/); - TransformerTester(build_test_case, check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - /* opset_version */ 13); - TransformerTester(build_test_case, check_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - /* opset_version */ 19); - }; + // Should *not* drop DQ for 16-bit DQ -> ArgMax (because ORT does not support 16-bit input types for ArgMax). + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/, false /*expect_drop_dq*/); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/, false /*expect_drop_dq*/); - test_case({2, 13, 12, 37}, 1, 0, 0, false /*use_contrib_qdq*/); - test_case({2, 13, 12, 37}, 1, 0, 0, true /*use_contrib_qdq*/); - test_case({2, 13, 12, 37}, 0, 1, 0, false /*use_contrib_qdq*/); - test_case({2, 13, 12, 37}, 0, 0, 1, false /*use_contrib_qdq*/); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 0, 1, 0, false); + RunArgMaxDropDQTestCase({2, 13, 12, 37}, 0, 0, 1, false); } TEST(QDQTransformerTests, QLinearMatMul) { diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index e5aa36fc379f4..1f4c499985ad0 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include @@ -8,6 +9,7 @@ #include "gmock/gmock.h" #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "core/framework/op_node_proto_helper.h" #include "core/framework/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -3501,150 +3503,116 @@ TEST(TransposeOptimizerTests, TestWhere) { /*opset_version*/ {15, 18}); } -TEST(TransposeOptimizerTests, TestQuantizeLinearScalar) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, {quantizelinear_1_out_0}, - q_domain); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; +// Utility function that runs TransformerTester for the graph Transpose -> QuantizeLinear -> Transpose. +// Expects the Tranpose nodes to cancel. +template +static void RunQuantizeLinearTestCase(const std::optional>& zp_input_shape, + const std::vector& zp_value_shape, + std::optional axis, + const std::string& q_domain = "") { + auto build_test_case = [&](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); + + NodeArg* scale_arg = nullptr; + NodeArg* zero_point_arg = nullptr; + + if (zp_value_shape.empty()) { // Per-tensor quantization + QuantType zp = (qmax + qmin) / 2; + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {0.05f}); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {zp}); + } else { // Per-axis quantization + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, 0.0f, 1.0f); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, qmin, qmax); + } + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, scale_arg, zero_point_arg}, + {quantizelinear_1_out_0}, q_domain); - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); + if (axis.has_value()) { + quantizelinear_1.AddAttributeProto(*axis); + } + + auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; + + auto check_optimized_graph = [](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); }; - test_case(); + TransformerTester(build_test_case, + check_optimized_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version*/ {15, 18}); +} + +TEST(TransposeOptimizerTests, TestQuantizeLinearScalar) { + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + std::optional empty_axis; // No axis value. + + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kOnnxDomain); + #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, empty_axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestQuantizeLinearScalarIgnoreAxis) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {quantizelinear_1_out_0}, q_domain); - quantizelinear_1.AddAttribute("axis", (int64_t)10); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; - - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + auto ignored_axis = utils::MakeAttribute("axis", static_cast(10)); // Should be ignored for per-tensor Q - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kOnnxDomain); - test_case(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestQuantizeLinearVector) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, {{-1}}, {2}, {2.3f, 2.4f}); - auto* input2_arg = MakeInput(builder, {{-1}}, {2}, {10, 12}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {quantizelinear_1_out_0}, q_domain); - quantizelinear_1.AddAttribute("axis", (int64_t)0); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; + std::optional> zp_input_shape = std::vector{-1}; + std::vector zp_value_shape = {2}; + auto axis = utils::MakeAttribute("axis", static_cast(0)); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kOnnxDomain); - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; - - test_case(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestQuantizeLinearVectorUnknownRank) { - auto test_case = [&](const std::string& q_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0.0, 1.0); - auto* input1_arg = MakeInput(builder, std::nullopt, {3}, {2.3f, 2.4f, 2.5f}); - auto* input2_arg = MakeInput(builder, std::nullopt, {3}, {10, 12, 13}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* quantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& quantizelinear_1 = builder.AddNode("QuantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {quantizelinear_1_out_0}, q_domain); - quantizelinear_1.AddAttribute("axis", (int64_t)1); - auto& transpose_2 = builder.AddNode("Transpose", {quantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; + std::optional> zp_unknown_shape; // Empty shape + std::vector zp_value_shape = {3}; + auto axis = utils::MakeAttribute("axis", static_cast(1)); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kOnnxDomain); - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; - - test_case(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.QuantizeLinear + // Use com.microsoft.QuantizeLinear op. + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kMSDomain); + RunQuantizeLinearTestCase(zp_unknown_shape, zp_value_shape, axis, kMSDomain); #endif } @@ -3676,158 +3644,158 @@ TEST(TransposeOptimizerTests, TestQuantizeLinearScalarOpset10) { /*opset_version*/ 10); } -TEST(TransposeOptimizerTests, TestDequantizeLinearScalarIgnoreAxis) { - auto test_case = [&](const std::string& dq_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& dequantizelinear_1 = builder.AddNode("DequantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {dequantizelinear_1_out_0}, dq_domain); - dequantizelinear_1.AddAttribute("axis", (int64_t)10); - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; +// Utility function that runs TransformerTester for the graph Transpose -> DequantizeLinear -> Transpose. +// Expects the Tranpose nodes to cancel. +template +static void RunDequantizeLinearTestCase(const std::optional>& zp_input_shape, + const std::vector& zp_value_shape, + std::optional axis, + const std::string& q_domain = "") { + auto build_test_case = [&](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, qmin, qmax); + + NodeArg* scale_arg = nullptr; + NodeArg* zero_point_arg = nullptr; + + if (zp_value_shape.empty()) { // Per-tensor quantization + QuantType zp = (qmax + qmin) / 2; + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {0.05f}); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, {zp}); + } else { // Per-axis quantization + scale_arg = MakeInput(builder, zp_input_shape, zp_value_shape, 0.0f, 1.0f); + zero_point_arg = MakeInput(builder, zp_input_shape, zp_value_shape, qmin, qmax); + } + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& dequantizelinear_1 = builder.AddNode("DequantizeLinear", {transpose_1_out_0, scale_arg, zero_point_arg}, + {dequantizelinear_1_out_0}, q_domain); + + if (axis.has_value()) { + dequantizelinear_1.AddAttributeProto(*axis); + } + + auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); + auto check_optimized_graph = [](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); }; - test_case(); + TransformerTester(build_test_case, + check_optimized_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version*/ {15, 18}); +} + +TEST(TransposeOptimizerTests, TestDequantizeLinearScalarIgnoreAxis) { + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + auto ignored_axis = utils::MakeAttribute("axis", static_cast(10)); // Should be ignored for per-tensor Q + + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kOnnxDomain); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.DequantizeLinear + // Use com.microsoft.DequantizeLinear ops + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, ignored_axis, kMSDomain); #endif } TEST(TransposeOptimizerTests, TestDequantizeLinearVector) { - auto test_case = [&](const std::string& dq_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {{2}}, {2}, {2.3f, 2.4f}); - auto* input2_arg = MakeInput(builder, {{2}}, {2}, {10, 12}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - auto& dequantizelinear_1 = builder.AddNode("DequantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, - {dequantizelinear_1_out_0}, dq_domain); - dequantizelinear_1.AddAttribute("axis", (int64_t)-4); - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; + std::optional> zp_input_shape = std::vector{2}; + std::vector zp_value_shape = {2}; + auto axis = utils::MakeAttribute("axis", static_cast(-4)); - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); - }; + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kOnnxDomain); +#if !defined(DISABLE_CONTRIB_OPS) + // Use com.microsoft.DequantizeLinear ops + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, axis, kMSDomain); +#endif +} - TransformerTester(build_test_case_1, - check_optimized_graph_1, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ {15, 18}); - }; +TEST(TransposeOptimizerTests, TestDequantizeLinearNoAxis) { + std::optional> zp_input_shape = std::vector{}; + std::vector zp_value_shape{}; + std::optional no_axis; // Empty axis value will not be set. - test_case(); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kOnnxDomain); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.DequantizeLinear + // Use com.microsoft.DequantizeLinear ops + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kMSDomain); + RunDequantizeLinearTestCase(zp_input_shape, zp_value_shape, no_axis, kMSDomain); #endif } -TEST(TransposeOptimizerTests, TestDequantizeLinearNoAxis) { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* transpose_1_out_0 = builder.MakeIntermediate(); +// Utility function that runs TransformerTester for the graph in which a single DequantizeLinear node is +// the parent of two Transpose nodes. The DQ should be duplicated by EnsureUniqueDQForNodeUnit, and the +// Transposes should be pushed. +template +static void RunDequantizeLinearTransposePropagationTestCase(const std::string& dq_domain = "") { + auto build_test_case = [dq_domain](ModelTestBuilder& builder) { + auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); + auto* scale_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); + auto* zero_point_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); + auto* transpose_1_out_0 = builder.MakeOutput(); auto* transpose_2_out_0 = builder.MakeOutput(); - auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + builder.AddNode("DequantizeLinear", {input0_arg, scale_arg, zero_point_arg}, {dequantizelinear_1_out_0}, + dq_domain); + + auto& transpose_1 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_1_out_0}); transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - builder.AddNode("DequantizeLinear", {transpose_1_out_0, input1_arg, input2_arg}, {dequantizelinear_1_out_0}); + auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); }; - auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { - int transpose_cost = EstimateTransposeCost(session.GetGraph()); - EXPECT_EQ(transpose_cost, 0); + auto check_graph = [dq_domain](InferenceSessionWrapper& session) { + const auto& graph = session.GetGraph(); + + const char* dq_count_key = (dq_domain == kMSDomain) ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; + const auto op_count = CountOpsInGraph(graph); + decltype(op_count) expected_op_count{ + {dq_count_key, 2}, // EnsureUniqueDQForNodeUnit should duplicate the original DQ + {"Transpose", 2}, + }; + ASSERT_EQ(op_count, expected_op_count); + + // Transposes should be pushed, so check for Transpose -> DQ edges + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Transpose") { + ASSERT_EQ(node.GetOutputEdgesCount(), static_cast(1)); + ASSERT_EQ(node.OutputEdgesBegin()->GetNode().OpType(), "DequantizeLinear"); + } + } }; - TransformerTester(build_test_case_1, - check_optimized_graph_1, + TransformerTester(build_test_case, + check_graph, TransformerLevel::Default, TransformerLevel::Level1, /*opset_version*/ 10); } TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) { - auto test_case = [&](const std::string& dq_domain = "") { - auto build_test_case_1 = [&](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* input1_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* input2_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_1_out_0 = builder.MakeOutput(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - builder.AddNode("DequantizeLinear", {input0_arg, input1_arg, input2_arg}, {dequantizelinear_1_out_0}, - dq_domain); - - auto& transpose_1 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; - - auto check_graph = [&](InferenceSessionWrapper& session) { - const auto& graph = session.GetGraph(); - - const char* dq_count_key = (dq_domain == kMSDomain) ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - const auto op_count = CountOpsInGraph(graph); - decltype(op_count) expected_op_count{ - {dq_count_key, 2}, // EnsureUniqueDQForNodeUnit should duplicate the original DQ - {"Transpose", 2}, - }; - ASSERT_EQ(op_count, expected_op_count); - - // Transposes should be pushed, so check for Transpose -> DQ edges - for (const auto& node : graph.Nodes()) { - if (node.OpType() == "Transpose") { - ASSERT_EQ(node.GetOutputEdgesCount(), static_cast(1)); - ASSERT_EQ(node.OutputEdgesBegin()->GetNode().OpType(), "DequantizeLinear"); - } - } - }; - - TransformerTester(build_test_case_1, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ 10); - }; - - test_case(); + RunDequantizeLinearTransposePropagationTestCase(); #if !defined(DISABLE_CONTRIB_OPS) - test_case(kMSDomain); // Use com.microsoft.DequantizeLinear + // Use com.microsoft.DequantizeLinear + RunDequantizeLinearTransposePropagationTestCase(kMSDomain); + RunDequantizeLinearTransposePropagationTestCase(kMSDomain); + RunDequantizeLinearTransposePropagationTestCase(kMSDomain); #endif } diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index e9e285411f0a7..0549051bc2387 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -21,7 +21,8 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& pads, const std::vector& dilations, const std::string& auto_pad = "NOTSET") { - return [conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, auto_pad](ModelTestBuilder& builder) { + return [conv_op_type, input_def, weights_def, bias_def, strides, pads, + dilations, auto_pad](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -77,29 +78,33 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef } // Creates a graph with a single Q/DQ Conv operator. Used for testing HTP backend. -template -static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - const std::string& auto_pad = "NOTSET") { +template +static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, auto_pad](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> auto* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); conv_inputs.push_back(input_qdq); // weights -> Q/DQ -> auto* weights = MakeTestInput(builder, weights_def); - QuantParams weights_qparams = GetTestInputQuantParams(weights_def); - auto* weights_qdq = AddQDQNodePair(builder, weights, weights_qparams.scale, weights_qparams.zero_point); + QuantParams weights_qparams = GetTestInputQuantParams(weights_def); + auto* weights_qdq = AddQDQNodePair(builder, weights, weights_qparams.scale, + weights_qparams.zero_point, use_contrib_qdq); conv_inputs.push_back(weights_qdq); // bias -> @@ -107,7 +112,7 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& con // Bias requirement taken from python quantization tool: onnx_quantizer.py::quantize_bias_static() const float bias_scale = input_qparams.scale * weights_qparams.scale; - conv_inputs.push_back(MakeTestQDQBiasInput(builder, bias_def, bias_scale)); + conv_inputs.push_back(MakeTestQDQBiasInput(builder, bias_def, bias_scale, use_contrib_qdq)); } auto* conv_output = builder.MakeIntermediate(); @@ -125,13 +130,14 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& con conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, output_qparams[0].zero_point); + AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } // Runs a Conv model on the QNN HTP backend. Checks the graph node assignment, and that inference // outputs for QNN EP and CPU EP match. -template +template static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef& input_def, const TestInputDef& weights_def, const TestInputDef& bias_def, @@ -140,6 +146,7 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef const std::vector& dilations, const std::string& auto_pad, ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false, int opset = 13, float fp32_abs_err = 1e-5f) { ProviderOptions provider_options; @@ -150,9 +157,11 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, auto_pad), - BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, - strides, pads, dilations, auto_pad), + TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, + auto_pad), + BuildQDQConvTestCase(conv_op_type, input_def, weights_def, + bias_def, strides, pads, dilations, + auto_pad, use_contrib_qdq), provider_options, opset, expected_ep_assignment, @@ -279,52 +288,56 @@ TEST_F(QnnCPUBackendTests, Convf32_large_input2_nopad_bias_initializer) { // Test 1D Conv with static weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, Conv1Df32_StaticWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights - TestInputDef({1}, true, {1.0f}), // Bias of 1.f - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer Bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } // Test 1D Conv with dynamic weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, Conv1Df32_DynamicWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({1, 2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights - TestInputDef(), // Default bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights + TestInputDef(), // Default bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } // Test 1D ConvTranspose with static weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_StaticWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights - TestInputDef({1}, true, {0.0f}), // Zero bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } // Test 1D ConvTranspose with dynamic weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_DynamicWeights_DefaultBias) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}), // Dynamic input - TestInputDef({2, 1, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights - TestInputDef({1}, true, {0.0f}), // Zero bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations "NOTSET", ExpectedEPNodeAssignment::All); } @@ -397,218 +410,448 @@ TEST_F(QnnHTPBackendTests, Test_QDQConvWithDynamicWeightsFromMul) { // Check that QNN compiles DQ -> Conv -> Q as a single unit. // Tests bias as a dynamic input. -TEST_F(QnnHTPBackendTests, ConvU8S32_bias_dynamic_input) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static input - TestInputDef({1}, false, {2.0f}), // Dynamic bias = 2.0f - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_dynamic_input) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static input + TestInputDef({1}, false, {2.0f}), // Dynamic bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); +} + +// Tests 16-bit QDQ Conv with dynamic weights and bias (uses QNN's Conv2d) +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0040235077030956745, zero_point=0. +// Expected val: 87.354057312011719 +// QNN QDQ val: 0 (err 87.354057312011719) +// CPU QDQ val: 87.3583984375 (err 0.00434112548828125) +TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S16S32_DynamicBias) { + TestInputDef input_def({1, 2, 5, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 50)); + TestInputDef weight_def({1, 2, 3, 3}, false, GetFloatDataInRange(-1.0f, 5.0f, 18)); + RunHTPConvOpTest("Conv", + input_def, // Input + weight_def.OverrideValueRange(-5.0f, 5.0f), // Weights (symmetric quant range) + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit QDQ Conv with dynamic weights and bias (uses QNN's DepthwiseConv2d) +// TODO(adrianlizarraga): FAIL: Failed to finalize QNN graph. Error code 1002 +TEST_F(QnnHTPBackendTests, DISABLED_DepthwiseConvU16S16S32_DynamicBias) { + TestInputDef input_def({1, 1, 5, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 25)); + TestInputDef weight_def({1, 1, 3, 3}, false, GetFloatDataInRange(-1.0f, 5.0f, 9)); + RunHTPConvOpTest("Conv", + input_def, // Input + weight_def.OverrideValueRange(-5.0f, 5.0f), // Weights (symmetric quant range) + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit QDQ Conv with dynamic weights and no bias. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0039929896593093872, zero_point=0. +// Expected val: 85.354057312011719 +// QNN QDQ val: 0 (err 85.354057312011719) +// CPU QDQ val: 85.358139038085938 (err 0.00408172607421875) +TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S16S32_NoBias) { + TestInputDef input_def({1, 2, 5, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 50)); + TestInputDef weight_def({1, 2, 3, 3}, false, GetFloatDataInRange(-1.0f, 5.0f, 18)); + RunHTPConvOpTest("Conv", + input_def, // Input + weight_def.OverrideValueRange(-5.0f, 5.0f), // Weights (symmetric quant range) + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit QDQ Conv with dynamic weights and no bias (uses QNN's DepthWiseConv2d) +// TODO(adrianlizarraga): FAIL: Failed to finalize QNN graph. Error code 1002 +TEST_F(QnnHTPBackendTests, DISABLED_DepthwiseConvU16S16S32_NoBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-10.0f, 10.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, false, weight_data), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft QDQ ops for 16-bit +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with static bias. +// Uses QNN's DepthwiseConv2d operator. +// TODO: Inaccuracy detected for output 'output', element 8. +// Output quant params: scale=0.0027466239407658577, zero_point=10194. +// Expected val: 152 +// QNN QDQ val: 151.8004150390625 (err 0.1995849609375) +// CPU QDQ val: 151.9981689453125 (err 0.0018310546875) +TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_StaticBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, true, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.2f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with static bias. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0040235077030956745, zero_point=0. +// Expected val: 87.354057312011719 +// QNN QDQ val: 87.559577941894531 (err 0.2055206298828125) +// CPU QDQ val: 87.398635864257812 (err 0.04457855224609375) +TEST_F(QnnHTPBackendTests, ConvU16U8S32_StaticBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 50); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 18); + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 5, 5}, false, input_data), // Input + TestInputDef({1, 2, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, true, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.6f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with dynamic bias. +// Uses QNN's DepthwiseConv2d operator. +// TODO: Inaccuracy detected for output 'output', element 1. +// Output quant params: scale=0.0027466239407658577, zero_point=10194. +// Expected val: -13.000001907348633 +// QNN QDQ val: -13.095903396606445 (err 0.0959014892578125) +// CPU QDQ val: -12.999771118164062 (err 0.0002307891845703125) +TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_DynamicBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.2f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with dynamic bias. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0040235077030956745, zero_point=0. +// Expected val: 87.354057312011719 +// QNN QDQ val: 87.559577941894531 (err 0.2055206298828125) +// CPU QDQ val: 87.398635864257812 (err 0.04457855224609375) +TEST_F(QnnHTPBackendTests, ConvU16U8S32_DynamicBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 50); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 18); + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 5, 5}, false, input_data), // Input + TestInputDef({1, 2, 3, 3}, true, weight_data), // Weights + TestInputDef({1}, false, {2.0f}), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.57f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with no bias +// TODO: Inaccuracy detected for output 'output', element 7. +// Output quant params: scale=0.0039929896593093872, zero_point=0. +// Expected val: 246.98667907714844 +// QNN QDQ val: 247.82090759277344 (err 0.834228515625) +// CPU QDQ val: 247.24192810058594 (err 0.2552490234375) +TEST_F(QnnHTPBackendTests, ConvU16U8S32_NoBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 50); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 18); + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 5, 5}, false, input_data), // Input + TestInputDef({1, 2, 3, 3}, true, weight_data), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.58f); +} + +// Tests 16-bit activations, 8-bit static weights QDQ Conv with no bias +// Uses QNN's DepthwiseConv2d operator. +// TODO: Inaccuracy detected for output 'output', element 8. +// Output quant params: scale=0.0027466239407658577, zero_point=10923. +// Expected val: 150 +// QNN QDQ val: 149.80087280273438 (err 0.199127197265625) +// CPU QDQ val: 149.99862670898438 (err 0.001373291015625) +TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_NoBias) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 25); + std::vector weight_data = GetFloatDataInRange(-1.0f, 5.0f, 9); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, input_data), // Input + TestInputDef({1, 1, 3, 3}, true, weight_data), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // Use com.microsoft QDQ ops for 16-bit + 13, + 0.2f); } // Test that dynamic weights with default bias works for Conv. This was previously not working // on older versions of QNN sdk. -TEST_F(QnnHTPBackendTests, ConvU8S32_DynamicWeight_NoBias) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Random dynamic input - TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random dynamic weights - TestInputDef(), // Default bias - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvU8U8S32_DynamicWeight_NoBias) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Input + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Test that dynamic weights with default bias works for ConvTranspose. This was previously not working // on older versions of QNN sdk. -TEST_F(QnnHTPBackendTests, ConvTransposeU8S32_DynamicWeight_NoBias) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Random dynamic input - TestInputDef({3, 1, 4, 4}, false, -10.0f, 10.0f), // Random dynamic weights - TestInputDef(), // Default bias - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvTransposeU8U8S32_DynamicWeight_NoBias) { + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 3, 32, 32}, false, -10.0f, 10.0f), // Input + TestInputDef({3, 1, 4, 4}, false, -10.0f, 10.0f), // Weights + TestInputDef(), // Bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Check that QNN compiles DQ -> Conv -> Q as a single unit. // Tests bias as an initializer. TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static weight - TestInputDef({1}, true, {2.0f}), // Initializer bias = 2.0f - {1, 1}, // Strides - {0, 0, 0, 0}, // Pads - {1, 1}, // Dilations - "NOTSET", - ExpectedEPNodeAssignment::All); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({1, 1, 3, 3}, true, -10.0f, 10.0f), // Random static weight + TestInputDef({1}, true, {2.0f}), // Initializer bias + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Tests 1D Conv with bias as an initializer. -TEST_F(QnnHTPBackendTests, Conv1DU8S32_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0, 0}, // pads - {1}, // dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_bias_initializer) { + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0, 0}, // pads + {1}, // dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Tests 1D ConvTranspose with bias as an initializer. -TEST_F(QnnHTPBackendTests, ConvTranspose1DU8S32_bias_initializer) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0, 0}, // pads - {1}, // dilations - "NOTSET", - ExpectedEPNodeAssignment::All); +TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_bias_initializer) { + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0, 0}, // pads + {1}, // dilations + "NOTSET", + ExpectedEPNodeAssignment::All); } // Tests auto_pad value "SAME_UPPER" on HTP backend (compares to CPU EP). -TEST_F(QnnHTPBackendTests, ConvU8S32_AutoPadUpper) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input - TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); +TEST_F(QnnHTPBackendTests, ConvU8U8S32_AutoPadUpper) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input + TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests Conv1d auto_pad value "SAME_UPPER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_AutoPadUpper) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests TransposeConv1d auto_pad value "SAME_UPPER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadUpper) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests Conv's auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvU8U8S32_AutoPadLower) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input - TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + RunHTPConvOpTest("Conv", + TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input + TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests ConvTranspose's auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvTransposeU8U8S32_AutoPadLower) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input - TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 1, 5, 5}, false, 0.f, 10.f), // Dynamic input + TestInputDef({1, 1, 4, 4}, true, -1.f, 1.f), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests Conv1d auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_AutoPadLower) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } // Tests ConvTranspose 1d auto_pad value "SAME_LOWER" on HTP backend (compares to CPU EP). TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadLower) { - RunHTPConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight - TestInputDef({1}, true, {1.0f}), // Initializer bias = 1.0f - {1}, // strides - {0}, // pads - {1}, // dilations - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All, - 13); + std::vector input_data = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + RunHTPConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.f, 2.f, 3.f, 4.f}), // Static weight + TestInputDef({1}, true, {1.0f}), // Initializer bias + {1}, // strides + {0}, // pads + {1}, // dilations + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + false, // use_contrib_qdq + 13); } TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input1_padding_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 3, 60, 452}, false, 0.f, 10.f), // Dynamic input - TestInputDef({16, 3, 3, 3}, true, -1.f, 1.f), // Static weights - TestInputDef({16}, true, std::vector(16, 1.f)), // Initializer bias = 1.f, 1.f, ... - {1, 1}, - {1, 1, 1, 1}, - {1, 1}, - "NOTSET", - ExpectedEPNodeAssignment::All); -} - -TEST_F(QnnHTPBackendTests, ConvU8S32_large_input2_bias_initializer) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 128, 8, 56}, false, 0.f, 10.f), // Dynamic input - TestInputDef({32, 128, 1, 1}, true, -1.f, 1.f), // Random static weights - TestInputDef({32}, true, -1.f, 1.f), // Random initializer bias - {1, 1}, - {0, 0, 0, 0}, - {1, 1}, - "NOTSET", - ExpectedEPNodeAssignment::All); + RunHTPConvOpTest("Conv", + TestInputDef({1, 3, 60, 452}, false, 0.f, 10.f), // Dynamic input + TestInputDef({16, 3, 3, 3}, true, -1.f, 1.f), // Static weights + TestInputDef({16}, true, std::vector(16, 1.f)), // Initializer bias + {1, 1}, + {1, 1, 1, 1}, + {1, 1}, + "NOTSET", + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input2_bias_initializer) { + RunHTPConvOpTest("Conv", + TestInputDef({1, 128, 8, 56}, false, 0.f, 10.f), // Dynamic input + TestInputDef({32, 128, 1, 1}, true, -1.f, 1.f), // Random static weights + TestInputDef({32}, true, -1.f, 1.f), // Random initializer bias + {1, 1}, + {0, 0, 0, 0}, + {1, 1}, + "NOTSET", + ExpectedEPNodeAssignment::All); } TEST_F(QnnHTPBackendTests, ConvU8U8S32_LargeInput_Dilations_Pads) { - RunHTPConvOpTest("Conv", - TestInputDef({1, 3, 768, 1152}, false, 0.f, 10.f), // Dynamic input - TestInputDef({64, 3, 7, 7}, true, -1.f, 1.f), // Random static weights - TestInputDef({64}, true, -1.f, 1.f), // Random initializer bias - {2, 2}, // strides - {3, 3, 3, 3}, // pads - {1, 1}, // dilations - "NOTSET", // auto_pad - ExpectedEPNodeAssignment::All); + RunHTPConvOpTest("Conv", + TestInputDef({1, 3, 768, 1152}, false, 0.f, 10.f), // Dynamic input + TestInputDef({64, 3, 7, 7}, true, -1.f, 1.f), // Static weights + TestInputDef({64}, true, -1.f, 1.f), // Initializer bias + {2, 2}, // strides + {3, 3, 3, 3}, // pads + {1, 1}, // dilations + "NOTSET", // auto_pad + ExpectedEPNodeAssignment::All); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 6edb6ecdcfb1a..e721ccbcb45a9 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -27,28 +27,31 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de } // Returns a function that creates a graph with a QDQ MatMul operator. -template -static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, - const TestInputDef& input2_def) { - return [input1_def, input2_def](ModelTestBuilder& builder, - std::vector>& output_qparams) { +template +static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, + const TestInputDef& input2_def, + bool use_contrib_qdq) { + return [input1_def, input2_def, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { // input1 -> Q -> DQ -> NodeArg* input1 = MakeTestInput(builder, input1_def); - QuantParams input1_qparams = GetTestInputQuantParams(input1_def); - auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, + use_contrib_qdq); // input2 -> Q -> DQ -> NodeArg* input2 = MakeTestInput(builder, input2_def); - QuantParams input2_qparams = GetTestInputQuantParams(input2_def); - auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point); + QuantParams input2_qparams = GetTestInputQuantParams(input2_def); + auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point, + use_contrib_qdq); // MatMul auto* op_output = builder.MakeIntermediate(); builder.AddNode("MatMul", {input1_qdq, input2_qdq}, {op_output}); // op_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point); + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -75,11 +78,13 @@ static void RunMatMulOpOpTest(const TestInputDef& input1_def, // Runs a QDQ MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the // QDQ model is accurate on QNN EP (compared to CPU EP). -template +template static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, const TestInputDef& input2_def, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 18) { + int opset = 18, + bool use_contrib_qdq = false, + float fp32_abs_err = 1e-4f) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -88,11 +93,12 @@ static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, #endif TestQDQModelAccuracy(BuildMatMulOpTestCase(input1_def, input2_def), - BuildMatMulOpQDQTestCase(input1_def, input2_def), + BuildMatMulOpQDQTestCase(input1_def, input2_def, + use_contrib_qdq), provider_options, opset, expected_ep_assignment, - 1e-5f); + fp32_abs_err); } // @@ -127,16 +133,68 @@ TEST_F(QnnCPUBackendTests, MatMulOp_Broadcast) { // TEST_F(QnnHTPBackendTests, MatMulOp_HTP_u8) { - RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}), - TestInputDef({3, 2}, false, {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}), - ExpectedEPNodeAssignment::All, 18); + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, false, input1_data), + ExpectedEPNodeAssignment::All, 18); } -// Test MatMul broadcasting +// Test QDQ MatMul with 16-bit act, 8-bit weights (static) +// TODO: (SLIGHT) Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 98 +// QNN QDQ val: 97.720298767089844 (err 0.27970123291015625) +// CPU QDQ val: 97.726402282714844 (err 0.27359771728515625) +TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, true, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + +// Test 16-bit QDQ MatMul with static weights +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 98 +// QNN QDQ val: 0.65461206436157227 (err 97.345390319824219) +// CPU QDQ val: 98.002593994140625 (err 0.002593994140625) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, true, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ MatMul broadcasting TEST_F(QnnHTPBackendTests, MatMulOp_Broadcast) { - RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, -10.0f, 10.0f), - TestInputDef({64, 32}, false, -10.0f, 10.0f), - ExpectedEPNodeAssignment::All, 18); + RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, -10.0f, 10.0f), + TestInputDef({64, 32}, false, -10.0f, 10.0f), + ExpectedEPNodeAssignment::All, 18); +} + +// Test 16-bit QDQ MatMul broadcasting +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.0028538699261844158, zero_point=6050. +// Expected val: 169.76341247558594 +// QNN QDQ val: -16.675161361694336 (err 186.43856811523438) +// CPU QDQ val: 169.762451171875 (err 0.0009613037109375) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_Broadcast_A16_W16) { + std::vector input_a = GetFloatDataInRange(-10.0f, 10.0f, 28 * 64); + std::vector input_b = GetFloatDataInRange(-10.0f, 10.0f, 64 * 32); + + RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, input_a), + TestInputDef({64, 32}, true, input_b), + ExpectedEPNodeAssignment::All, + 18, + true); // Use com.microsoft Q/DQ ops } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 80b929e9dafbe..a441e828c0cc6 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -260,8 +260,8 @@ TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset18) { TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", true); } -#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 548f80675a622..724e9a11cd781 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -116,7 +116,8 @@ void InferenceModel(const std::string& model_data, const char* log_id, ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &output_vals)); } -NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale) { +NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale, + bool use_contrib_qdq) { NodeArg* bias_int32 = nullptr; // Bias must be int32 to be detected as a QDQ node unit. @@ -124,7 +125,8 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef bias_int32_def(bias_def.GetShape(), bias_def.IsInitializer(), static_cast(rand_info.min / bias_scale), + TestInputDef bias_int32_def(bias_def.GetShape(), bias_def.IsInitializer(), + static_cast(rand_info.min / bias_scale), static_cast(rand_info.max / bias_scale)); bias_int32 = MakeTestInput(builder, bias_int32_def); } else { @@ -143,7 +145,7 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef(bias_int32, bias_scale, 0, bias); + builder.AddDequantizeLinearNode(bias_int32, bias_scale, 0, bias, use_contrib_qdq); return bias; } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 1b0b85319918f..fd572fa17f2b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -266,6 +266,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe std::vector output_names; InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, f32_helper.feeds_, output_names, cpu_f32_outputs); + ASSERT_FALSE(cpu_f32_outputs.empty()); + const size_t num_outputs = cpu_f32_outputs.size(); // Compute output range(s) and quantization params. @@ -432,7 +434,8 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef manual quantization (int32) => DQ => final float bias -NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale); +NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale, + bool use_contrib_qdq = false); /** * Returns a function that builds a model with a single operator with N inputs of the same element type. @@ -479,9 +482,10 @@ template inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, const std::vector>& input_defs, const std::vector& attrs, - const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs, attrs, op_domain](ModelTestBuilder& builder, - std::vector>& output_qparams) { + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { + return [op_type, input_defs, attrs, op_domain, + use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; op_inputs.reserve(input_defs.size()); @@ -489,7 +493,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty NodeArg* input = MakeTestInput(builder, input_def); QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, - input_qparams.zero_point); + input_qparams.zero_point, use_contrib_qdq); op_inputs.push_back(input_after_qdq); } @@ -503,7 +507,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty // op_output -> Q -> DQ -> output AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point); + output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -563,4 +567,4 @@ bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version); } // namespace test } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) \ No newline at end of file +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/reduce_op_test.cc b/onnxruntime/test/providers/qnn/reduce_op_test.cc index c3c2b578a1bd0..57252f93492e5 100644 --- a/onnxruntime/test/providers/qnn/reduce_op_test.cc +++ b/onnxruntime/test/providers/qnn/reduce_op_test.cc @@ -648,4 +648,4 @@ TEST_F(QnnHTPBackendTests, ReduceMeanS8Opset18) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index eed12af3c703c..63498982930f5 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -104,6 +104,7 @@ static void RunQDQOpTest(const std::string& op_type, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false, float fp32_abs_err = 1e-4f) { ProviderOptions provider_options; #if defined(_WIN32) @@ -113,7 +114,7 @@ static void RunQDQOpTest(const std::string& op_type, #endif TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs, op_domain), - BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain), + BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain, use_contrib_qdq), provider_options, opset_version, expected_ep_assignment, @@ -151,6 +152,17 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Sigmoid. +TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid_U16) { + RunQDQOpTest("Sigmoid", + {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use MS domain Q/DQ ops +} + // Test the accuracy of QDQ Tanh. TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { RunQDQOpTest("Tanh", @@ -160,6 +172,17 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Tanh. +TEST_F(QnnHTPBackendTests, UnaryOp_Tanh_U16) { + RunQDQOpTest("Tanh", + {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use MS domain Q/DQ ops +} + // Check that QNN compiles DQ -> Gelu -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Gelu) { @@ -171,6 +194,24 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Gelu) { kMSDomain); // GeLu is a contrib op. } +// Tests accuracy of 16-bit QDQ GeLu. +// TODO(adrianlizarraga): Inaccuracy detected for output 'output', element 5. +// Output quant params: scale=0.00015259021893143654, zero_point=0. +// Expected val: 10 +// QNN QDQ val: 9.997406005859375 (err 0.002593994140625) +// CPU QDQ val: 9.999847412109375 (err 0.000152587890625) +TEST_F(QnnHTPBackendTests, UnaryOp_Gelu_U16) { + const std::vector input_data = {-10.0f, -8.4f, 0.0f, 4.3f, 7.1f, 10.0f}; + RunQDQOpTest("Gelu", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 11, + ExpectedEPNodeAssignment::All, + kMSDomain, // GeLu is a contrib op. + true, // Use MS domain Q/DQ ops. + 0.0025f); // TODO(adrianlizarraga): Accuracy +} + // Check that QNN compiles DQ -> Elu -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Elu) { @@ -181,6 +222,23 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Elu) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Elu. +// TODO(adrianlizarraga): Re-enable. This works on QNN SDK 2.14.1! +// Inaccuracy detected for output 'output', element 1. +// Output quant params: scale=0.00011093531065853313, zero_point=8992. +// Expected val: -0.99751651287078857 +// QNN QDQ val: 6.2726154327392578 (err 7.2701320648193359) +// CPU QDQ val: -0.99753034114837646 (err 1.3828277587890625e-05) +TEST_F(QnnHTPBackendTests, DISABLE_UnaryOp_Elu_U16) { + RunQDQOpTest("Elu", + {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, + {}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); +} + // Tests accuracy of QDQ Relu // TODO: Relu does not set negative values to zero! // Could be due to ORT's ReluQuantFusion! @@ -208,6 +266,24 @@ TEST_F(QnnHTPBackendTests, UnaryOp_HardSwish) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ HardSwish +// TODO(adrianlizarraga): Inaccuracy detected for output 'output', element 5. +// Output quant params: scale=0.00015259021893143654, zero_point=0. +// Expected val: 10 +// QNN QDQ val: 9.999237060546875 (err 0.000762939453125) +// CPU QDQ val: 9.999847412109375 (err 0.000152587890625) +TEST_F(QnnHTPBackendTests, UnaryOp_HardSwish_U16) { + const std::vector input_data = {-10.0f, -8.4f, 0.0f, 4.3f, 7.1f, 10.0f}; + RunQDQOpTest("HardSwish", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 14, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true, + 0.001f); // TODO(adrianlizarraga): Remove additional tolerance needed for inaccuracy +} + // Check that QNN compiles DQ -> Atan -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Atan) { @@ -218,6 +294,24 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Atan) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Atan +// TODO(adrianlizarraga): Inaccuracy detected for output 'output', element 1. +// Output quant params: scale=4.4895936298416927e-05, zero_point=32768. +// Expected val: -1.4219063520431519 +// QNN QDQ val: -1.4220787286758423 (err 0.00017237663269042969) +// CPU QDQ val: -1.4218991994857788 (err 7.152557373046875e-06) +TEST_F(QnnHTPBackendTests, UnaryOp_Atan_U16) { + const std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + RunQDQOpTest("Atan", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 14, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Atan domain + true, // Q/DQ op domain is com.microsoft + 1.8e-4f); +} + // Check that QNN compiles DQ -> Asin -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Asin) { @@ -238,6 +332,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Sign) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Sign +TEST_F(QnnHTPBackendTests, UnaryOp_Sign_U16) { + const std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + RunQDQOpTest("Sign", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Sign op domain + true); // Use com.microsoft Q/DQ op domains +} + // Check that QNN compiles DQ -> Sin -> Q as a single unit. // Use an input of rank 3. TEST_F(QnnHTPBackendTests, UnaryOp_Sin) { @@ -260,7 +366,7 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Cos) { // Check that QNN compiles DQ -> Cos -> Q as a single unit. // Use an input of rank 3. -TEST_F(QnnHTPBackendTests, UnaryOp_Cos_Inaccurate) { +TEST_F(QnnHTPBackendTests, UnaryOp_Cos_InaccurateFixed) { RunQDQOpTest("Cos", {TestInputDef({1, 2, 3}, false, {-3.14159f, -1.88436f, -0.542863f, 0.0f, 1.05622f, 3.14159f})}, {}, @@ -326,6 +432,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Round) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Log +TEST_F(QnnHTPBackendTests, UnaryOp_Log_U16) { + const std::vector input_data = GetFloatDataInRange(1.0f, 128.0f, 6); + RunQDQOpTest("Log", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Log op domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Check that QNN compiles DQ -> Softmax -> Q as a single unit. // Test that the default axis (-1) for SoftMax opset 13 works. TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_DefaultAxis) { @@ -336,6 +454,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_DefaultAxis) { ExpectedEPNodeAssignment::All); } +// Tests accuracy of 16-bit QDQ Softmax (opset 13) with default axis +TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_U16_DefaultAxis) { + const std::vector input_data = GetFloatDataInRange(-5.0f, 5.0f, 6); + RunQDQOpTest("Softmax", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, // Uses default axis of -1 for opset 13 + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Sofmax's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Check that QNN compiles DQ -> Softmax -> Q as a single unit. // Test that an axis != -1 is not supported. TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_UnsupportedAxis) { @@ -410,7 +540,7 @@ TEST_F(QnnHTPBackendTests, UnaryOp_LogSoftmax11_SetValidAxis) { ExpectedEPNodeAssignment::All); } -// Test QDQ Abs op. +// Test accuracy of QDQ Abs op. TEST_F(QnnHTPBackendTests, UnaryOp_Abs) { RunQDQOpTest("Abs", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, @@ -419,7 +549,19 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Abs) { ExpectedEPNodeAssignment::All); } -// Test QDQ Ceil op. +// Test accuracy of 16-bit QDQ Abs op. +TEST_F(QnnHTPBackendTests, UnaryOp_Abs_U16) { + const std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + RunQDQOpTest("Abs", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Abs op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + +// Test accuracy of QDQ Ceil op. TEST_F(QnnHTPBackendTests, UnaryOp_Ceil) { const std::vector input_data = GetFloatDataInRange(-12.0f, 12.0f, 6); RunQDQOpTest("Ceil", @@ -429,6 +571,18 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Ceil) { ExpectedEPNodeAssignment::All); } +// Test accuracy of 16-bit QDQ Ceil op. +TEST_F(QnnHTPBackendTests, UnaryOp_Ceil_U16) { + const std::vector input_data = GetFloatDataInRange(-12.0f, 12.0f, 6); + RunQDQOpTest("Ceil", + {TestInputDef({1, 2, 3}, false, input_data)}, + {}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Ceil op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Test QDQ Floor op. TEST_F(QnnHTPBackendTests, UnaryOp_Floor) { const std::vector input_data = GetFloatDataInRange(-12.0f, 12.0f, 6); @@ -457,6 +611,26 @@ TEST_F(QnnHTPBackendTests, DepthToSpaceOp_CRD) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ DepthToSpace. +TEST_F(QnnHTPBackendTests, DepthToSpaceOp_U16_CRD) { + const std::vector X = {0., 1., 2., + 3., 4., 5., + 9., 10., 11., + 12., 13., 14., + 18., 19., 20., + 21., 22., 23., + 27., 28., 29., + 30., 31., 32.}; + RunQDQOpTest("DepthToSpace", + {TestInputDef({1, 4, 2, 3}, false, X)}, + {utils::MakeAttribute("blocksize", static_cast(2)), + utils::MakeAttribute("mode", "CRD")}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Test QDQ DepthToSpace. TEST_F(QnnHTPBackendTests, DepthToSpaceOp_DCR) { const std::vector X = {0., 1., 2., @@ -489,6 +663,22 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ SpaceToDepth. +TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { + const std::vector X = {0.0f, 0.1f, 0.2f, 0.3f, + 1.0f, 1.1f, 1.2f, 1.3f, + + 2.0f, 2.1f, 2.2f, 2.3f, + 3.0f, 3.1f, 3.2f, 3.3f}; + RunQDQOpTest("SpaceToDepth", + {TestInputDef({1, 2, 2, 4}, false, X)}, + {utils::MakeAttribute("blocksize", static_cast(2))}, + 11, + ExpectedEPNodeAssignment::All, + kOnnxDomain, // Op's domain + true); // Use com.microsoft domain for Q/DQ ops +} + // Run QDQ model on HTP twice // 1st run will generate the Qnn context cache binary file // 2nd run will load and run from Qnn context cache binary file @@ -561,7 +751,7 @@ TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { ExpectedEPNodeAssignment::All); } -// Test QDQ Add +// Test 8-bit QDQ Add TEST_F(QnnHTPBackendTests, BinaryOp_Add4D) { RunQDQOpTest("Add", {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), @@ -571,7 +761,20 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Add4D) { ExpectedEPNodeAssignment::All); } -// Test QDQ Sub +// Test 16-bit QDQ Add +TEST_F(QnnHTPBackendTests, BinaryOp_Add4D_U16) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQOpTest("Add", + {TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({1, 2, 2, 2}, false, input_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Sub TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D) { RunQDQOpTest("Sub", {TestInputDef({1, 3, 8, 8}, false, -10.0f, 10.0f), @@ -581,6 +784,20 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Sub +TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D_U16) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(0.0f, 20.0f, 8); + RunQDQOpTest("Sub", + {TestInputDef({1, 2, 2, 2}, false, input0_data), + TestInputDef({1, 2, 2, 2}, false, input1_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + TEST_F(QnnHTPBackendTests, BinaryOp_Sub4D_LargeInputs) { RunQDQOpTest("Sub", {TestInputDef({1, 3, 768, 1152}, false, -1.0f, 1.0f), @@ -656,6 +873,20 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Div4D_SmallInputs) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Sub with small input values. +TEST_F(QnnHTPBackendTests, BinaryOp_Div4D_U16_SmallInputs) { + std::vector input0_data = {-10.0f, -8.0f, -1.0f, 0.0f, 1.0f, 2.1f, 8.0f, 10.0f}; + std::vector input1_data = {5.0f, 4.0f, 1.0f, 1.0f, 1.0f, 4.0f, 4.0f, 5.0f}; + RunQDQOpTest("Div", + {TestInputDef({1, 2, 2, 2}, false, input0_data), + TestInputDef({1, 2, 2, 2}, false, input1_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + // TODO: Enable when this is fixed. // QNN v2.13: Inaccuracy detected for output 'output', element 2551923. // Output quant params: scale=4100.92626953125, zero_point=126. @@ -680,7 +911,7 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Div4D_Broadcast) { ExpectedEPNodeAssignment::All); } -// Test QDQ Mul +// Test 8-bit QDQ Mul TEST_F(QnnHTPBackendTests, BinaryOp_Mul4D) { std::vector input_data = GetFloatDataInRange(-10.0, 10.0f, 8); RunQDQOpTest("Mul", @@ -691,6 +922,19 @@ TEST_F(QnnHTPBackendTests, BinaryOp_Mul4D) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Mul +TEST_F(QnnHTPBackendTests, BinaryOp_Mul4D_U16) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQOpTest("Mul", + {TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({1, 2, 2, 2}, false, input_data)}, + {}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + // Test And TEST_F(QnnHTPBackendTests, BinaryOp_And4D) { RunOpTest("And", @@ -711,7 +955,7 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { ExpectedEPNodeAssignment::None); } -// Test QDQ GridSample with bilinear +// Test 8-bit QDQ GridSample with bilinear TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { RunQDQOpTest("GridSample", {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), @@ -723,7 +967,21 @@ TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { ExpectedEPNodeAssignment::All); } -// Test QDQ GridSample with align corners +// Test 16-bit QDQ GridSample with bilinear +TEST_F(QnnHTPBackendTests, GridSample_U16_Bilinear) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), + TestInputDef({1, 2, 4, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 16))}, + {utils::MakeAttribute("align_corners", static_cast(0)), + utils::MakeAttribute("mode", "bilinear"), + utils::MakeAttribute("padding_mode", "zeros")}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ GridSample with align corners TEST_F(QnnHTPBackendTests, GridSample_AlignCorners) { RunQDQOpTest("GridSample", {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), @@ -735,6 +993,20 @@ TEST_F(QnnHTPBackendTests, GridSample_AlignCorners) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ GridSample with align corners +TEST_F(QnnHTPBackendTests, GridSample_U16_AlignCorners) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), + TestInputDef({1, 2, 4, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 16))}, + {utils::MakeAttribute("align_corners", static_cast(1)), + utils::MakeAttribute("mode", "bilinear"), + utils::MakeAttribute("padding_mode", "zeros")}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); // Use com.microsoft Q/DQ ops +} + // Test QDQ GridSample with padding mode: border // Inaccuracy detected for output 'output', element 0. // Output quant params: scale=0.046370312571525574, zero_point=129. @@ -751,7 +1023,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_GridSample_BorderPadding) { ExpectedEPNodeAssignment::All); } -// Test QDQ GridSample with nearest mode +// Test 8-bit QDQ GridSample with nearest mode TEST_F(QnnHTPBackendTests, GridSample_Nearest) { RunQDQOpTest("GridSample", {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), @@ -761,6 +1033,18 @@ TEST_F(QnnHTPBackendTests, GridSample_Nearest) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ GridSample with nearest mode +TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 1, 3, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 6)), + TestInputDef({1, 2, 4, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 16))}, + {utils::MakeAttribute("mode", "nearest")}, + 17, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); +} + // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. @@ -801,4 +1085,4 @@ TEST_F(QnnHTPBackendTests, VariadicOp_Concat_2Inputs_2ndAxis) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 3c5f516af4846..5c2db435d7fb5 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -566,28 +566,30 @@ def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape onnx.save(model, output_model_path) - def verify(self, per_channel, is_quant_type_int8): + def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=None): np.random.seed(1) model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") - model_int8_qdq_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_quant_qdq.{per_channel}.onnx") - model_int8_qop_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_quant_qop.{per_channel}.onnx") + model_qdq_path = str( + Path(self._tmp_model_dir.name) / f"conv_relu_quant_qdq.{activation_type}.{weight_type}.{per_channel}.onnx" + ) data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) quantize_static( model_fp32_path, - model_int8_qdq_path, + model_qdq_path, data_reader, quant_format=QuantFormat.QDQ, per_channel=per_channel, reduce_range=per_channel, - activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, ) data_reader.rewind() # topo sort check check_op_type_order( self, - model_int8_qdq_path, + model_qdq_path, [ "DequantizeLinear", "QuantizeLinear", @@ -597,9 +599,15 @@ def verify(self, per_channel, is_quant_type_int8): "DequantizeLinear", ], ) - check_model_correctness(self, model_fp32_path, model_int8_qdq_path, data_reader.get_next()) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def verify_qop(self, per_channel, is_quant_type_int8): + np.random.seed(1) + model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") + model_int8_qop_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_quant_qop.{per_channel}.onnx") + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) - data_reader.rewind() quantize_static( model_fp32_path, model_int8_qop_path, @@ -617,10 +625,25 @@ def verify(self, per_channel, is_quant_type_int8): def test_quantize_conv_without_bias(self): # only test cases per_channel=True and reduce_range=True to avoid saturation on avx2 and avx512 for weight type int8 - self.verify(True, True) # per_channel:False, is_quant_type_int8:True + self.verify_qdq(True, QuantType.QInt8, QuantType.QInt8) # per_channel:True + self.verify_qop(True, True) # per_channel:True, is_quant_type_int8:True - self.verify(False, False) # per_channel:False, is_quant_type_int8:False - self.verify(True, False) # per_channel:True, is_quant_type_int8:False + self.verify_qdq(False, QuantType.QUInt8, QuantType.QUInt8) # per_channel:False + self.verify_qop(False, False) # per_channel:False, is_quant_type_int8:False + + self.verify_qdq(True, QuantType.QUInt8, QuantType.QUInt8) # per_channel:True + self.verify_qop(True, False) # per_channel:True, is_quant_type_int8:False + + # 16-bit QDQ via contrib ops + self.verify_qdq(False, QuantType.QUInt16, QuantType.QUInt16, {"UseQDQContribOps": True}) + self.verify_qdq(False, QuantType.QInt16, QuantType.QInt16, {"UseQDQContribOps": True}) + self.verify_qdq(False, QuantType.QUInt16, QuantType.QUInt8, {"UseQDQContribOps": True}) + self.verify_qdq(False, QuantType.QInt16, QuantType.QInt8, {"UseQDQContribOps": True}) + + self.verify_qdq(True, QuantType.QUInt16, QuantType.QUInt16, {"UseQDQContribOps": True}) + self.verify_qdq(True, QuantType.QInt16, QuantType.QInt16, {"UseQDQContribOps": True}) + self.verify_qdq(True, QuantType.QUInt16, QuantType.QUInt8, {"UseQDQContribOps": True}) + self.verify_qdq(True, QuantType.QInt16, QuantType.QInt8, {"UseQDQContribOps": True}) def test_quantize_relu_conv(self): float_model_path = str(Path(self._tmp_model_dir.name) / "float_relu_convs_model.onnx") diff --git a/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py b/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py index 3df127f5d356d..f74342403f4c3 100644 --- a/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py +++ b/onnxruntime/test/testdata/transform/convert_qdq_ops_to_ms_domain.py @@ -1,59 +1,154 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """ Loads a model and updates the domain of QuantizeLinear and DequantizeLinear nodes to 'com.microsoft'. +Optionally updates zero-points to 16bit data types. + This is used to create models for testing QDQ transformations with the contrib QDQ ops. -Usage: python3 convert_qdq_ops_to_ms_domain.py +Usage: +python3 convert_qdq_ops_to_ms_domain.py --input_model --output_model --use_16bit_qdq Models created with this script: - qdq_with_multi_consumer_dq_nodes.fixed.qdq_contrib.onnx +- qdq_with_multi_consumer_dq_nodes.fixed.qdq16_contrib.onnx - fusion/constant_folding_dequantizelinear.qdq_contrib.onnx +- fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx - fusion/constant_folding_qdq_node_unit.qdq_contrib.onnx +- fusion/constant_folding_qdq_node_unit.qdq16_contrib.onnx - fusion/constant_folding_qdq_node_unit.graph_output.qdq_contrib.onnx +- fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx """ +from __future__ import annotations + +import argparse import os +import struct import sys import onnx +from onnx import shape_inference QDQ_OPS = ("QuantizeLinear", "DequantizeLinear") - - -def print_usage(prog_name: str): - """ - Prints the program's command-line arguments and usage. - """ - - print(f"Usage: {prog_name} ") - - -def update_qdq_node_domains(graph): - """ - Updates the domain of all QuantizeLinear and DequantizeLinear nodes - in a graph to 'com.microsoft'. - """ +QDQ_CONVERT_TYPES = {onnx.TensorProto.UINT8: onnx.TensorProto.UINT16, onnx.TensorProto.INT8: onnx.TensorProto.INT16} +TYPE_TO_STRUCT_LABEL = { + onnx.TensorProto.UINT8: "B", + onnx.TensorProto.INT8: "b", + onnx.TensorProto.UINT16: "H", + onnx.TensorProto.INT16: "h", +} + + +def convert_initializer_to_16bits(initializer: onnx.TensorProto, target_type: onnx.TensorProto.DataType): + byte_order = ">" if sys.byteorder == "big" else "<" + byte_label = TYPE_TO_STRUCT_LABEL[initializer.data_type] + short_label = TYPE_TO_STRUCT_LABEL[target_type] + + # Do not support external data + if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL: + raise Exception("Do not support initializers with external data") + + # Need to convert raw_data bytes to 16-bit values. + # NOTE: For tensors that use .int32_data instead of .raw_data, we don't need any special handling + # other than updating the data type. This is because the upper 24 bits are already cleared to zero. + if initializer.HasField("raw_data"): + num_byte_vals = len(initializer.raw_data) + + # Extract 8-bit values as int32s + int32_vals = struct.unpack(f"{byte_order}{num_byte_vals}{byte_label}", initializer.raw_data) + + # Repack int32 values as 16-bit values + initializer.raw_data = struct.pack(f"{byte_order}{num_byte_vals}{short_label}", *int32_vals) + + initializer.data_type = target_type + + +def convert_qdq_op_to_16bit( + name_to_initializer: dict[str, onnx.TensorProto], + name_to_values: dict[str, onnx.ValueInfoProto], + name_to_inputs: dict[str, onnx.ValueInfoProto], + name_to_outputs: dict[str, onnx.ValueInfoProto], + node: onnx.NodeProto, +): + zp_input = node.input[2] if len(node.input) > 2 else None + + if zp_input in name_to_initializer: + zp_initializer = name_to_initializer[zp_input] + + zp_target_type = QDQ_CONVERT_TYPES.get(zp_initializer.data_type) + if zp_target_type: + convert_initializer_to_16bits(zp_initializer, zp_target_type) + + if node.op_type == "DequantizeLinear": + input0 = node.input[0] + + if input0 in name_to_initializer: + input_initializer = name_to_initializer[input0] + input_target_type = QDQ_CONVERT_TYPES.get(input_initializer.data_type) + if input_target_type: + convert_initializer_to_16bits(input_initializer, input_target_type) + elif input0 in name_to_values: + input_val = name_to_values[input0] + input_target_type = QDQ_CONVERT_TYPES.get(input_val.type.tensor_type.elem_type) + if input_target_type: + input_val.type.tensor_type.elem_type = input_target_type + elif input0 in name_to_inputs: + input_val = name_to_inputs[input0] + input_target_type = QDQ_CONVERT_TYPES.get(input_val.type.tensor_type.elem_type) + if input_target_type: + input_val.type.tensor_type.elem_type = input_target_type + else: + # QuantizeLinear + output0 = node.output[0] + + if output0 in name_to_values: + output_val = name_to_values[output0] + output_target_type = QDQ_CONVERT_TYPES.get(output_val.type.tensor_type.elem_type) + if output_target_type: + output_val.type.tensor_type.elem_type = output_target_type + elif output0 in name_to_outputs: + output_val = name_to_outputs[output0] + output_target_type = QDQ_CONVERT_TYPES.get(output_val.type.tensor_type.elem_type) + if output_target_type: + output_val.type.tensor_type.elem_type = output_target_type + else: + raise Exception("Only support Q/DQ ops with explicit zero-point inputs") + + +def update_qdq_node_domains(graph: onnx.GraphProto, use_16bit_qdq: bool): + name_to_initializer = {initializer.name: initializer for initializer in graph.initializer} + name_to_values = {value.name: value for value in graph.value_info} + name_to_inputs = {g_input.name: g_input for g_input in graph.input} + name_to_outputs = {g_output.name: g_output for g_output in graph.output} for node in graph.node: # Handle subgraphs: for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: - update_qdq_node_domains(attr.g) + update_qdq_node_domains(attr.g, use_16bit_qdq) elif attr.type == onnx.AttributeProto.GRAPHS: for subgraph in attr.graphs: - update_qdq_node_domains(subgraph) + update_qdq_node_domains(subgraph, use_16bit_qdq) # Update Q/DQ domains if node.op_type in QDQ_OPS: node.domain = "com.microsoft" + if use_16bit_qdq: + convert_qdq_op_to_16bit(name_to_initializer, name_to_values, name_to_inputs, name_to_outputs, node) + def main(): - prog_name, *argv = sys.argv + parser = argparse.ArgumentParser(description="Convert Q/DQ ops to com.microsoft domain (or 16-bit)") + parser.add_argument("--input_model", type=str, required=True, help="Input onnx model path") + parser.add_argument("--output_model", type=str, required=False, help="Output onnx model path") + parser.add_argument("--use_16bit_qdq", required=False, action="store_true", help="Convert to 16-bit QDQ") - if len(argv) != 1: - print_usage(prog_name) - sys.exit(1) + args = parser.parse_args() - model = onnx.load(argv[0]) + model = onnx.load(args.input_model) has_ms_domain = False for opset in model.opset_import: @@ -64,10 +159,18 @@ def main(): if not has_ms_domain: model.opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - update_qdq_node_domains(model.graph) + update_qdq_node_domains(model.graph, args.use_16bit_qdq) + model = shape_inference.infer_shapes(model) onnx.checker.check_model(model, True) - base_model_name = os.path.splitext(argv[0])[0] - onnx.save_model(model, base_model_name + ".qdq_contrib.onnx") + + output_model_path = args.output_model + if not output_model_path: + base_model_name = os.path.splitext(args.input_model)[0] + suffix = ".qdq16_contrib" if args.use_16bit_qdq else ".qdq_contrib" + output_model_path = base_model_name + suffix + ".onnx" + + onnx.save_model(model, output_model_path) + print(f"[INFO] Saved model: {output_model_path}") if __name__ == "__main__": diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_dequantizelinear.qdq16_contrib.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8fc884024b00f22f1b7d95cc7003004177cf784e GIT binary patch literal 3938 zcmb7{ZBSJA6~@mZnC(qXHydLtIqQHUm1O^i)5+P8GlerP{*+Rn7onSSl3&h+KO_IEF8Rz=&)&fUHD z|3BwB&w0){ckeYnVO`|Yv^raD*rA0yF*P$|wJpcCdi^`gl`fO=jMJw@>=hnI(0g>!GBfX;+Bd3qF^W8z>Ih_69$d+OQb~8z zd4fw#6Q;EFi<~~M&FgUnd_iATNI7p(o;ARts>kIH%})qA_SMRgRJ)iO@Y(%7j~22n zPhEd7c;;hwKO@;gCMCn8`NN@9aaJ4TEXKM$h1TM^mEt=km{TWn-qoz9?zeiZVnZfn zUXcHX(~>z(4~g`M_0K3zv+@WW<*`LjK5Z=PLSHz9zL|4Nny*sLO5Z@YW`ADEa0Pr} zztKe5&njj|I79eSB_X(Sz#J*8#cNSC-c_REy~QGO(KAZ?MT+~7DQTjhYv za71Mnm^qxqVcY@7dNS#f6L1^>=_fKk4#(u@a#Y5_@{!#HYwgmHU9J3DZgV)vteR{a zxH=|Y`3TE{G3Za({^RkGmckw3#`2i7C_`sEF=vNkEN${M+d7F#$UBA;QIM$H*< zkWVKOPr`nX&p3YDxc-{#l{}f1F%WJ_4^}}r%i0kXdI>~b?IcmkL3#C%4dQ)bRWr(aN!FBga;_7m)2m)SPC&4(nF5a?9J;Cgfe}xl_ux z9+G#!+(i`gGgPb{jHto53H6$YvK!=1Y%AdO23hyXDi9{In}*dc^z)fNUbypI5nri-(%wDDgqvT&lv}sUp@rlR}(IOWX zV^sW?RI*Wi3->V0R>9>kXHEEM#P5gUO?Xz&FJU@$2$fxA5vL>8f!HRmz%VX*=#Cp8 z3!V)I(rx1&HweoiFdO0D2a6%H*~7gsNKPYYc>}DA_;<78f@=%-F&NbITaU_4tWTlv zZg4epchNPkVSR|)+UTwi$$JOLZCI+ruLU813d88T8^1I7&!YbZIB8V!0{6K_kjGFb zj)t{p(#?Dq<1>Q#qp<%yb<(5s!lZ_39Kv@yT@hiwoxPppGA38$W8$`wNt9_mN8Azm zYlK)IGhbD7UxEAv>jCEfO*FRSooiqnKszVB=V4mf={$jJm(GDFQoj!}_Gc<*6W zs$i67=zA8lLb6^*{Tfi;fm#D(!MFR*fgDXS%-e$VM7 zhl5N(50S%I^rA{GJYq0EPsS>Fc(CfE0(!T0a@T8Mmd9Zir54-KWFPAuBiaV^uLS7^ zD<{dh089OT-Hv80%)&7+6XZ5R^z~F|3Z|`mH_!u(MC{@Ic>!KQ`0JBZhu0Whe?XPP z==NPQTtuG%UnBAJOaBJ7;FOXAT|00kw?k2kG(d!dO-e44Ucnv9Gpw|yodDx zSPwAgJw&OcQwFGYF^b1fvlfm=s6(9370!2&`43QM0@X*zWFx(+-=7kAC0IR<$93*m zz3`84tf%(t$ZZDP&0M>oj)1LWMF^=w)x5-)jC12dUjQI(zO3%2&1~H}wh1cuvoL{?239Bc9i@OU)L`!umRee|3As(zJDi^|1t}+Ao4D2e`oEhiiqhS z$z=92v&mvg*GQ*le;|lr*!6!_4tz z{iyhoa_9N<0Sn!7g??1q=6^Of?9(4ctmds|{zvi$BHZIpz2@-`(_)OYPquX?y5haMm%cNma_Y_a{r4X4oj8AM%Bd)nL&2w7wn)^D6CC=iaIR3&*TLD*ylh literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.qdq16_contrib.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b9cae7f59f8e8bbcb9a1e33e1b7ba6ab1c7f0447 GIT binary patch literal 1433 zcmbVM&2G~`5cZ}ZsZdM&j* za^MxX@)SG-kHDRm-~#KE7_xDQ6j>v!XXodeZ)V4`oE^^CMK3ENCg~M%Zv?%5FF5)E z@Et*i3z@%<1c`Bg`|6vU#c*4{$|up3fUB#oh~PNolu75Qes@gf@m^6zB7si|nHKVK z7H7DSbeiM1NO9l>>l}3t&(=d9?M!&mOXx)KoS#Yfj9^nZPAlcP#0ATH()n#LqYIJ`i~$-nzMy7!=S3WoxJy}AiiqVZ8tL*>(nRqu;41_O3Vux;LK}{|8)Pl4dMNlEGI%LjHj2~@$vYz;AwOoNy^!q67d-q2@yJ_ z-w-Ti-30pnEJw-89O%I{&{;`w=DMp4;%vF`8y1s6lQK1nIDq!w>3gD{Kme`7@DTqq z*G~DrNTFM!(`nZE`-q^%Ykm6NGQB?UZmo8i9v#F9)H=etrZ|?Gr5RstQ;jzc`v=~Z zIymiW#vdg-CR6{Og=}-Udb5lghT-qqXjs2tMkb~7neV9>|JgO8#xKL~s33atM&^T` zpPu`B%3W(?nfKk@s_o7m>{`gNpUdR$zLFe-kh#u_2yiDOG2 zSG-u``tO?4mQGag#@hW~pPoAV(p-Ttjr-lq-uJ&e|7U-m|MP#(6)$|z*Y+%5KJeWuR?S(ya?!eZSAEqhzWbW*e#NWz{mPZI zzI5rV*Sxvsu0>1dEuXjYhvqJr_pZ0y^745Dy>ph&oi}&(-95`!&wk|--(UW^e0$~V zXWzAGPXFwwvF6Tt-E- zZwd)FFIqltPTz~KeEOAFf9>5p%f5TrqPzQg`g`WDob@M{&ia9_GxhfMES`7w%4flL zcStYT{|R)z>A68M`#gMKcQHd<^+Ola>mo2d1NOf5xv}?!yU&oly9{z3y1TLWbR%_T z@98>S1hQwq-q*hTg6uh~8+&)@KSTEJHp+SU?#ABJ4%L;tr|oqS7@q-qU-#VD`@-F4 z$lhHBIS<|4*n7H>y0Z6loh|~|GhpxQo+EpA=|4mE?l#JK`0mEu(+<^@y{GMU5g7mE zC9}RcE=d2%+4Fmr&Yj)UyK?sI>t6Ay^PG=wn(lVF6yGx4t@D=t(&_F$-GW~=qgLH5 z|FkQ)=6f!`E@oUDrN4cjeZFS282+$|z4???tZURo7jRCb|7= zG|8UkLR#b5XL=XiQl_R-5`JrA?~V|Pyz4K! z{xVI!Vbn$h)hnO0{%HP0UiC%Yv#C=F zlqa2pXZQ?PeI*8YPOSI@*G{owbA9_i+&OCoZqZYSdG;lCwXQQ6WQK0hjD|W-YPvZS zGaBz{;`Tg@*-cJo;1*rPO8zkuuA7~i(Q==z2d;WvM(wV|c6GO#7qy;Q{~&(Bi(cB* zY|fcG_o^ASdhuniojZ5d&P!*_sP(L?XEO5>s^jG|YI~k;=!;a zk!Kn&YnpszR{~A9we_Tb{;C=H=T8p6-+WF^(Tx1_fb^sfoX=15&s{lBn@su9E&uF>N&_&P7sNI#=A79cnuXR!Dne`R1?62tRfM5Ia_!(Ut@V3>A9q_hwSNExH zHA4ry+MXx;-5k2=@HFYZm`>fqcLomlbg(?X*y-k~6-3Vvp^NL*T}Ec$fd9kesGEFE zkE8$34)~0QdFB7iIJ~B79mo`*5J$>|J&nu}P;KW4&p3qNMG+-K;b=VjFHN^Ax}%|)%}D`S4I=^E5r zH>04YZPf-fFM8q1`LA@g=IX1iyyV4~|IFw# zQvY7f;MFfN#ufj3;`?uSW?5@m)?n|{=)_gvdF-|?rN&cFG)PG`@5>vTTy+nvrI z{a&ZD?hiVhfiHAAXMd~HnVWuZcuCWIF@0Y0;ZEnnuW6dk{!yp%-`?Cb@4T{UzW;YS zoqIC=pMPJ|Z2jh@dF+2}n)m;Fr*riTSJH(&vXTJQ`RJuh^UHbvS9u=_wBHp*`Otsw zbWY~|%lW<}t6Y%p|2?1I7&f{+uNPm^H0yHz_j2tg@_K((_}BUV%URJE{%fc6OKHFQ zvZk4ndt37Us;_ODEopOE`hCw|cRDZ3eEZXWh78$-(yb8Qc==~;dHGBQ=c%pRuPpC8 zr8A}-HyFPu*tkAedBrz2&0GGa(|J)a_3ioo>VMKS-y950eqGc2Wib9Xp~mkAOYh3- z`h1@ASDnto!S0{@ZKw17d9BH{eYyAh>HF2`_p-cRpVypV_tCFznmt!E&0GIVr}Khf zd}XfvO2&WV3ucV*C%0{;UB{FP_|ePX`0A@?ru@3_;?#&Q`JKOc_0{jY{)%tCV)G?e zXkdvmfGY<7h)$Mbo2Gu)hLCYyogP_wkz zl>5hZv$wgu?X|JlpF4Zf??CfN`W?xqA8GDM|FLE$*AL{=_GWXlBR!7f z|5(0FWQ6|aR6dP0XVPL@bFdk0K9un{HEWtxX}LeM|3b4qeecQDzFgUrJ{tqsU|LQD zvYzH>dhO4(1I_%*vpW!u<$YgXn*!s8jI+79HT{leT|+G#hnr=&zBw)SXO!{G`=P+F zEq#xs-&o#v=Y2ANhXe0;Ub`|E%Y}aYb zJ&u4R^K-Y(v!^xvZ@oo%aP`j zX#-;)ZI*<_Kaw^#hDxW>{=I?h4Z+dTaLRb|Q1H1TQ1pj?&ZNhJ^u9AKMl;^F=4H8a zJkV|mz19S_GlAg^;fdbRY;oG|%-;j~HrhNIhzGNdBk8{^kUXCDTY|?$;m?uqPH*P^ z@ys|o*In%mg+JKb*W8gg$C`hg84l+Cr^9Kxnwy(n&8WxH>P$FnRk#IJCWGlM&G$45 z(`#FJ>u7pD7VNBT-j;i(TJAZS)g5Wx9n6nqwZqL@GS1%ozqOf@Rcz1b`0!Nv;n49w zdLUQG+xF1oM00&sw5OSyb~x=w)5*;5$XrXavSq>U)UwZX-*K=F7pFZ$-ujB$)@R{GqJ(Et!2=#y!}4B)vWmZrK^U?atVrNS}dr zoV8i)j`SP}4JVtAg*v;Np9_W8XY8ZF%KH327arS`@lOPfErEJGu#AS^MnZx8`F6N@ zAdsI9HfW=f;Ho!RAt^)Yb8qN48R@$>ux!fz$+S3>(U%1S`!d4v+#3odKA15M<{qgW z%=^L2zAX@M4d(6(tZOpo+Vt+v=i?b+e|oLV`_VwWDpwBX&cRk%`m@p_fod?XN7L)E z^!l;%N7u2)2Q9WEpC;3PQ6LzPBz_<>q2sdj*_n|KWehrDSK8bY3J&N0ucgONhC0Jp z>E>|K_Q10zJ?ZIn>ANL&_GPx^;nm~e`GMTooKJ_d!oiIBVCb_XvcqP)E7X0krQJ|k z91T#aYe%@C@kZW)-{AXJ77Z&%6V{ z{=PteAoGyf`*L?A>sc6{+>`HXL-mpTKNOA_%-_kh+LO5+iEN!pyRnSPPHfNpiF_Vw z-k82SGvBVjKbD?b({6k2Y)h|wfni(jEY4r_(R)Y6I~bhN00%>ny@B?0*0w*rwzb$L zPpsrb@W^%?2{uR5>Tu|~E!Sv#d`q@=wYWZzzsDje$Fn|Mz}^*qo(pBh(~8~d&)<>U ze<+f)IZ%&>;-mRQe{Rg*QyFnjJMR8mJ)Is8wfwcM<%;E@{hm;AJW_g~xh?Rl4|cyZ zaI%oY=`oQ}ZV$h03J!KeZD_P!{yN9FG$D3aaKOPA`-=F#S zXY~8izO3=KU}keTd%VT%@!T5bN_H4+mtokm$i>%lzqVs zUBf@XQ6OHM*GTX{Q{$LrX?r}gvP`$OGo4Ah^})`*%!D6s$cDUI?ZNaJ%`AIze_cLr zPH!6MTzhYSaC<6a?+p}ut-XPXHQJg!JJWVB)Lxy@2Xk*l;CU$j+2hHKvOVK3$owmF zZ!}okmHF4@&QNBVZ1rbfdOX%*t#sW)=zJoxvv$L|K9V`;0^Z=3^dHK|+cMwE+?mLW z^`6KqTXS`!g>6f|t;p)=-=(>SqsaQcKzKNGa`zTa`Y?gT0A- z+L&vkWLy4kPyY{RW|;cESjN4XZ*%7P`Cw#K*0nY>@d^i&zM z(aFQ71NFSp-qNka21GGqHXT1eOV~}Y0Pl^J%2yIJGD@W`61hi56Kcrf4A2kKM#%O0=D)jjz}Le2yi>#~-I0`X|pD`q?t z&V5Jbn#|aTLWkaP)$z!e$ga1A`MW~*-J!;@=KBNp+cMVvz$3;Q3)hYX;v-q}nS5Ro z+-=UR_vh37(V%NH((-(MF!VYb$VWr3eJ!uxzdPEsJQ7Yjmp(t2wl`kgRL{^ zc`ST>HrSb)84d*YM*`<)^S_7gWOs2^u{>ie%E~u|s$=;+89(KN`L-jpqj3i^+s5FL z&$uRQ=*_)jv1A(p>0sbH*h2VV+MJ4PZA!}rLoE``54Fa7^ZlXpq=C=n+PZuer}A;Y zem38KBiu3)OizYV!&%FM;BYjOzBe$_Jf{Q82ZLRH|DND|B=8Id;>q;e5ZoTi`;I^) zc3K~L9tmW7Tdv`M-0dZLnQ?fO2iiL$`8<(cw6b`fUSJgtrtj9^ z^hD-7lFtXz?s!^0kbBm4Apb>PhjWicJD!$&Co%1sz`HGYKa+PH$~wtC@asgQG7!fz z)4p8W6)f|8jUcuzJdI|o{@mG{xpwC7K)$VNX~X|`Z^qe{*TKxbAtQ?9hSPd~up^rO z@m!gmE9IAwWj^7_tmiDIy=t=sVJOYxLB?d>3CW z3|G5G5>G~ww4{kiveB<^U|iLz_*=~ydsBjJ}5;UDnIZQPou^F*|5e_-2?nOUDZ z1J6hsjmjLY3m#5qyjNx=JJa(-sIoPq4MfAAY-Lyc!iEgx|2q$iT6OqTq^7nO#*fs@+OClNl!PmoyD|xVMVgv8Z^}V?&6Lfblb~f{@jOAIFRlYA8 z^H}bl3N(8H!|BM-8zSFB;p3yt7h9Zc$>+nlySl|D&umMys0@Yt4E)i#pk?3IW_H;h zF_}1OZK%8|ck$pz-uVeUdhxrwP35HA5xMbm#@vr zDkG4&Bl$-A3}mkNg)(c>dwXUk^&^>+kIX;Nf?ht5Z>*`6_l5h;WGs0vSwC4Xd9*FL zwl*+sYS%&si=WP>?TOqw7rO5WM;s08EQzR@S1Us!PdS;-<7qz}$k(Uu!B`Rzp;2)i zk9}|E`dF?k&P(oRZARipPo($O+})7tMey^-I^{Qs%+oXkA( z66-Vjy0n_i`@zgPlJ^IL_pMopv0&pwS`>uhPw@cxT^%}aZ`aAjZizIZO>w?#50}WA_$(ibBbEkU5e!+81-&pSy*B6!*#`@A77 z@62_X8oVU4+LzWM6TJB=`ONPh3v43Z;mo}|axL#Ci?J@FFGw5MIgn9BD z;nnHCHL%HLA5Kda0;e6%S`H>x@j!ILeSu_MMi((33@mGdyH)xAfz0w)dJngdkLUW` z(Sqc6V~h8rX-fhJGma?vBhf!+a_7Fl!P}!j-``x*d~fD|W3J+#eWBJw?84^U?~m3w z9nWA(Am1K8si*m^eBT`kug-V=@ZnavRs>hFNhI-}@Z8R5WYHKq+84Q8l0IvrQ{_Tg z3|Q#P40i>Z@r?V^8Bx@FD(mDy$lCmLKG9lhf*&>*zwXcLOs-63w28pBB3?cZZA-8P z0@;m)X~S;uTg852ZygQR>j{#n7HZAF$eFN>p1`C{M% zbHU29km0wN@xJ^uyNK5~MzaRd!*FQ%!Nk7XqbFB~cD#x=g%aC>=iTW~lZYihn01^D zWp)LJq=LWn&b;pp7W+a!S!Uh@eKOg6uGMJz%NidHZ?KtL!wW0I8TV#3oU$RF@~Mn} zHhtmkSX$G^_heK#{f&X+NH}?M#(W?=bVuNLB$V#U^)10%Us~}0@Pr+hrCGz__#m5d zb$_^OM|gjGIN)He9L*J0Ttsj-qmSp>f$%>|Y}ZAW;K8haG%(?>!L&dhdm(3n4PMk_ zypFfExa-LbAmiKc5%)zdSEmP0mB+dx-}(Ne^YMJ!5*ol9-E?oX)rvr}IlX!^-&Rk z!T4vq$d6_2eeoeHM(`<>+ST(`N(qhS2p$R!zUJPHX-t3pN-mjt17#!N#$y zU?_h6(tLVr+R=IUr1k#5EA#(wpt(JF9}cbGm(jLlh32gt#S?*BtZ{#4rBB4vyFz)n z<=(X07aGgz9tbQWu{6iC;*Vv-_hhxJL#=xPyPZ*cM<;`c^;zfrx%WWedw*u#7kl(r zM(Yne^U{MibXRC`Abn0Xzn_?PbD$C-tq)v3nmOb&>;@eUL<=LEYoc#opVkkB59~-z zWRyAKrQ5RFiO_#CGY1zhmWM);pTVK_TJ2aA8!jpM&Fjv-W%MC0$v+_+8@50n>8E= ztrvv1w?ry8=KG?oo>w*&{t)L@PM4KFm054eHM^H~Zhs(nuoFCxu~1fSU#4L&u*+ZA zv-?11*_K&(2qgS$`q?cWZ9j|n?JdC9%3PD1vlqHI-)OI)tnA+O-4_~@S^ZdjzL#w2 zv3%MbYRa$6kH9)FmwvuKW8D*IXbK!6g58<%WFUtEp?FwEi4EBpjF-1`EPb{FCiZfD zM&N@ErQMIDuXSt=MdWW~tBq(^W30LQlb5GjAeBAO?McixZod5Cl&-wy|LJtT{2w}< z?|ebi{4duu%^#(};$PT4f*~Jf7a>zUHbmr=Q^DmFKe1dQ^4`l`F>Nzx;g!>%Usu|_~UaK>re_gK9z!> zKS<%no4=IOo80?@+&h`#jmOe=a|%ALOrKj{n*yL*TabC)o_YRiR@2;+em|A+nz#H# zr}HO)AWg!?@zyfk-0vec5l!1&jj|jU92d=G~Tcd?|%2X@3BrZ8nV;|~VDH)pJYf7|K2 zGX1`s@lO6q*@Ic}!eHpvQ?B#xLyhkY^xu|i?+%s@ z1dE>zhSp_ufBc7?&KG{I(|Jewen;B=a8~;Ld4E|T{+@8@AE(WG!&UFfSf9@7e=l?W zSuk`}u&^cbq-p0*GVT@mygOt3tF*l{6u&9$KbMvF=Kfy=#>X=M7ed8}jJqc7Cvtyt z?tMOR{B7oZP2d_$Y1_Bw`lA^K)<2Qgn#?&A_+J`Ke=!(cnF(JInEpEJ4ADE6GO-rX7dH*#%X+FTYm|2*^E9Xu%(`?-AHllPy^9KV}6ZqMA`k@Y;3c6VfqRk?pI zFdq&)M?=j88S`_2;kLm3@6zT=f#F+lwtJcD(tKZ^Hop}Zzbe-z)9(jD{dZ>k-wO;s zo_ohK{;B8jJz-(op-F`iNeU4)mn9_LYu0!+Hu3@Q!kg|+Dis* z_zFMU@aneA_Ofd}-o4CrMo&$wGTWII#dcf$Gm2toYCbMpg&3i~CJcOEnBg1p{+D5l zt@*4j?{~tOuMM+KhOv9|`qePi4~O~Q9o2u+pQo}iO!+Noe>(3s<^76(-|2k)FGe`S zyeHH4V`13U8ROfd5WW;fd>~`o7NPiJ81d`F$bT5-eR-~}2;;ses`4F~_tzsZJsInf z2-re=@K#ZRhg-N15yY4vxMx^ZeI*{=2kMFL^`8x+CAu z^`4b+fQYjD?)>K1D!ef{1<8em+7}JYkN5J#RF(M8S@3XM>${p zZk6g2mvoh;r&#pa4+Nbj{8ich=R1Dl)zehA=c#)o9c`KX&~#<;Y4Q})U;XIExJ#+9 zyzGk0rY)~b^Yh9pKX$ojymr*p7mUWk?2E^(p5V?zwcDc?MFaMyH^$?W8<0<3pKoLF zIv2*Hv#0K^{go~0S-Tx-(kkaP5+8C$>w#{K_cxdp18Fbv98KSBiRU9d2Rb|H>t-%5}Lw(fOuaIUC4i=Qrltct({4ku4p{Y;ey% zR*A*a-=3Aq8o0Zr{Vp#c4=gWj1{vgy`L9yK-klu}ch&klkZLyBS@l}t5%=LpegSPn z1?v(;%LU6!tGD96t7B0`V8r!lyCK)@Te~(dan94!%EE1_C1*;f~?}^-T6p26Gn$MNjfZ z>k|{L3S25$%#TO#njFd$5BFs(++|;-FMI{tKI3H>7dv#QfmXi5FdT~k$qCxv22{i& z_*funaP)>COZmY$0fS>kN8Qx#-4+h=(@ zl~O9{VNpn(LjT@-&Y4BdfK(s!H7CYI7ca&TX_JLX0QyZb(slQXd!dFGrjC8kr#L~ z*SfSvf9qCFgHlFN(@#sXFLF|2nH#6qo}L}V?a@ziv-o#Y>YG#;>`G5Nd!YaJgj~gFz;3b=sh)Iy)IOHim@LAoyNAoouXH z2o~af$tmG>Rdq(+k_;8g)|;{9w(QEGyEaC!%RU?nVpnfvTFTw{KO9*m`7ipYls zkRIO5()6>BiBf1k)SS*Nn=`_Je5RRdKgC{&dJR&tKlxSK&CbSqlk2J-F>=AGFAg+4 zS?$lH?O^5`NCZ1KBfTX(HU}1ZO0E06RJol@Myo%Zr_y5}y=q^%w9<)IGS&rF8WAjZ zCG4-S$Q5`YA$MfVJ<-540Kettw7VzS&)4OuYWKZuzI`B=+sBp%WmQzWkT8DTyIU&P9yZKPq}QCx zxI8b?$WpD(HJRT>!UJTMKA4}GU{lSGs)vU%i;7j6<#5{6KK*d0_1W}X(!#Bda6#bx zOvc8SRPEFwL z?Jmulv{I+HB3iB||8Gi*HJQUclC0?bz(_{yJ{^rdr=Mr%Guh^Os9vr5y#0X!U$co` zRW&-49`-WkBD`vm9OrWcSo;-mxY$Gr~OxFJlY9XX+D*=R!zS6!L%=p08?~eX#l5Z zXcS%(OKDGmFTu9+C{cyS$D_d;a?WK z>;ni?$DuKg6$SCf_RL-$qw@t;M>~SJtOTugS6Z{V_>H$Pn0OG}_+l~<3O#hV-6g0< zg0|)@js)I;Kw=jTr--w_4F=I4%2>l&-|!xA5q#CR^1)O| z6vaV>eyR?5qvab046wk0X;c0d&QX6& zhIrVumtK~O6&(tM?~0T!jL)ySjqZX&F%9m=GrRyY&le!iI|B8}+=l^FH(yy@Pq1CG^L`_(q z@}hXNdRhUgB`NN!g(h*Tu=}E;=m=M6yRv}rs`jJm=G6GB>jynAodtwtenOp>aGpTD zC5>dAv>OcqV?1CKLOJ?;ZLAjDEzWAitN2f3jnnYIS>!iZTULNvikDGJWgX9l9J16b z7_C0h{^yayH$A;NSmM)@-{H);Cf|p$;)=g`YbdoM*YKlWCtA5uv4#`SlWyCP`}mE1 zVtYwiU!X=Ya)cs$OgsDhE|!CBQ}JgUGR10ch-DTN^MdR!(`_OZTFYIu_LrnO?F#0K zJ*}6Q&FA6yiDlRm+M}{4;tkpxr-8~zin4>SgpT~M^6|Bz{nUrE+p4T-V7zx%T6qTv zK5!D!8qw2PCwx0U*I`5CCvwEQ_zZ4TRiog-e5d)XeSP{ZjRz_Qs;i*Ek*q#Slc|D= zM4$|wRY`XsuoVVr5gcdTq?FYw9-tY@GODH<4J|&Hma0*zS_$9rL@W9zd$1$Ah1W>C zXk~t$JWkmX-)(I)XkTC^!JEP*qI$U#Q5x=a4wp{ArKEvIz?V+P;STx5C2ebb&jyGn zNzRJgTOA%)o*rU*X9ZVfWFCHh`lx@~7C+ee*psmbB7Z*OrdUvZBdzN!)TxY4gQ?H0 z$n-?ZIcM@+)v9Vcbt6sC3$mnq1E7c8bf^LtR))bd)1iga>Pq&_FSQntD2Rh z$gSZzT&2HW{A`>i3V_SXc(7_GG8djCdvJ;G{2$1eDoYn;<*x~amgaR&OWC?x+5*1R zX048|CE~LS-yhm63#Yy**ZI9?vZmQ-wKsoh-lc(n%^1r{s=kY~^v8?t3B;;URRXC{ z8qJjrsZm%LsAQekL;4YK(bBRB3v-u_-kP_uD>8{|Grrci zqO=)~r7L9qRPFH_O4`gK3xK2OIsVwj@Pph6X=k13W03<+rK4C>HSs(?x)c>jL)ot3 zw2<8_$*>wLqxoIqFPW2U$kMC{CdAMDMN){0vc1}?xUL!ROG^;)a>#FQJf-3YHkkZZ z49`}8n4MQ!L9W!2J(WR_@d6EOfrDpT_O(uCIvECUD&^E&xF`2RzO1!8IbIDf${Nci z&?Iy%uTovd`rKziiqLORvEI|3r-`Z-T6{msSS1s= z4O!&sPs5ZH6)f~1n^$A90PtRGE}uoasnNh?ZJ@%9@KS5Fj^cf=fQuBC1i5FPdN&gq z6rS|fl1(zp)T+^h^cDWVSJ(hQOygGbg9oLx^WU5-hBXw$iiS9Q3U*M{8VdTU6-+@m zMXBM8s;!mQ$nL;2%#-1gD_32q89-rvd`@O%kj+5D(DJJzeI(FnA-;qxxHy$x%aZ~b zi(T7PGV#F;!4J85i}DA>7B8k;}KHxomwZ0)D3E7~#CdD4+~A~%IDr0%wi zAg)}K5#F2?tVnhaeDo-;=cz9VX6;GAi2XHZ1}oc&UUcfZZ40|p6Is1Tgij;0#E$LGEF0oI8AXm@ zG%>>WSsJR`K%%@+R1ZR zb_Tz!+{*Ch*34>eifqYe;$=FE#h+-`r4E{YQNesmuyrn@EzCW$lS6qCT#e(!|EKfG zKFN`GHqoFg!_H)n@c^Ao^6bxn2`^M^LnDaZODbWutQtL5)_ZOsJ{A1(PRSDqA-5}A zdY8>{F4lTZrk6OX@{^)QE33#wJSf&+!%^NCqA)T-2Jph1wg*lwo$qAWL<!@nCU4pAPU2`FoW|&F?IoY5i~rR+!RTnUtxz22Cl_v zi>g~glggvW&+zBV!jbuvX;GC3AZ2@KSbHCA5S=&{9P&hIOj%=lI#y3-!^E+?Zw$Yw z7GPKCUz*8zYF>je&mGs@j&7;vd^7=AjsipCVsNg7szOn9odp8Uy5EJQC1`8V&|40!Y(^ySlAG; zZ;qa@qsm9cLo_%okKSc>;g+o`f0@05GgS=u%PQnjRJ@d?6|Ef3j9cOXJe2uxkaGs& z{B4m6`voi@*zvZLC6)VRg{lHV54l(6IRzL1l&}1 zQ;bwn#sCD0vWI{pjkhE?wFH0Hvz!(H?VIVu?i)s#(V zL5vPVcu)oaf8Y(44@H~~mj~i8XibN(6SOT`QBi>%VgF@B@S^-YI7OD_C(tu23N2qg ztqKTy&o-7Gr}s!V`N3iIo?INAOJl*G|MVL9u9&eR0JNO)xd(z_ye6i_E5DKF$T^kXi z)3L0mc05e`59p5CcH-qHV2am3XH>44EY+1o8EY_gO`g7H=ZmECZaljUE}X_9jn=6` z1ijJHzQ()KiX4kmRMt#}f9}d>xsReTYiAdjUV>%XNTz^~=-edvQ`sbThHri8(jUxq zPs-tMxq~_=Cxb;vtl$LBQ3)W;2;l88lo@$$!tqezJipqfGz2<4mFLFHN9l^%1m(iORmF0yR681;OoTh8SZ z-NfF}4};-n7?4r6N_;E|EvXvL9hr5KD)QxZYRBq(RbHSxKVE7@@n#o8@h^;N<;D>y zla(oqqcR&}FO3&PQ@O(KitgP95uAt`(1teM7>Uw)OIR!Xfrg(#Z=6am z&&phqD>S3bwH>iXT3A#k@fFl8tMe|OLHuG2`EIzAkx*x`xIM9Ge0(G8Ko5yn=^GfZ zGJ3E4m>^>@`otU($Ihy>U&|0JEQLv6ahjh>4fC96__f37f)3uqJRv z#}x0oi~C6<9LU?lFmF)}sJMU*E05X}>%{*&lo|y;pc2o_h_x57FH!!UT*rA<@2LU! zvFt-t{mCofC@=XZys|$!r9T*E!ATY=7qi&O8P0dre{=H*4M`WtqLcBey(0D<*QxLp zyRc^^rM(%&GY&z3f9O+IiEoa7owQZ^fQF(7*2q&d3qnr|;aPb~L0|6D^V+F_IF^~o zoHee<6>&E!OQTc-f!-z6=FxLq_7M3|r9;pHFUUpklh6;2?HI8{D9irSBv!6oQ-*VE z_*4a9S#R*rP3ExX>L(^w$6odCI|424t=s|*G~!saqEo`O5D&Tbx6D9#SPPs6TDrjb zRL|gDo&UE74m)h}30BUt@pMHa(66)sX!t5*1Ba{+jXbAQ-cw=;V;`tWPNAi8?9)=mp*xi_3$6ds+FCeE2O*+#Bx^ZbatdH+0}j zcxHzzGV4+iI{m=wASrBGSp^b9R?0_r0-8?XGl*wZ&U&&HZa23T$lFcz_MBzDAo(f( zlzpY&$TE2)wP2D_109* z$o^Pfo^pkvs7=QCY&E}UvpQ?;%Hgz=2U3-}CVk-uT)3<_TOasy9k<{VyuyFx*RqZ* zg_r>cl;4hvaFv`T38W!p-FBy?Crs(<#0T#4x$?(66uJKz9kub5vE+y3GEs$>S2Yp# zs}E!ao)AN`^N$zj{-LY|R`_aGN`^h#*UmLAA^Y@aW&cD42Q#m{%-YBd?XAkBqIsNN z85LO*x`_7!(v8uNxJv9sl6gHix8fL$ zC<#Gfw&}5UCbG!pl8~W{DeClJyui=CFZV&g+aSB1jK=3X7D#9vxS^LyH{%&&k)+}! z62+UcbKtY=mH(c91eZJ~eaNP~lX!yE*eg?2jd~({T!cqGcb6SO;r{gR4>W8yJy1C& zPb^iPT6waiX|X!b=)mVH6=@52z(IU7tLIzF45I*DUl5_ScnfspL9UKYR{tq4;R-n} z4Nprt`=+M*u3WE7P5ERjrV%Pm0=Kxf^e8`<)OgCHI(c3i9)J&;nRK!0_C@)To@r&@ z4fhov;tLvt%z$1*02)>d6zae$o`4T~X5<`A*m_YHU%hrDjl^$|vlbiENY<}~>+%ig zD#E4t(XRe_q7RIybLJbdo1&AkV8n{q2;8C;kzbCVa5~;sX-x~-AtQro)Q2-J4P7=% zU)*NJB?s>G<#SorvessXKlAX|X9lO!kdufL`>Jp8nz;nR|Xl@geA}lw_b|`D#`S z^YBgv%Clluc@6l0l=Ek-tuz0{^FXzCqBrDFHxt+b$!^8$Wp0IwPi zge~C7R=hLE~onHHFem6s$ zUMR6Ec1(x&7Q_bf}lt5$Nz>2J#U8o3sZ+7BPpM6-JYy_%S6-Vu>lTSQ>f9V(% z;q;<4h(c&***+B=wP(jyEGeMla4$)ORUAt$%vIW&KWu%lst@>i#lIC43HZsL%e0D?smju=v)_wBPm_%F1QkNIzZzAsMSU zWiWH!8?Kb~pc&BA^S1UxC%!Aa>=1b>+ap;y8D~rA4*T0Mm9}Ep54F~P zS>`XBxHz!P3BB%0j0$2k+IQyfnZTpgf<94&;ukZjRvAWcz3gZ0Z(CXYh6i53ad6<7 zh6AZ?v`d3x{G&O!TTd#kOsl6-P35ZLna_(FY?wV38Q*R3ZP0&VKG``}C%v$RizY!? zltUYymC;B*Puif$9g#>i#(W3#I+&|6GcswV!_2?~5e1_&j$W3(xMU!>xHA~q)1EWp zZL(;zD!CCg+6$30*%3Oba={zJ`3xf);?q>jOqZ)S!U0uJ&(~g%S@^g(o&ILWMJ$z< ztjbKe7}(?uERB>o6E{EC)Uj*7E^x@;h)C#9+{<6Y2P6`vXlPj*+-)9l3r!{8Tsxb* zfBAJMCiqqS16c_+pQaTXEYAEeKRctzY=}{o=2Pu_;>5CUXuxN$eH6BZT-0xZ(A@l> zC4obhPAi!!=MjA7Emf``pZG-;Q6u|OkpNlsY+8Ehy}^icb<6Wjd@Z6DA=zQ$$<-4K z*5}looVTnDn#q2z&d4l;`g|4cc3|wVlhm=u6)XBcB)}7>ooOKPBgxi*l~$gzqK7g{ zNrI@6FX?0_K9qB_Upku8dLkod%SATiT?WrDb@4=0l2qlCtP;-Uv8WN6pZmNICsD=4 zmB%Bq>=C-LEh0C(K|5=W0<4g+S#i82R+M4oo51)x(yvbQs7k=C$kaZt~TgQlB~J1*Jw-wI)Cpk|B(e3OW1jQN=H`qoOLK4Mb^NM5lha; zQVWFso}gPk9(#>9m*&O)i|2_B&>`?Cf%0%w9}ZS*`lMY4Hka#VIwE20sKuo+P4f&^Q(8+vR7^MAoqGUdbQsffJ_y?5OZ1aJq_Sdkr9` z6Y&s9VO=X;=EK1GiQIKkMr8(kgN-hwcb|%6a2OtQwWv^f6)d~~BUMI$zH!dWuOn{B zoYpNe2M2vwlEBxJkE7MqOxQu->#>MP&D{j6=7{%-~7698_@9RE-}xdv>qqO>a&91$L^Un@$&3sx74z>|pWzWaK@g zm2Q#u-4^f1h&}mhjG}%;3H$<8y7Ed^sYZcsh|csG8u~8Fq*BJ-pNvWAE_MkNbedcp z|I$9FCl=!)zKT)Qvx|JmA)jnxo{x)*?7)!*Sl^YNWMCjYWgU0K>yurOfj}j2dGdhu zpex-;7Z)Fk9(nVdWAkfg159ic8N-n{P+bqI{9^8u9YkYr@!(N`MlFAgT^`AL#J9ME z?V+3T6Pp3*cc%wG32*bAJQH8MAZEgY{GEr=o;1P&9_M??rm;^nE&cLn+Bjh^IzpE{ zi5=+Qs{O9mRi&%!nV#y+{cfI`7f(Vidq_5Tr}E8aGpBtz6f{%4^U{pS7UOkKbiv(h zJ>1e474wrecxTW2#sNPW?%_g2>k18tUp)tZZ?qkJpeZanEdes}S2-4SH}2C!Xj}S% z*WlMq;G#4DZjzbvTpn<(%5_|5PjX%;)faDmNv_dmxV+-qMfuDcl=k+k88n7^FuDz% z-xF_FG{zH_Em#o>7HzE^y*$Ot+NNxE<#Ng&1D|!+v4l-oo#DU*T2G?KQEElhKYRYP zYzmlF)8LsUS-thKCj6TE?XBUK?^##U;S>|uvtD^Yzd;3ud>T18wE`$^-DoJo<3wTY zL(%1a84oYB*|d}$1^EVXDLFzRe$B3MCJiHM7|y%Dl?N>CXtmBx(kQCs$T_E18YG+q`@g2YqnDS z#Gk|A##xi=d~dvEG&QKKB)?as__k!R=tI=Ou|@`aQF2b&^DAGQzBoxtfKI$6Yp_Op zbmC=CKeLY_V&V-|MyBGZin;2ya#&ov?Y%ygEt6d%zjA0ei-v)zs$aFcSy;AeI4qf| ztj2Jz;9k6>8ij>5kI`HugS@h`jWniN{Q4`e3FhQ^EAIq@A`$n=b6HT3Ssnhs z?Yv+7;pugBJo$zhTI~J&!khd#N}^k5}BUi-_h)6vD3`Z25 ztpfb!7i$$qky~B>PlaWw^C{@bSEAQ>;35o=ZHiyX)3Be#pOd|}iqSv{6XwDZ;!@mU zr8pKgXj8IA;?aUcu%X&kore+RH>?t*@QRA=cm}(4M$eR3W1vRs=CyNK=C z5xvCRvJrB0aI0dG-V%@76{VHfH@rs|vFrS*^4V|@YWe2;+TyHnQ)u>RuE}bY7V!MK z#Tf;?K|zzy2YeWwDmpFBxcqQy!F8f%`Cia2%)AGq%~&2;7#=w647U~SIDdOAD?PP4 zYhTjJ*uvb04^Lmf;doCZyf|0(G@3~d>{`5`<@bE?^zG^7rQe%KeI3G(Vjl?dtdx! zULUwUkrn?==EalqtIUTp!h*E)426Y(#P72CRh9W^aV+0(nBT~HPkQ;qaZf4PpBX(J z(_DT*Z&jq>WUejDI5PHVQ4x%+v)mr;$6MQ?p;hy!pz&-i*_Ye$?QlNhdRiAJ@D=55 zop(ImBl9-}P`JH-@ zYqIEiNh^xOwr~E@>Qx6)Tx@)@P*DVJE#jnc$QTVp$7;*gSsOV*F(XZ#!J}1-j`z(- z3f!$x`I}(F5oR(&;fO|}d;E6={2&6e&t*f+UzWhTnXQoz!8W;D_-2{$Rb^00E?Gx9 z8YE?>%jQ}$-AG4`hd;{3fSRn7C9Rl>e?ueCNvxi^3(oQ_3*V&Q%!NVYls=^I$*9?Q ziKV&elA@&crQ_jDTl57Ld>hj~3lH2FS`;PGpnL#hRV4x~qW9)_?R4Mq#9W?r$b-Bw zt5}+LpU!u`p)DG<$3XYUchMwvRQdL|L>Az3d$a=2|EI%~%OV?27A($Hn$a)l_zj|+ z8Bf%*uH6yvRHX90>2gv-kD$cy{5=uxn@ni6BAR1yX8cs{tquQgYJMT(spy@Xdn)do zEA&K{!D!IM>EqOjeO145Lne7ae%((!1I&v<@#H%rV=y)+Yg2va*P+A;huiyTxGnl& zURuz|hjO(vhDsjKc$Dz~resgZ#bdwFK;3v8l98d7!p!)bjecv%=|4`n{B7DwLq zMAt0LJ=W9j1MW{w`7Mtym--h$@(z_&Gflg-Dt=s)JE~9EK|3h*RK)VY_?7&LO~F9%|A)13T@FYpSNp7_2=nSiZG% zu(PDNh6gMkCwp}?SfH87S@~3=Fp>{StC2&lN@B6*RPv-?64Fd!13@^dV@?l0Y@*%Ez$+NIDWMDc0y?e`8qT! zAKnZy$|~K=i;s98)(Eq*jJ&8txyS!NPu5_4GOXx<+Yd$3coKdskAK4pu-~Z`(Kv8s z_lF%*we6fCob%j_%D)AnClGGT93RcvYoAICTlocEJ&IK%P?434AYahE4bd(! z85OqGb>eEVp(=lz2@hym^p>0C8Nm=8M4sm6osVBUEsw^7g^P;4#i%GuAF?z!9}QsJ zs4_LQsjF%caGsb6X3LYt1^i#t6xR*(z&VtcSE z@)qTaUU9Z(M~IJ$)-(*i7gv)*xSWU<<=u$1K}o_vhTBD+*0n3vQl3zR3)}XLkL0^M zI8t;&55f|7#SwDDkK`MR%rfAw(d>eW<#-}h=U~?mZS488=+^4l!SL+lkSAtrjdidh z(V3hmd9(J?YsS+0ed$^8Z0+#ykHI8{n$o*GPZULexZqomS(JuJoN_#4(Wo+tYK=Ur ztfZKZP+LOd!ah!(@-yL%A6=vI>&xRI;jlt4kdM;MtPzjT?f^SYj`5p}I#04<)~b#? z)b^^HY%rp%7*9L?hLgkckQII588DH9nh*VXgCggOM%jVVuH`xNHQ`}HxJCX4FR2(*pfSB4ZHv-aW~Ssa+;Vcn3sEQg$zPvTTsNFE4Q;J>s8 zDHEBy%Zo9B79|^M35tfG(!!dou%HCn>9naQao{o6`Bdf`4&0@uH#v+0Tp3Az)`N20>y%#;`AMpZ7C5J4es2Bgqm9tLH!>Q}W6}%cdAm+E9 zkN02?1;oI1T0KDmL{+`zS)1h6xhVhZe47&to=|SidJc5uk?kTGdr!RWEwM&8Xd)gS z3RIjRZ;e-ALk-!^$i?Bzjx*}{x|7k5B8BqAM3VOPJ&|7Bqj8K$llr8moQ%B;(9pt# zH6zkvFlYq0v7v=hK%RW*gwyD(Fs9 zT7Duu0Sn}AY5EVRc5Ghmqb!d@-p*6&)d$kYY>%F>iw<%we2V_aDakUfTnxQ`NB%l{ zh(pMkoG6~8)zqSn#WSY4z(D5Yzfg)-Wfga3)RoQLcuFKplq3p30lK9sRMb7l$%!D@ z8L~(I#dNeJoy4nVFL*@m4~5h5v!}AxQ>R&Od}MbTfAULf7I8_{HCQd*W@V^^VjFZe%ra+I{g8Kn2sagX7s!YS2k0XlUUxfkQ3`Woral zusq?}ZHtNUsj4a0XnNJ;;;luYnv)g%8Hs;EuV_hswMQ>2Pui`WwK^2&_J&3a zLLoQ@FHG`z;Ti|s+K$e<9Zx)I57g}L$;#0fBjJD>a<-B_mxVi>yYiBLtG-SO;GxPW zki+uy$`7-H#UF+X`*){XiouJ@Dt;_o?R-vU+Wn5t@nFLK9$YO>E9)g`Wl?xUWJz|0 z#0*A%s&ix~=`5LS+~mYFAB&_p^QF$xQw5giUQez(l-JJC<#0w>m0tb9v>fU!;Z&!E z?7%q{sNTlTKw%G6JP&ng@F!ECLeBp~Ee-1FL}X+~o0UN~XXWnB`_klI?`qrQbNj*8 zLr0Tb{KeOj%dAs0vbH2(VeSrRUGU1oqMKw6-dgrkyAow7^` zYs>c1cRVlJrA}YTXQ-4$S3W6Q%NnxIcnt4*)+;HX0qhTo8&Bs&lk)uRll)ZkO}YM1 za1P@#8Zd?{QLawT_XaX@CR%5GWRuumc9HK=R+|n*Tb>i_&@4Pf^2vu2ztM|RzNT56 zT@`V&NLGk{7v@F!t-#21A#Fj|dWIM;TJeBQ+ghA@N*)=eZ%6?9p%qM;$!h3PI|n%RBioBdD{5k$%da9; zd{{aUpM&~WBX@i}mW`hCOgYb=#=W8=e0p2fQqM2Kd*Wm@9b^D+u}FS1Ob z1+Q|6PW+b!SNF9)*-qHT5&UR9So5-9c+zLOhUDt$)P`J;@##&gV?ku9wJkr%DSf|l z&+pPtM#htXSyi6uL^R&9tV}(&S@F$Og)A;YL+4tH4|zOzhFwJ&^mgW`Y%%*#zf>UC zcr-QwUTNSHSqsl@e!FV*^Oe1!XDSBgH-QYMWjRDWr7QeuAq$D;=rIypSzi`h@2U#| zt&FsYX)ylYn}So7PGGD(G}M-hVkOAk-oT|^Z$bXzSN8p(NVA&sqb;q-A>RxaqbqL! z-g$6%oCd<5{5a>{!AQPw3!k!ht#p>(_;yx6CIvUJZ$|^WSb;Bw!z<5)!^x?5L5-R^ zUpyf;#Iq!SQRc%>6OoF)$=#E&6mWr0c-m^(kL50#%2Od*?44Lo#)FkOoH6}I32wG? zdrxeV$lj{x9N7xlTt2qgttP$*2=`4F;yRaXKMC$RnEJPuV2QfI2ygv0$_Q(xN4 z=<;~%aTtv^NAqvZx~xbov3W#TwUf^a2QfJDrHHer!(LP*P`hpTUj4A=VDYC*lWfeD z%HGLEmd4;)$^-Juc|}WuE4e5!HUCvDrkXBVl*Xz&iBovE&vQo*4c?bZJ3-)8#O80Lm@u4=j+mJX0iVv`vo(LV*SSg&5fDQoD~JoY;Xc1 zrjkU zDKDJvE>BvUMQKT=@NVP}Ws2Bo5!lhdI-I}mk_tJI<+-vnKC2vo8W_~Uv#`N48VpCG zsb5e)OR_D`PLkzp3(vHhSc%6V3ug@aZGH4G??+FOHlL0aKpB{4*Vq}}g^UpD@~jJU z_=6T>ac|2#zre{pv3xs|d=|DdBWt`!B=fSOtH29Snn1|Qx`DVTaS$$=;%crh0zI-!0 zBYI$O$vGb%NAM$QBob2j5!CQ3Bo>Mu=JmlXeJeiD1J#a)rntxXV>^msJKu3K2`m2e z%nEh2Y?hqE`_kK!8tH8sb#o|KT*J4AZ)xPahjNBtJ6Mw1m1_N(}}DNJ^8}2GxLKnTCZx&OFr?iClk;n zMKh3!0_=nFGo3QQ!!pQqQb4rqGtMk~x;k8^>P2n>7nCnv`b=(A{zueq)r-RcveK*{ zDyRggc#T#wFK=-;W7GAOORMK3!6S?7nc;Yb6xXhf_3=8`W;~0sWgGY|tft%rkDvcf z!^p>ZqJx@q`UHh&5xLFMbFh!vPSC4ef-hJ`EBoBM6A=2f(Scx^W~*0a&1y%ud=B~% zH^D8M;U9N!k2!IM-vwh8_+{lAgVcH|--0jKgZe(pp6qtjo+Lh_Yb*N$Ui3m?r@zUm9C~G!<|l7O zM_V5VNe6x`%MS+hBxk%u9#UmEoWBPVYl1UrS>9Ulg_Sy0$JWxtqyfgsC5gc8?&&37 zL05RPT5U^8jpkcD=d>gejP;aDuANe0^dxY)eqnUfuIN-M(yY|z4`$M`ZRxG->4c2nhmxLZ|V>3uv+ z@;#3Zto&&sktUHPXxR}sWi8Za$?Bj+`88tR;zYR#7!eoJYv3-O!w2H!qXhmc?8$bQ zhv+nbHSyhx*XfwjlxTofu*mzPt6>X{?Oc!$SjXQyL=iFWuH3zF=L*qK9JS_mf0&r_nAa2+nx4~}4%yr?IX znSssH@|FMLE%2Pph)?)*G=S&{?6{j=G%HM~43VvL;(bnw6LuCSJBNm|Wj=5!ezGE# zUR<&>HeGhMCs*x^!^?_bvGz8~TN;g2^4H)8FO{zbV=#@|<;%_n3c3^>#F4C$GiRP& zvMK#HM&rA8EGOW0{=1;i18& zw)l?L6eHmhIOn^eN!6j_+-=#RqnC`JwflSwW0s`T<>gsu%|Dc9e<0(rYcN1lsjm_f zl7Zqfd_{MaKB3#?ZA+^g2Y1T&@?B^Ux{Mx_Vf71Yu;aBcv(ZSEPvyfOYI&7SpW4OS z9t@TalW{3KGMQ2ED!+}4i4IA0<%aNPc{HR>hS7}9VbCp=HQ~#b_J=w8sPdNOhm%CU zJ`CWl;%d004?H8Q^n|;^8AslSG}#BoGw|XQTw6aKv!uO>$F{aBT9s@eEg_#vC&?14 z`qG{~E_&XV(P?s#!qmFR7arhKqEPV__)EXgk$l?HV{{~W6Y z!;#>I6V&k{6BR4t1<|WkurE)fo)|}}@~1b%lvazXC@jbAT!|`1aFJLumrq=J#hfzm zAQrW)2!43+D>AyZ*YiPfn-#MeAgyduMO9>=o_l0I8lyBXTGVfJ(t^(T@v{7y4_n3x zqLDV{Fg{w96;l=N_vY$6ea&4SQPCEU!2}7S<9Vay>&SYSXIAsO4r2B7m1)C+yu0!; zX&m%7!qI3EeuH!TPAT`N4KD((m%U~qtk5%r)UTD-YbB-k@S=zS^*~29Y0io~D>gY@LTP+(9^VstR~n*lc0b*QisocE-doQWMx@zvQ@bR z9=CkFJc6eZdKQiP*V=_6U240m`g9<-US1U)%R(E4JgiEe-gtD!gYUi392Gf;RY}B@ z6+&TIU_8kZv*qYjQ5Sp78lnLHEtzE{cno%iJXz9u;SWTjZ5Z-YVG*rrVSW7qf%AM~ z32U=DTn5W9!(Nqqsi@|2v83)9)2<@A`7TFFONh^1K|`7ap37I(3Jk0(sYH{S0}q&q zo+IPzvKh*sE!@$}Wx1`3g|lk&f`?AwWB2FXeqmXr(vIk2K72*$**U#@WrgEFL^LK?$#KbpcfCs9OON{u9;+aop4kX%TH-4aF-%k) zlf5%s#GV+<>R{Ix{u>1rXtU~xQ)xEqtq8a{wWx)^U{Wty+Rg=8LvO3X>oPVnpJ-DW zz-aup;t^12TU^Mh*dwFC)j7Z>I-nR!Zl(1X=kdgSqmkXgNic3*R<%2l;dHXBczIPI zC_P)Wg9Ubi{FqHHK$f1)C!eKvK!ZZH>o~>36%kvx6`)If9?thVqsXrDpv7K=+wwF^ z>R83HxwX?l_lY?~>Yl?P7B4^7ex58hPE%FX8+^(f%ZbRY_GcdT*q-W2MsZQ`D*htL zr7_`_7Q&w>rv_Lqh&F^ln1cs<(zu&{!q=1;vpWC5W-r9kGDS62(ZoA5I$l%%MSl5< z=C|trrruZ+TxK3mfZ-{ZCD*Hb7|&|3TR<|vEQ_}?dT`38him&r@K$4?22Ty2psDf2 z*Kq8+ysikB_eFaa=Np^W)t8rL&1I|Tx#A`s-efXD^e+o_AnW4+&d)bH4OSuNqcYj9 zQ)P}wqw}r}Rr73vaGzDH| zgRBN`x<3^a?TBY@f zW@h#U=E)3RJCPaJhF`^*uBoKj+m6Fhh`2mcuB<`%x}^(cF=>AOKTph_xX8yI>FVIY zldR|DZ*eXi23PX%_`oOC1AIvPM|3!i$d0nGWlwA8^kC{m#J8*=8*fEq)EdYNyD7^l zbA&6@$c%(ub21`W$%P6YBj5p?2FHGd=X5(3zGZ_*w9G*15xC%0fLT@y=b|<4(4#22 zAoEvE4=lqFOK8VnZ8*m%|9bY&y2L4-f>(JR`qi&z@)r(dG+q(h;hLJ)H%}oun>*m; zS<|TUHr7mIOzBp#j=uO+e9c!xJrrFTXzY~13?G9vm8ZuGFi7TTRNO$Xpi}wk{vOVK ztAbs=qa2-BgU>?q;C`7g`FS1}obZxCL0@>1ohbTjzPXCpq?hkqI0=0Dmz9$)s; zna7TdiYVTq%CSZ1Au3QkB%kh?8rFv+wj_IscSsN|R@FIjYx1}JtL=%At@*B;XvZOJ z9yrJ{t&1YFlbu_bt2`7rG&oQty*;nRxmt6HEA8#l@8TD`3@VgVEr=K3jrQD5xV0{Os1JrmSWP_PnVFstg|B%Ao;pR-@wrG9X<8R( z>x^?*JsiyzI5moEX)3zNPRFtQMcKWP0=lO+7~++(NVm80*_%5%)5FuP9?f4p-!nDL~YL!_h&p1@wX$Q3YUcOLYVrOzUKds7=9t>1G@CB_lm+ki} zhify#iQvF4Ep1EJ;TrbT&bgK0cG*<*2sBLPfyt96-IJKI-lU1=L_@HvD(2`Kcg6W~ zkSFt5yz059G}-DFPj)!%%F>+UZOtuyCanjt%AsrmOP!_Z-BAg#k*g6>Gx?$m`nB)h)+0dXN2LHUXeDn@}tHiln8i;2u= zz~Zp-Tv?s6ycM<4{P>wPsjIKL0a~M;LCKpa%PVt4`cx;1w0P1apFc_qz@4lZI-3{;F)XHM+}hHwQ|B9E2i(@?B60ml&n|6`mbSLo3o}vIBD3d&B*_eb$$+1?HX> ze!DFoW&g-7n+GfW1Dsdhzwv6fWPK~q;KZG9$QF}WeiV4hY8peF4F^-Y*-BBH6$3Y{ zSr?hd*CZbn@IMYLU0wNXxf3(GN5A9775Th0`p1a@r!QGoy94w{<$lrHeC0`sBw>)f zvujqe<=N4F8O85cu8CJ87bN3dej3l+TJVo*O|x4iyU7z{MMybVjXj#*7bCB#r+393 zY>#%P+vQY?m)NdlEmSAcs=OI11KrY$Qanj}<6g3+NA4*vSKaj@%@Udz_^N! zd-500Ra6b9c#nM*J<%8}k=kXEGFhg5cs3-^sPw&OVYmw~r7L=x9|-S}I1-_^>I$o# z^72?Batp(FkS6e(PAFUO;7L9TPRAv*!LGdja7IF5-hei08fXTzC@qKbWDd1u%{<%X zt!ZhmmCuEbco^(2DWd0m;w==ODxc?f;pt(m;RB6mJaUg3d}PmyTNTH^|^7G&1If{;|>z#r~4^S~41po3m=rpCjYMX~Y7 zCEBuwCA}ai>qv{(=FKz8<2|@0l+VP;i;T-CtX2i**Y3olxS7jn~j{6hafvvI~yxoiL+{+Kwnnr z*B-^x=|!wcmx83k;11I1S) z0sphQ#^O=(LP4r6X`7!`GE`+3X%pEj{LSA-PrAFPYOQu5)mj?=Xuk7`i(~%>Yj+mp zS9+d@y&3a@cTp}}Wjk@ovMswrB_b(Pwj3=MM<$2jY#d+)#9(VQx*NUk*tfybzyO17 z2Kx@oa2axj!=bo{B5gUeWm&dtmn>IeJ5HR7l&exnDwRrAu96Gq`JIOy^MX~$MGDpE zf1m$3-|{ZcyL|8WonoOVk>oWZthbHsXbU?U!F5HnXduULj(kS%l_qaR3l!ud(8_vK zV*l4ieh@yt7O8F2yPi9~82&dVBl)mN1w7kRsY7X)!(VBfZx%Agd@;S_=hdA5OjjoJ zSMkkB@N=2VRQ`6pt19D(JmE|gHly^Zl~n~ue)!VPLWV8IuW$aWT^K53aJK8y{6dn% z-WrKq98I+@imNb1%qcfUa(Fgh4Q_SKvECq#HM78;RswE?Kk`YuL6vY527T^xIheNw{_)`zXGi zSW~^KahjR^J|YThh$yK#L`=*lpBZ?IdVE{{HV+W`@WkOYYl(Z!XtVDJ(-Pjbb=8+L zF)RSPz@}TDBrSBF^{_|#OUctTDzR&r6)r#6zMgtvM+R&WA@|HhCnee~$9`IE@5-{h z5WHOelhuBf2`}ZzRx&nm_l>lny&J($dsCJ&(?j{zI*J_U`OxSa(M2a3JxMm>PDa#` z(;k?D+kC0D@NzA!k*&KqWr7UTQ(n0>-0lB=hjECr;43}j`HN!U z8_VRB4yQeRHzRe*5lVKSL~ClHKsAO==%cvF^Q)e?(zskNGM_jqivVNb4*H1xn*VD> zjeFs4k+&5#yb~UYecft%_)-v%cHp;|TqaS*2~F7qatV<-7CvR~f}k+EnSWTKFIJjt zuux&|+3qJkkiq9Xxb&_8B(QzQ)8|PmpQl7xPcVizLfz+zWw?(=~XH zGIR_c*tc$_Kb{3l!xrHV-9=IxVZJI$HzdN;ohM&i7MNSrC^ zsqV59{RgEQ|M`Y&kyjQ%zr6KkB#Mc(9PF3~axr(@8D;{1%Dd)@P)TOe`Q1D$`A+zZ z*RotR2GZLPd^C};{VtF4N!HA4Rx@(v%{PBU#$mPpwSg+=R!QX9_|Cc>{8gLP{DFSR zKOebgx2bfLQ)-1h@5at7u`F5Qm5U{ry%2ZAfN8$o25`reFu9cM*)NJ5Wk!Ra6e<%mv%_2wG(-V?hb*% zR{kz#rexa=bMb)4?s2~J9!_$OljNvA1DWBE+9AeTo2xv4KG_f4=8?{291kM*POM`^ z?&N>EYA+|Bo9%@4hazYHAR{2%Xyh&{gl4pqdwBcorf7pVL6@E1bUcx!^#X0N8uq?z zPYdh>&Ep%$rl?M3UyQ#tznL=UvU{)1{_K9mv*OxnuB|1j&q{U&rfSYd`SwX#ygRhc zfhsox)F_`H1p} zs9lUkCB6i7GXuNz`xkzTyF3kZd9(8Q)_yLA?r_xpGIdzK>^R{meI*LA8nO20x z9W|^hy_mh6_4%`DkCe;X;293FY@U-HFfKJn&8Fh8@yLmz2AVY2Rv!mP1Eor$3PR0eFo3*KQ z$f?Q})MD}X(ZZbS1DP>-CPVxTRDk8=gapBMPxrd{v3ih>eR5SxW{W~RQvRoZ&6xNG z#kE+wuf6&-B(G0uk)>Q8?(x|y#q6w|yIAG=i=HqCazf77t*)EGcD>4pz+1d#ZE#aQ z%{L^*H8ruto#ub~)*^GDPo=Z!Bcol<^Hw#jhp>r8z~bZd8__s3L`_mrj|FEc5tTo+ z(>yhmCM=?eu|BrgpqVfGzvy0f|Mm1JI4v(u?qqZ9JJS!3NXAC2!80&Uwwxr$L%5@P zi&|0G4`)4xt}LX#d~TeCo@}xfX*qP_A(93)o0sE-$T!)mLE}(`hCw*?!cTME-hRLn zUkrZP-T3*)vJ)G4ylkwxXzOG1xrc6`HtnMY=0?AqCDZo!i;bu4DG*PPJ1ePnZki9j zuc#-uI zY~!}i;?!_DC?5hni3_4B$Ff6Kn|viw3rVfA*9sY_F+a$U8Sr+DSB4XR=^IN$ zPM#aaN@97QtWZB04HT6d-;!%Itoy#u{>ETCyq%R~9}m@e?@oy%D~ow{HG8f}igCbd z`b;~}yAmvaYJ{v64~ti#0*n6BdYL17F`E&y7BCYY(FwdZTS!2Lpt?DmnSWXp4?Dj| z|D;(wNgn<48&2^dco;k|(qr!^U4e1>cS0jsg1gY(7@eMH{h&TMOw^Y#vMxLqYre1K z9^5uYu@L<1-Z1{Tsw4OXyGUu*mw1LENH{}Jov+m^=!pKfA~$LMfppZ)lO9)SoLaHB z6J_!X_o+1BlQ*`^bmKV-?4b&sKFyLUo?P2=A<&9*}vvP%$9$}PVdTR za;!F|tG&(hi*ej1mjJ)SYKystJi*s`d?Gt|bTKP*5Sc}5tsqURAHYsR5bIq=0VT*L zoWC{9Q4HALP6QLgQ-9I5ekO^T-Y&Qr$7YS`iIZMwx+ zxhz)fy|me}c=L{o)~gUp`wqZ618B=Z>9>~na%SVVP(<^9jmNYrw%wX z*7*}l=@9%{5ZGHC8nJZ$Zx6%;g!H55IO zD1I+%uU>ECtyQR2duM&y$e@AvAA;=;YF7qell>OAGIH1sfz<0A4;RD->-llI9!RQLB#x4o-<~eDUz&Zs&%Iq zJdm!+;+-0&ONlFI^MrFdW#Z)-ojuO4=b;+C{@6cl!VH$cE@WBYS~JM;D0i`V?ET@` zR(e3oRVdgU>;1eJQXhd5%lY41hA)r&jlomsTSwtm{U9xosXL*(Qx@>5^|~;!Z*(;~ zv9{AfMJ)+JihH-qSu= zuup@$z`yg!JRJ;Pe|zZRR=$^CQE^6^n=g}-wel?!LhjK*EhI0oD_>n*+K5jei7X6G zlNfuw(ccL42^UBUoS;?L@|%>vXZ-TR_IgGphxbOx#kKrezBmfPOO~A{y+30?Z}gRa7nU9B}U5Pz}h!HB#;DWYNuCUZ>J(4Uv2_-l0X!wpsDK6Uxq|4~WW&d6$hE z%nc9ic<3%Kv=sg2{iz?i7K$zA895vgoVnsD3yV(dhs?%wR&&lon$>}zfpa?XAC4CV z79&aD&7J=7p;#m`N`{RQHd`~B$|ttc`nk5)5_4ijYkSbl99gffNn01A9dJrzk*q0X zLPt>zPmb+C6+H2OV;+pfJ7ynI-rrj6l@7`kc zV*2DgS=Hfnz*_RbYY<<;KL=jog9=Q=bH|WTQqvzqSAMXCa+D56#ey{1>aos?nWFr4i86 zdak^=qO=*}A5mB5blX2cVYlh17toRC#X$YR9p1C;7tFb4~p0|pA9Sv*abNvM)*JC}0 zR_~aPu{I}T%%qFFvEnFB=H25rFCAaC3w=crW@qG$mOSTMeHoof+_phZR3D=S6f58T zFWU_J@Uk8dn@`*D%O{vozZq>dV&O_UtZ*++g|t_)R`%J9 zp#Fb789qKAsaY>NFCPxQWJ~bH^D6hrmbi@v<1E*XO5f-qh9Lz;3)wx#9{3Hzoaiqb zAm0riNFX%f`86hy!RH5W`4LW`EY7O3LOog_?W^)?RpN`ETHFzxB zW%Ocqdc(`Z9r8l+%-NcmwLTbVMUmdtvm+g>)%tM1kv~cs)c|}pXB_vvXCN6bna^(A zVqx*GToo?4m)A0#zSRodi2qBY>%W>a?89U8WU2IF1xT-|GWaN}2iYpIyVk{P>IAe? z{A^}P9mnNZ2UhoHo~(trlHTnUptC9;c{8%~#Q|E#s+>qPK{9x1w}%zt8F`mc4(VAIzu}&0ADT`RRFATw zBoYT}8Jc%;-QS*MQRHTPHUj;L$ZJti8x2`6GU7827O(2N+Nn@$D$X)bz0(ic(zhU3 z7hqk;y*G^Td}~qQH_fJH(Bxt6dNVQ(DS9@t+Q&=DN2zPRH^`H@YMD+#R^FIA6RTsq z(Bw{jTb+Ut;sVIBAzh@|eBCdm$H&s1MBprH70IlJdajv%bkh!+Rb%kQthHDfF`{~- zJf7yGyWY^SlCr zFF3)TxXb!F-pLcl;PAPtH{^kbAY#UM7-Tga4zn*byIzTQi}~)PDU}`=(3aT^RPNvK zg=82Z=|Er9?MWuEt7viZ|#TMX6uA=t=*u zV+67rmlJ71YxNp*sA%AScb998OYxYl$c@`qQ;!_2*)vGcoCWz*w<5P=uOWNS&!ClR zhuL`+jBH#<9+!`7$BHb1>s@IrEog~)3Gb2SAIp`nk#Gt&wSKT85pRh|$f}Y1pPy1& z$p*qgnF!y>MnKla9Yw_ERB0+HG$Ou(Jd=7R_sU50v{P?x)l)rE*%Ek;g79614j++> z7^R(jOPQw%32z5`kTD!fU6YkUJB^K`wNmy9$Yb8m(?-F^6Bpp0m0f#Ed`p65EZF1e z{3WYMUfF2&nzdn%aMH>)dDD~i98xIbRIkC^H1KMqP|gbiiGXVxMAY=pnaOfXs@(Pt zRGuGtYs7+cxTSnS~%K=NE1$s|sY+=dw!;`MMMmg5 zjihgI!JIt>+syA=-hRm{(>%xsE3W2uZ52U&=SEeRHI6*&Xz%{+cnj(c?ZrAIUaPYPa=MFPMZ9?*`hN3%c*+*M{B6adc6vA z$(ukWeT(SG7yAo4MQ~8Xx#==ps#HiS-Lk)6f5zn0t;PJ7DP`G2C{|k=kB|=%iGE@& zBjWe4`TE8~_gL{ENqy@VyiSA z4?f9fh^Nro>SS{ll@;i&%A^@sD240jh8JiOKbjrFS_{pSQ!rY1 z3NIi$U0TT8j^(N4TqCDva^=?lQR#9|x=tB<9H?+p5B2U*WPvnZVgq!J! z-6f+RgopMbm_PeLn)#rpPybv00j;c&(kVD?PXf!%d$WV#UfN_Iy#Inw0}K0nq z>>H}e%~>;T&SCU8y)RoJA%4m!2 zSf#%EP6on1qjcSz_`Kp{Bf>$JjHQC>A}&1Pee!Hy&Dk1ku}GqGsCQ2f*>U}!&N?h? zsWzbxNs@`>ouM@E&ib2i^Axnia}oXX3}7&6f#JL~nK*n#JN~$NLMeIqxxord1ZnCU zSgSE})(G$6GMmNbnYB^WbLIb#16WaeA+q4v^U{w+-+O|oQ)5Y;r(m-)Bq2JUi8%3@ zB{SC3sbYc+w2PH73V9t;W!|((+bE4*e13P|&u{jZcc`uIaD=_aPpes^`ljvIH}I48 z<45JN_O0s)>VglX<=H&zy#wZ6&k#Q$Yj-J3>?A=^CocpImo664@vIc-w8B1lR2Rg{bTXkAx=Zhg?ESGFJ+d7--uo#VPVq zPZ}jf&U_m>uO;`0v{7G=^cI)Uv%bICqH3c^wDKHj_jJ7@`oeb6Hd4R?{rA2g8LDs3 z_c1z>&RX+Y&6vmOIeLy~{xx$*e%AAumn9NqM`^LOEB@d=$|tqu_=i?bIdFfoBzTM7 zkowvspFt3Us2qVk-zmLtivs9%9m2t$m5b%;0y32`38NjeDe>dQax)oOREOR zh&5wK;oU1YBVG1G{$QwGVR_Mq$IgG|OL&5I9EmSsWOyhSVsiLvcpcZg7^ny)i z2dioi2We*0?7-uF^4Qoba=es#d2PldD!Cdo;WP1U>Lnb?t|RzBYkBBKPX6>yuTkUG z)E0_Q$h;LY+#pBI3X%oB$U^Rwn}GL?Q&}@Du&*$yr;KZ-+@nEDxr1EN*ydXMEZ^4WT^D1+(VM|Vm=23s`9XS6(rR-@dU0n9#u6% znrWrc*fHB&R&#-H=yvu>^So#z?V_9Q1B%tAIeWxsss^Iy;8OqN>^znnCN{ zNcM|_pa$IGV_UsucN;O+PlaCigUV>m`|*$5vl>V83}leI5#Oqj*&j-~CKIeS*UGi~ zSz7#OO}?9Ne4;&QAl5Sr@?-2!>G{y#XvsxACDMiuGB`3NaDe9FzN!zH^dvK|&qYim z()T7(JF1P>+BNCgJA6+@dB3om_kxCv1;h-GGTQ6$c%2T=)uC^tFA+<mlw$L#t&h5eoCVc)&0YRq2G0im7-*WQ^6U4Uvy{eaP>USf6PnaM~S7fzIB`Q%zKXrY%-XEORc>R?R1mq(~)?tgu$}4cmJ%q_(f!LXVYzFKeqf|1k0eye`I|VyazPO|i$mlOr8?5}p&#Qw zZQ)Ub>bFEE#QE3nXEu<2hUl zmFTF*xOBOnd*}-pSkM3RMi4=cdN~@v1GQJ$swFQ){`FR@rTEx7p;|`O*uOFKL6=3? zo>E;zPt~oEhDYhgz6>!5oa*~sdBFO{%`=(9bo2@8v$xiePA9X)j?;bn`B=MG^67BC z)ykRJg_(RW#%UXT5HSd!L=6EXxSi)-9Oh;95!Q>r`&K;X{&p47yW?6#N_N#OUmbRK zoXM!v9z&D+=?^#MXT8;CJ=~F#gdno+J#`WitHFi))@vuz^6B6s+pPL|A!D;DWj~qs z>he-#t$BPfTjU_uBo5?{LUkG2*K&yBcmBw!9*YE7!zR?O zED^a!GrG(BIG9snc-wq~>Yn#n+m&)EK8-iEd_B)BWNa+Ld`3T?C)GAQ3?=&>1(*ay zP;bLZZ0I8`lCdEJ)@14<@;aMcgKs!`J>$QVd)0lgypWnqLsOUntEZw5PQSswCqs>v z*7?73ANGnmXKFcghHT!%`ED@qW`3H#hxNR#dLYH >Qkb*Vfcd&tU2)CW)BgXpo zab+{h-b})li)+NbYEo_w(g2B?>$o0ns~ka{o}=C!2FB9Uztqv%^TjuD4dST4_jVZg z@L=Rik-=M;@6GU#eO!nvuH*@Cut6mf_f~kpzmN$)QEO=`V5T!JT1N}4$(bQtWa-HS zyD8q|XOV|SUS=UYpTu7U@jor#W9+qegmr>D4!e0wl<#qgc+W?Q6 zjes<*=}_xpkMJkbn+yj(iB#Rp{PpOqG^9$71K*6k*|YKyk8qO5#h*qIcGZmd$D*Fv zS*?(8b1)zLYy@*@c;{wkiwf|>nmDABV-xXM6VXC3hW5AGl>H^y(ENO`j#O(wg`hp* z{8ekDI7gmQ7_FVip#@raD_0n^Jy^6WA{P6HvHhg+8k|^1tAY~t8ots)zd|_ zgSnH3O*_%xaO6~!ZXBe6w_ZP{b3-HkiQIz7Np4L1g-d9oHP(T5&L`yevsm>7%)%!t zKGtq-GO1Ceo7hc1~nPYM0qZ??R#)`s%v#T?C1i+WraJD4jggo}O#JGxY;I0*}7`*v^KZR`rG#A%gVxJ);Um(;5*ZoI_<$1j@2 zdv>OdR!Ee*1dj5JV2L=zVb!PE4Hg^9md>nGBlP-_st2s&sq63t6m{YDEs}!T zL>Ognwg?sW8EbdxT+K{Lt#MWp%pc|Ww$|KGz{&Ew6;Whug}Yf_nu{iEu2%KvlruU? zce--b+!w4SOJ+$k$tY`WJ}B5&mc-+~e17RmyPCE6V(bQ;NBOzv6dwV_{ZE>CAtDHT zAbrrW@eYen-C<=bHANCyd_Jvh&5scs&TMaGHV~+~sRfeyc97gI2$D7xf2JNdN;?;h zEqg6^`OTm`4aSMNPy+tgJE#iC8CCF*oQVTom?aH;sH9zy1~T_%Sp#%HSH;o^O5mVEg>7sH#Cbw@fWs) zODn;~mHdq+*Fy*IVftErtgU%J!A$OvJ&_5(H=|N{`EGERv}>7VFKX~Yc@*v!m6Jw3 z3_O*8dneT5cf()^O(ytk8+@jZk8&UW$k3?x;77G_)jREb_mUCh~dJ-V5rq`sE7NRJ-Y84pf7&pD6fgZqMc11o2Za{N0 zClxZI%}mBiNXQv_k0ieVLpuP^ucK8o5MKWxj_~4uy=+v;gi&^TB(Hu zPTxTnh=n3NOI&mgI*(4hg#8+0`Oa!Q&azow&y$bSChe$>UkZm|y}A)nC{t&xr&V#R zo%y-O=IBd?q&fL|bRxXJp$H)-s@5RQic(2tmnwwl5-=V3AEL_)%hQu z#yh7}ux`*;CIou3b?iPbn}?0Fb_Y#mc3KlJ+kIlyubN$ZBl@J{vdH_>#>4#8RXv)Z zHzTJ#stcMcx$2}x*2wVL_zPmR)gZBOwfr7~$nU>28!X=Y6eaL_1W4^Irz z@NPY4?xf1N{p~4zp@;dmWxgfo_mrVA>1&*TMzD;}M7mi@J>V9ZHCLQ>Hu%kfOZW^)(1%uvQjL%&p#so* zJ>)v%twdV<+O;98)@;&LdOkJsojkz;;T@0i@W{W*H7jNG{b1&0ZBs<*4azE8@wF8~ zjZ1l1jnQxw)|kK3bLsy~{uedU1(EA|;xgI1R*KSX{*ipYvrNv$qj)b*pU({N6-Bq@ z$GJ(OBO`!ydNQt4kxVod5u>0Lh6kaKObTkyv#t$a3(DO~FSM4YAn%0J;z71bl%$gI zX0E~lXO+<9CmA6$$6Xc`il{X(AKr<0dsFT-irM_8Su#;_aALD!9m?YoWPt>{U6RfY zp%%|d?E&8g?vpz!J@pu7lXIeHR+Goli>y+$ABEovmrsPMb`O)M3&TpWGpu+$k8}5W zxQS~I@)XU14(K3%KsL;uw4$pq**kMI{p=j%dT-jVmwzKxMJ+t;^F>9_57UPVAS+b% zS6s~-#ON0ZSc#v=bDYUH2lCl$WFDR4PnPL0^q?dB1{9Wi<;|mS*8<&lD?gWW`WY(l z#YALIHWtgZ#sK1wG}TaAVnf(m5+(j*-_4vnv@bmEx&5bdju{VGmc}|iB$bKg-N-NE zt&XvUHiChn4a8iM+3VDKF@DAF(olj0~m5s5xMdk$e;iR6S&qA)C zG-`+_%u>||56L||{)Jqtly2xZoAbuN7xzGL{u7j>QL3v|9(zN`TK=zei5_Ve?_Jdh ziF=U$%@Ah8I~g*c>6-Z5?C=vBu$;v_nn|xoF@J>xt{2Q&(pSjg>^s;CJxM$Lhp4I$ zctQSeE=OQ8&&qFPaJg+PbKXk8X@4P88)?-UeTZ zCYd$lHEvgFsP^iC)kD{^JcE1i3`Wvm{-gQ|b1z@4Me2t()Pt+6;U>PIzwv1UMK`1# z_l+5fm?dt~&Sp@`cTe(====s|^3L_3t#a0!T!Ze$4asPYG1oS`(p(ZNiu>f;Dj>gB zUns8!Sshf?G89B@h-7S9?w!5bNXm1@?R&_fRrMlfPxI^f>CuRc6L>5_Vd87o=!xf4(DeB*(K_l6}xGy2Dd5}O+yLtGj`meE_ydTT=UP1|uy3#5Q$lUC*?Fq z-sSwZe)=R=uBKy74gPd!`f zeW`MC5QQJ8J$@|S@^j>9_eZkKrYJz?p(0fB4CzR=k8WwW&N zd}b!H9eYgw#=IR}P(;Z1AFoaBn87&&ga4t385qi7Tf6j|W6?7nu3vk(y)&~xaqo~LD3$pvbl zpi@Er^&m0q0kk3W?37uvKH5VKcx)`5;QP?_IdOC?3PoXQVuokkMPjt%_vL|;( z6GSmQ4Y?|tWfi&;S;~J^34yDaj|dIc<{JdYSc1#_~)%-BH7)Ee9>Bc$oX8sU-pr_ z;x}5j+j|x84)?4m$?>6kR{~T@Ear)3L&!M)K&@3V$vWV{%zrFb#Y!?Z@Cyw#JQZ;t z3%)1M(97bSF{rYl$#^avDVJok$b&qkOp2b!W^vovruqpQY8T(}V9l+3E+UeZ^-z*z z;15jl2}enTd?2~z1JG*Hc|OuM8EVhv>}%(}%X*1y*k0$7kx~+%Mya(&wHVMAQj1~O zH?nMX-_CLNLT-icf-k(G%D=n`I?NBhaZa_C*m--Na8elDpq)H{<^}9^yqEj#CD(I3 zW3~=rhBVa~p}aJ4DEVnNLd4n~1lE)0^K2`pWd897O`jx-ax0$YcalN5JY)jtVo?$X8Eb7bj47{iMplzv ztf%*859BsjUr5>Lh}?$Br6KZV-LgzijU05+zGZ^QUQ$-P7c(vd!aQt;Iw|QI^>#MB%uK z6V{rqQj7ZGFRMY&)6bQc2@&UX3_bPQ|Wob zKdg1GKW;`e6Q9W-#Ix?CuAz zw7Q+YhT86yQ!qnOBY#~F-3bVTS{-(G;0}(_a7YcK#CJwdH@iAmJFh($gc@Yl7_HFy zZ53R*@_zes&HEfwfXcMWkFh@!xyyX~)1zz~kF~f4$yxhaNS147#(gkohBEc7NO;$X zTH~Z16q?!5LmprgF9YtG3B=-S&@BE#*BtDKmLXIj*!&?GAVtM{_ZPda#hMrY@e7vJ zhbN=vzHTAhx(`7!$2zv@k@rn0U#Dz_=kY)|Z zv`CvsAiNcO)u}jPj{M;(Z71p4sOMfx5o1&r*+?;~ zIu2+pzf`Y>_v9X^r%zG?Ye}3j@^RJIipb22$HIDvq^#?~rJkm46$YxSHbUFE{H)d>zoGBiRrE1a#nGgH@gr49T- zWj@b=;F{_Q*USzdWow~GD{9K=QjjlqJ^#)m+Ocjx$E+F2-!#skRWuNi)fe&?oYKOk z{BP~KsH2wGT2RlivbIH{?qse`Osd|P9h5O*F(dl~kN9cz9S%lkVZvm*QF#m1Hg;ic zU?gc#$IKJBoBL#K_WpBG-+EH+*6gd4P*u8IWV54mBnc#tZtVKJIJ0SB_Hlj#Wv#|H#gizL#b2Q0n# z1MNu~OeIHJ;o<4?a%7RVt!B2od+Ug|GxGUdvuXrE&P3+Cq0pW=PtZAWGSudMxO#P< z8cDQoLl%r|oXayP%LkXEfRGRt57lb2g58D5H-@}u#n!?U+Qyerk0yWQ?RU6L6X_C~ z!+Gx^QMWgdD=K`nBI+S4sM*R(>zCuXTECu$yb&d7r&NTu_1E+Om5WRpKAUO~{o0F5 zuSh%!7!m9zbuS7Ez&9{z`ysP8#i&fjAV zI?6_&<_D3XWW5{c!Gf;jA0O!H$O}56iIKN@kX1u1dO?nH+Wh4zjjWM#v53T|-aeEu zk$tkkH$Y>W0dergnGjl_F{n@4pA3J=fzcv8+|Ov)i`pYqCS8qyw6wvKyeAsYTC;2y zq6bAs*Kp!c-dad%c)R4vb21OC2XsRdIL`M&8GjossYfkzko!@yy`DQ+7yg`keMj~; z{6?7)wz*f#`b6HsV6?1-A90d;26NDs47~g)uDibxFnu8(df-d44n}D=u@U-&L+B~G zBu!c#8{`L#y|YU`lEo*pM(Dn?v0PfmXEDAL5xV37A zu-w;vb+>qKCgg^PRSH3QR#psxK5&er$qVuuo1uZMq^ugM4Rm6S*&ueW2um_OXC%gE zZs^jL{0%FEOZuujh&AOn{WIkCJ!zZ+LB8E?L z^gObjeMDbzjm+`&Skjxh6HRbizD?gahcYCB2faT!1NY3DwxBu5mG{AY+-}Yoc8R>8 zuAMNb4zs)?h&-y~>)mJ5CTigw*?280dz1NWuN-c>+bTo#+#QhWZfr*7z!;2%9e2O~ z#cldsNo%Z>T&U;Su4)eKS1&r9Z=7LSR4`Yc*bl4!bkC@b3jKJxWO-(g-qwV$hAL$A zXq*s&@8aYz@@!skQqEpXWYnrFSyp}zv@ixJqX!i_Fx~%+d3u_Z|55jB5F6#3yM-5( z2lOYeqCc`>PZMp1`m~oGv2Hl^(f{P9KK#)SjXd=yH%-p(A07R~FZ>rDA3ZX;x-`1H zG_$zxqkrb7KJ@25d}ZswUw!JQe(Z*|@v-^Ik8a76|H0i4J@x!^OM7PKCl@F8eR^VQ z@}F1YL7e4lx{cDpG|IeRW z{>qQ{xmlkV{;w^+_SA3w|Iu;>`xi92_(wnYD^LCD3~FuP=-#FIiP5FyeWRnF{P-`9 z_NKpNMyEdX)PMSuZuv2)y#MPT`EwsR`Js>G-%ow`sXzOZ?q1nHwzzNR$mGQ5PX4+| z%;vpFoOc?$A36DQB*6KAJ;7orc9QpaAt%(^4=1ms5{zy# zU(Z|6*uzKpp(QZkZtl1{@|Pp=Ytg4$Bfpe>to)0I)LQ>Hk z&6GLh-CWLWttOH`=m%|R6dhn08(-28SO_U$DA^Vns*a;w-c41T;{-n6$~lp2z8l?! znfhR5?bVb&g$&x4|0`0ls(cu@#}AQFX&lxQ8X$e6GO=3oevJmY+D&_{-}#G|(5x>n zRxYPmQPuV|c~9O_`stjA^uwIhaXAs?-u(6T&~-L+S8Xpty%x^S2JgIe^eg%2w1k=D zwzsC`**rC#ofB}=p3^*b=Q8Sr^!iFh z^iH_>C_gV`wDvOdxYXV(XQkkuqz!TBw=$B8k-@#0{oNq8(;9ZC^@&jAR-W*Viwk+P z3F%T9vLkc2lql>;X7oGZ^uHPT>$!eEPskDC+~v&c+qv`h$j|4VyW!EhdE#)M_@`;{ zAEnJtr^PdA!}`+mp}9=_yLtLZbnju>dXVwG96E?Hujltv{@Rz!%2GylX5>%jNuB^r z`ir59+>W^1sS?xS(v^{)$z4BQd|I9#`9gXh&-kWt=f917e=Ti% zFWj@+jBk(X@;Gm0WbdV}!pY39ga@aFd%S(rx!QA~>P$X;HS@LRL}z3sn>BhqGd+^g zKFR!lCq4dr#(FgE*gLI$(0Iu?Z&x19xv=~4>BaEwK&W>&{o7Bym^%*UJ1Yb*pm{4^ z$auJ`N=BWsDv+tnlK+1+cOMPqZicpMZhtR#S=kzo2JFZv-%GD*_a0?jyVLu{;HLbc zGXq}A&s1joO#V6)eNsPgG&FlB?K~Q0w^`N_xK>efa&f!J=;pZ|J9% zyI22O06-yUrdcOPoJ&Z zoOWo`%E(CWmSxrhYx-ZjBG&*mOHy_{Lozeo9W zG1@hm&#G>`S!pJ^F+1`yp}>$dGrg>z?IMgzi>$1wez4B((9ve zxBT@+T75Q| zJ-gDKqx@c;g89=K<8Q~K_-xwbkA5}3@8z4fqaRKmni=`qX?atk^o7vE-b3drsG+-; z{^_2#U4JoG-_Iww`)WpUH{)E+6&7MVv$D(KO2(?jZ!AAD^fyA882x^^-#vGj|7?TXlZ-t_J z)83izUfw~BYR|3+K+cwCMG5*_uCeAOP%gm32H!G0^k?V`m(W@DcQ*9oGUiUMW^&rIG3?Fxgzq`_cQ@hV) z+&??WmN$<=2vWHoc@aIcEHWc1)=}@v`Ms7AObyb?I}=Ys!CSe<{$Kl^}%q-W}+r#lFAi*vQB7=~Q@jCDf3EITwFz zPo8@>G=luM((2v({;$H9-^yI}Mbkc?`<@Q3F2ssl3WwCx{xHw{QSkqTNUv4AxAOaD zaDQ(`^B_;%i)0>%W<$@VaN5Z;PPRFnu|3SQ@?x`j>ZgO6XL8@)j&(Z~-XG4mWuY#I zqfe(58IMcR37KQ{AhV$fw1Jb_R4u8-i=4igS?|nTPvw3w!dG(dDOz${Fp0o5A*k%`NV^B?iue>9{)*xKNDKK9{Qio z%Yf@ur|Aq9nlwSB^>b!R3 z+UIkRcX``i_K?Ux{X03ptci^YQt{aPI2)Sjg3k{7O)#S$#Gi1-}{!|97F=tvr7> zRCG@3)6p{8>pWw%-*2YnUyr3^-MlH~EBX6We8i=U>1v{~?`GciQ+xM>jLTwZV#frz zujH~|!9ZOZfC&F?$82ai;%IsZTgnWV9@+jJ|z2PB=7Mr}a4xcRY8`=j!Po=(hB6KY!oN->jB3etySv=yV~L+^Wj; zcrnWv*Rf=Hys7(qJb-JVr`im=t1qQr`&7L7fTZz>Z{^x+@xx>!`CB{l_jo={=Kg!R z`um}fH67;=yqVEVq@}&#H5Bs3^si)ecmRivMZ13^&uz-{w<3Y+sb(`mnHp>3H`3zu zjCvunkQ>zEfwbpIW1ma^R>AiTXuK8;QO#jz{?no9gGk8T*!`~$R@yo{AM%lP%*jxXq}n=fV7&!!(EkbitSy^3Ndqb&>RnFo6gRU;6g@c(Vzs4yzY@PVU z+U6P?+Jfl~8`QPQ){_o@7OlAEY9a z^llGl_&6Kyqgc%R79*opxF$0KIJ1x^j6jZg-!KZY`}28zF|si}NG(Khp3qvPa?5}l zqF46iWb{ML;AHw_twq|$GfP(5Dwq{g$OxOey2YMY$?EEBcjU;8$oYtF+%H;?WxkQW zAv}HPIy&#c->QcA5>^$|u(XQCdhJ~PMo%`wx;ig~KWzohnhC6?&sN&H9$-BJhRd$H zekPyJ=BMkqdNeMhv-SXeHmYnm%J#e&IiLU zS%*WRGT+P&F6+)tdT~<2+2EZwCO*vfd!rY$%=^vwe&#imk-&EqNHUAwAlQ|37;FEL zigmuF)kbSmu>ExAwUV~1Gs)b_G8_#Z_^|w8h$xpQ*J?$lRZ^Z{p^Xv$&6`B34U(1E zm%dQdDvP;TUtP*GGU_Y@AJ)39)uNgFWd}NMw6zbac2$iY4KAQiS9$!6&+3V9hKl?Z z^Mm%sqrvLtn$2is5vp2)wKImN%o9W_+Cb0QZ7s|2pB(u!p|0v6*}6YS?;nH$cQZCC z>@Q_>TXO#&4%~FsIO*J_XekFIZUT>w9PeliKxaShtiD<3C|90^1 z6B*41=}8?=_rty)p7ZGbW^~c|+hS;TCB4cuT};n7emNPwtC_>5jKLZ6FN8iP)7JNM z%}z~ugOBE}$|+tt9mF|2A?`8|sNdnd@dH7z-@^lH#wq-$5ul~{#e4g$+;-p-TOg5;iF zPaWV}X+xHBUq(f;e>>Lx`OwPB`k7!Vi~J-b-xXiu#XRGU=1z${5L)ieILX7IL>j0q zw>}$6iIlwQTcl-;!*$umoAG?*;pFDMx#!r(U&)x?%d_V)S1Ysk^Xy`5>aC3Lsn8iV z|3N;BY`rs)5Am&V?8U)KUr%qM$=i{bx6<0rg*FGH>3=Dd-IMXFW!)9*he?O>+^%S( z9b%6&e<$|Ngi11iUyL0<;mw23BzOELGyi*e-uWa4<5|9vF&U-m-_3~^AoaC;zd3lb z7S8NR3;!Y`Ih;HHaYjNab`CAG=(|F3C+1!X#SUcEu^#6{=<<8{{Kee29_l}rKEId0 zSwd@Vto<92!E+gdNcU)Dfu|{gV!@&6foKQcX)>H1%biY_li9Yjwb_*JS9+2W*psxJ zxwBf=lci!MC(<_yA{ylTosC3z2cKM<_kf9Sv;Q=zk(g4sN=nE#!5b1YZ9t79RU zusM<;*5&bgv$O29_jEc#S8rmD7xUdjKFf)#;FoV-&i|vi=A6dq^mQdq?9G$nE0V{0 z!)O-A$u&;!ogR9b%6*e*aW4N)=eL~7M84mh>nlM(Eg#CyOy;wcKG*WexjD1>SsU7N zM$~vl=035eRayDZ`Lrod$~P0$Pvsbqlv zMMl0i-yF#m)N)$T^e~^rw7PX@$K89<a!6(*c9VU<1C-O6lSHTUEIDz4 zqM1z&G@Q$*CiCr7?lNOul2zErT-ly`CqpMZ_iniPe745DKX;N#92?8CDp`AarI&LMQ0hLpRcCLC-xu=Rs)adQ zT`Z^R)ci0bx-^zCkLEA`Pv)Ps%lXV=C7&UvifTNSdp(@HCWbL>&G+MZ)(F(4%%@Kh zP6n)OtPJBpQ95BBj>6*>GViH=uD>I@&&ZTqW0(xr$kDPjjMzWjqcOg2tH&U=W-%RAm zSAy6R18T3vLzoY(p35D(@*7{b<_f8rjx23Y3#-X*puYNn-Fb=zuBLaRw#K_HW7v_a zd(yf*FRYtQ`<2|?`NT_vQ7`8De4arIanV|AzrJXKItZs_@-Ea=^IM%g+w3h|SkBns zz*10)lpF0*KDmOT@C-tZ=ig{7Pc?Qvz4U!7H***5*_+lOl~dPe@|#w}go(^#AwH-IA#JcqKS$wRb7s!*y@J@y{EJ{o6P6 zF`K#1r9YY|C$*CHH)LLP4YP~{hVrywm298CoiVaE_w(-eq+b+yCgT^87KP|G%w`wJ zOh>RiG;m@gWH4Io(k#7LtD$4nMYzKfq0)G+!UWRPeC)$uB;xH9@>4!{{tf_owwkjm(!g4eSy1-fTzY@6<1)&QRr`yx{A=k~ALD1#! z2tsw9aDd0lvJ@X7DhY*5`ebR*wsuC=jTHKBJU_VZ4!Y{xG|~;dr}B@k;W|7LSr)zc zGpZBzMrTwCK8OcsmUQkjgQfmVqAKHG$hYW*r)(2+hIm_s*2nWnA2PhFk?yGDlfS?$e^-}{bd<8=l9;A8NF49rGE&s zDP!|C@wI&3nf7*NRQ`u5{3ffvTk|}0oF4k3pF0OivIRzAZB6chdC`?e>Aj3V8-+S`++5EMbo_1!=?D_Hhh5>TS^n(@W9gZgM_(u95fzRf* z(_*Y!?@I3|vX<*>85>IS{#Yqmvt^*??zCesowY-BA?KuME?5oCctX3=2Rn=oq={S` zvo!^nCm%q!Rx^KpnP6h)m8Mam~62EW{&)J^s`5S*^kJa3gNE+N(kOlK5q4GjjDdF1&*75x6 zKThVUL*a$hN#pSzd6i@^#3y(T<*j+|%Wv_bdOG^bJ6O$CcDWeK5k2B#uOvcT66uTo0{0$O6o=-B7c4j*(gvU<; z;4&JK5LO%|x97LDG#>Phw87sngV6yKcz9^CE7y&q;}i8+wS;|;ixpMBCz`Bx45_tV zl)L8Gz^VUSSM<&D+d^$2Tmr+4qQX#tlxwhvURt{F<8XF4v zX^DMGG^_qCv|;JzlkZs`XyloN%+HGOwqXB4Muo$8JD%~;33`dg=Fc;MYUCUj_hz)b z$M(uU@jOd}u6Aj&0WxjkZF#$Z;o$~yLzp&#;Q#4Ou0c}{O+W_{o|$(P;6BO}|HIgl{XAf$nrFXwKS z7t+@ks2^6ZpAY3tonOdz>?suIxh;i%bQM1M2F39~-%A-APK~7p7L7%RrrsP(v;6Hj zvQLxc#z?#vfp5&SqUo#oYb^cqI7qZHi*xUCaku83m=#dL<-qBbLPp(=|Lp&Wsc%= zz6>h=YUV1=Cm-z5E4d~@gllt=BzJGlCmBw0SfezOPNHS(*uIxz3s&=)r%^4K9Z(E9 zz#@G?M6|5;zz;V|mWR|p6Zxu5xz9-l+TqiX999ACoEJ$ijFZUpw&H{*GlOrwJlHUorvXv6!W*`gV>q2vk{ zX;JR4{s%te6nsJ#v%<;pQDs$8jbx#LZ}e&u^~P|F*1-U>Y)<~o#%|Dp=hC+tANfie zU#U75E*dW@D)Z=n8jhZ#=GBauu8_FKKu*R(0Z3$xQk97l$gHQ0<*xc5=;{tFi1tX+ z=hHHDfSlSKP20w2rexabLcWK^bVVD*1=kMeXC`-iD%THZEIgC?(8gL{0Iz1#uF9Hv zsP!>b;GiFxY|FReFZA1%URFaF7=)s&=iIOob`X&#@<(5wD1V0^!Wubeg+B6<>El>N z1;K6w0cgy0s3OjLIqkfWUUua!J~H~Q<=2#L47n|imwnGdb29#FNX!S$c~Y+F#cjg z3auIZtp7_fr=ty_xGQ-d2 z(^{e{m`_jZ;j%^M2hC}$tk0&Q2N*RQ4vVtrHYBR&I~_lwa%JXxR$Bj3DB_(os6CoC zwS`yh$*zO>=m})w6U#;mK@GR8tj-Ox7AGcwD!`P?LO+Lzr*%7vdF1?f%>>`9!ujOj9W_&-|tXCRk^v+n~Gs_^cnQS3DL9&d4^x~q? z>RoJ52{C&9G%BiCkcnca%mW@<`Dh#|ug@CS4;2fsDG)$RByuf3^`=#FLOzb>)8~h| z*XLdxv|9X4|IGlOVHWD@+b+`kX=6|7KA;+nl^cPZWR$#E3lj<3f#kX$78N?zCqPqm zEX}$YI~kNU+@Cf+ohOW!)z%(4(1(~&d}bC_mulgR+u}F7nYNK_s@2Y#JKUta_}d_i}z}tH*|xWl7bDS8L_<)SbWx zZ9`95t6IujDi8Gs8@JLVJ}Iu&I=ibV*lZuKQKTrl0`W^{agn&+Uq!^uNp&LJ!$Daz zGskiK;Ag=UQbb0IbZia(49e7ES*O08zxigKVypOV^;lp8xng;(K+?LlUNo%Kl2#{5 zS$S$asA5u^v<80DqDCXog2%$I5fj2>EBpAvw#ZWzb5I7Fs9)rZ$=ztL@c~-KggPb>OVXZ+BJ( zp9z)6qaU(@ERj79Ea|R{)AO>HD6U@~q#23n7E+xgX0W=7Yob3bkOq?KtfSTVBx*hQ zU^GTFhm%F&g;+bh_(VL$`UbKGDj;|h{7UkO@h-}g`J}qy`4dBtb()+mfH6 zrMcK6st4bYw5~PFyNLy~qefFkft7#(8@`)2VAC-ig#32#(p=-jPrjBmr}7&Q;g`G# z{?f6Tj8AK97Eb|9w&jY-SQd&VG-nI5AeMh5y#9Ez7BGpf&;eXzwb7-PhV`n&r4Rf% z_lgBb|E{#eKJXD5+v2H=2buO=<3?7oF-?4zl#kQ%$$f(1q}H5B2x{LOMh!v7BE7J^_K5`ZUd&ZIP6BBfOW{fU zAW20m76#S%=cJU6Qf#lKawkdPWx-k_K)o09yDO^jhV+WOSWj{!Hj^RV@UN^BHV4Hj zH`dVBj`V8>>GSziq~p6lZKHcR-)zra^xG&Z+q^8_s&G+NBFBunuoFfYD-JjBBXjJt zJa47VOn6S@ioKKn70EQO$)lG=7*E^irDa;ZB|p|PNrV+hb*JcTY<}Ys+e(9tmNf7% z$iCT#2aTF9;cnOQ0LIWY)IdMT{)Isrt2d}Y(qM=G-DOnu1w;y+J6>3YgcoM)DKpmR zTu_odc_#ZU>|i%8RtLweE8@EmkX_k&T4fapN{O#gm{iyUX_h_T%$Y3ba$Q>@TGDBs z4lQ@Tte5k3B4}e_L7Fvg{uZTJ7ydm=wp-RtbeaV{#31BZ-jaNd4l?6T7IJs` zB@tgpE4+O2#53m0uqg*}rFqo8Rad0W8e1vj(^`X6#YOk6Mbb^uEUQ177GzPyA*>06 zLZ9j=jzLD1*;bS%BgN`Ztre*cLScJ2WHV)t&;FZK%;DNZjdE3)WN3)1lN-+XKM& zf_d=XJb8Z$x$n7*?wLsHL^wL0`JlR9)R@Rt;u31qUnY6W0|!Qf80+EIXt3zTJgGJB zDMVqmkZmq$!qwD>}NX|)kg<;wQ7Hx`=Msi+2Nds-qrunn52k%He= zROL2g&iNNcB;QSvx95{ORQW;_QY%Pe(VkrK=1?4(vdMVK7DEXcCEtpRRhGdAx)0Im z_s+D;kDE=s^3Hq@d$cNorO9U0IFk>-qi$>P2^Co2<9VJ3!Bc2d-?biSry2o1ia8r! zluD4Pwz|1SvoR9rrg!J8Kq+;wB-A&JLP$M*=Vh{vyajc1aEp$0ygUFXjAH1tBlmaa zX3wS>UG=F^o?1hAF7M-aDJVn0@~LUzv6>OIRloOQct&Qoq%YFNBd2-xEx}lL1tnwv z_2*v*&a(tCS6-B;xT7_BFj_?Y|1q#W=zx3MtNSN z$1gsS$|C2evXV!`o6qEXaZB^bIBR@Rgy-xH@Z`%(8nvrhuufX5NP22Td4#k<{=(`+ zy`bU`>xK^eBAy>9Guq}eD=F1Y)YKpEiF`K$0?)0-tVI&G=MGs@qwDNx8~S%`VKPW- z9PBv$qXY^VzghA&c)%n?oFdjlL;Iodh|dI>%$rw3N?Q#BFX;&RYdsi`kEL=yJme+m zsg_dD5Xx_IfFJA?PO|T!7C&0`d|Mxlj^hV!K&`1+ivwl2%#s9MOEPuFO z9&ZJ$^@n4A!~I%T9EJ2+>z-)#&unP7es}`(e`EE8lr>@`BdFN+p@%v1XYs)7<=Iut zz^lg0=3Nw7{07L>lq$k8A3p><|SyNn5C zjczi1Ytua88qH=M%nKqk0^#jzM}N3T;$&cGR96L^qR$VnU(k3@oM+XV6g6fs+Qyom zQ$s0RPli|ll}9R~U?IN|4v=>qMbVbE;sw4Ct*l1t5kL6njb6=!zW^Ee$wnzlc{n#`tZir-Eq zUXAw93;ZKRqDme+P0^nST8z|+0amFkU+%0GW65nL%Q&7l&uZGuGwM z)d{5N4XLd=yqx~|dvdBM!*ip3Vk7_1iQkI0bbABawV@~L5yaC=*+656 zsXsoTDlWOydMV;PnvJ65i&XG3RX*@fH|O~i@j=8rPNuToOJwZ1mxtX1J7X_+kWL|m z(yZg2^k5%=7|q^i{)3nS@>O=kPU_!efL_Qm`h7lE_GC`9-41BdKA#!#w?3Cj+p8Jb z_Kb;s+0|p+va6WTn)kFTScZ@CoMaHzL16y8n2}GqDeYS^M^E~?lG!yP+L1n?6fT#S z<^%6`r$wA(1MH??<8EgJ-jL3JlJ{1jZG14<%Cd3rOz1!Fg;%+hQvj_IeIcXe$J&J= zziEZVykEtE3caN7wyM$2gy)<>^jNkbJPYqdvh!!o=I!8RBp0)f-7l> zC3Pajp^SDudzSgU+tat1iWJB=>pz#6{1-!^lle4}@!rdGC`)E&fluU{K9CmwL|mRm zLz3+}+m63-sU))2o8EO**BPDsB7BGFT>+~%$AZv?$;^jlXqzQye{-Rgy1dpW&{Lxp z*|PR%PI~xz##r4UP3G8mmOK z5$Gx^v}IJ~7g^6`KU7lEc6LoP?|ZXu#V#AghNH5cOB1$VKiVw~$Ui>u7HKi`FmDt` zBmc|ztBq-8f)ff*oi;#g@>~xKUKo9X zQmvyF1F-$P4&%WWYf=aE*+|XSTFAyMJxNXuMfZ4(?BHZ(LAt*Xi50QQyBK3L?s8kQ zS@c01(aKEN#-@=(D=mVIu;b6pJF6F zMkPwNt7tV3FKP zdt;&NhmoPpnZf4twlf@{%jB&2vD)9}>O}E4ZTcGJle*buqxE!M8!5qA3fDm(Y#>Lw~ec zFQY4nDni7*`lN4gTg?$n;%(R`!=9pj(W|>TyqOgt{nEu^qo`wN)nQiAx7Ehx)F@j}oeSMCcGe=EnM^ECY{Qe!TlA%#uGW z!^ZQ&M||c7kxc#^?^x`H%Cv}wPu8I=Y_u=Rsl!%_A(1_BYz^FEztD>psWsHXg}Hbm z;^&X0HF`0g&)V#IFw};p=&hEHeuyoN2p5eQzvw_S{Csg(p%?l*8QnUVamiP9BFFaGu6-mvi+6?Ppzi19;aFCAhNPK2tRkFChD>}x8S|poZxC;(!AO42y ztcUN&-&p!@t^|E~Wk%Rco3HR-r_Tbeu%{WLFk!G5e0;bXs7mAt|I8 zk2YzWoZyok4zI*_-W40hH>wuc7jD12h|W5~Y*x1~riVAfX?o07KRfiQ1^6avcs6&4 z{k;8`Kf0Rx=?kB}c6%=0p*#NiMh=#inRlZ-@^c|0G!B^}80Jh*wz4(lU1dC*Z%CWm zCn=^sEGX)shkUGS|2J!Q`s3$%-UWPuDp3gpK`Jk(XlaT-YA8ypHX+TPljbC6vpwTw zypQ+Y9v`pc@r-Bj%y?{%?Q?7&J7-Bw5}M@Xge=|ZLMv2Cg%F?u5>f@K5HFAr;ve8e z`F@^neobHWB^u4lZ+V{kx$kTFUf=7w@8^DGh+6mhT*fKe!4!JGiRw6>kD-avEOK}3 zl}#7{TM;E%DG{-t4#~04=7hBzxtmo#WbHWhF8tcHJ32AFW2ZGekW1-{yV9{mjOn?I z6%uq^8aBp{XiBw2yuFfrVb!0S>#$?9Z!B6y+0~(T&iQP5O&;oN5Zk(TGb`k#mn<5h zkULMKSI9#XaL-tKXN*3=WHAFP##G5%uUP5b^tVjFIjszM#m3-{p4-b&hRB+IqOZ=u zQ(yuQm>B+R**L5qvU5H^rajsybEk3FZXaDm#3ge*G{eOT9M!Df%ul5#qP6~@T^8jS z+!%@&FYkqMEK@(RIhIrYp>I5kT*-p%V-UvA-^^p06hOl~LQ@L*owJ5A-G2d9LY+S7-xlf60wmxNqdJcb} z8O!3uu)k|VwTvC|$?|GFK9QY!TFBGHD2;# zGT?>KkLIxf@zDGvl)(>N!(W!*Q?_A8Cov3?v%F?O5SL!91#x!g>WtA0gjH@hL+8k# zGWn`Y!4$s3YMsrxAc_B^A(NSRJ8yCFhi{Z6EqGT35mLihg9yLRkT1)ob>r z@?`$0a~d#`-I>#-J-G;bQfs0^ksW*Stb@PSf3F(Z$Eh(F>7iR*18 z#lrdHt>OEV={bFqvR^uqUBW&+bl)-Sa~%jtCeWwMj+L`>@fUmCmuO;+?Wc0g&Gd>) zp_{rRPlIR1=$UIjmbOo%j!{2g?<}6&p#T2((4DM&V!4>0T@J-yJJzVksbb12S-bPt zhdW=b(-n&=X9&t3mvaQZUd*OVy-+cLyLchY=2;k}(ZM_QBiV^uraJ=Aw!Bw%R(>aI zt;cvmj(HTbK)cg69ttCF4qmAZHe=5RxveTrv0!(;SxEjubNYc%-EXoxR{F{K(&WG|Mhq8s+0q0*`5pr@I0`B`oh!m<^1a$~5B*dGFqSMi?7d3K~xi{N|3EjlG%(zfz}Z9+S? z&c4j$@IYCqwHxypsWsWED!T_6&F4yUVMVPv^25tnOG6LPAM=5-swL(hwL6}EZ%x%& zEKmU(VVhMgW)x4Q(*C&||C8B)V|x%#&D-wV_TBqTu5PbC)g`tv5o{Sx8=sF9;iWnC z$ylA;3{OPvlNs^uoMSHCJvr=NVxEAdc1E4O&?-a`(`sLL zqNyv{llM=nZgr*2{aWrVx#w1zRXHV$F?(o~5#s^R3_RnoDEB zMa`SEUR%DeHB}ueJd2ceF(xe<+!c-u-+v)z$|1?gI9dADw1au!MxMuZC)c(|wVfcX z@l(01Z^@f@DBr@{PNWy^(z0657+n{>81U~Hv4k`^rXEkFS^=K;W|qk7rJt4El+Ja!M_UU(couIUEd?8QSs1=opu|toXxPMFjQ&0Yn0_5cf%{ zdpnSiOb&xw%}%WDm*TgwV|v(=W1)`T!c=uKo=VH&JY3aY*KqL+dZa&BgIHpwS$h1S z{LL)vfxN~mcoUmf|Nd`-bjNcIYiw2eJtZ8v%oY6L3(XMpq}M)?EL&tY z+igX(ReIagW7e!!;^(QP+J?Yy%NnYCa5A|t7YPnfKGUq6@E)4@Z?N~&ZF zajpBJr%vLPwCod4Ref15Txo7VTl8|b(Z5jvY&*WbAu zGcnYdJr@wOz&6;8QB=y~!>D%;ZSvGs2U`_$OtWg4FlO8ehYZH-vz?tF4eXOKx2nJ% zjLTXuxhE`B4TE?h4X&!M_=?u8l*Jl~t5$(&iMH7QJJlk~Z8oT4eku%U%g}CTd}(<$ z-d&DPt}FqQZ#9Q4V4UZI8+t2K#n0rvA|VNz)w3GdE^4mcnmMHB<=7;fwl9JG!AL?z zT5Q0;Wmv5+t9e2Xvl2d&9lL}7SgyXF*Q0qF^4$7;8Il-l5i!b}$%G7_jY{MKARGk zRUfP*?#S^6-xx&GfZafF1Bb zmI+f9awRSG&H^?-+bl*l#V14t(H<(wn4y_Ii&l_9T&GDoe^Cbb0E2cn1pW_;i*_R$*Bow zpZ$An<`ll04dn06;5TH`nMqX+NU46>*e%CveU+a%n#9PdnV@LW2qGGr(dPAHWLXd9 z0_$LMxu<%$UNnDXpU$^#nYD|A_|JNJ#m9K5%)YrKe`0m`m$*(w&3oVmZ0HPsBdX}2 z4U(-N@kv)Slc6P1%_pqbnhry z9V^fe&xo)KDqB<|b$9pa_?6vFTJgXRTJxhhE!s5#2X3XBhpDWQ@keu}d$`P&k&xUB zLa{rTV+LQuGI|`U+H(z8|gjTk|PEW&zn_EWsEHV1F)S_&v?PYf9) zdCIlQy=W2>fgDEUX=0w`1lQ!sBugqr25IR|jl>#Uw%Y6-ck6TP+#jQls?X%)cvgHo z%s{rY9Pa1t9Mwl&p!$dNvH|vB2Z=r{!ww+4UyNky_5+pYvjFpI+IB~$*4>FT%BAf0 zL>uf8ALOZI;(JjLvayOXR8ii2IEM%MiPR7rpApIZrGIQ$d}d>E$*yf^?oiIj|3!*s zA#AI>O{8TZ_U);D*c45&Uh9S6U=ar9!$mb%jOc1)bvUff^|s}6_v+vnWNZx=4o8ZZ zy{nM4QNdL5=s3+azGTEn zbq1g1Iq<#qX@1Y$JnK0Jm$uuTyr1sa5Ol+$yXaRY(>xVlr@!I6>;uTeHK z76qHF9LxTy_70gD&V)u*kFYP=a9^03$W;8cnLmvgP496k6NWcdutvIlU&=n0UW6O6rQjQJ(bI2I$DSF zR+HffSRjUeBUkdh+7fHVR^5ZRGp{ghvhcGf5kbT-GHlFF_ubhIg0X3KA`jMg5|YtE zSaxrZBWGIoq!QygDjwoDKj57>48(yqG9G!GXB=7AvwdJb>WM5ALaHmMyp}b(|4Pom zn*D_!N784OZ=I}Ow34HyYK9(K;!(V_jGa$fHMujtWQSE-XDHf2JrWWj?T6;6F`hYq zhpH}NRidv|YOr|A(C%`^rQU(5zzBC-)3;nv<&&MLNvKt}Zl{ktS)btKY_vL;9JV_M#>Msw-qFxEiOcB4;=UY7hFy~|nNOaLg_99(MsjwZ zwhj?yRwL@X^60WMy%QrE?_kO4XbLVY<%r%9*pV3_bT(EkTPylZj==%!ugBKqQY=%i z#ZSHueO;NHCv)t0j^(3z;A-YZWRBG)2_8#ZJ97rB#l6^--oisMQ6}TF&cj1udjhEr zL36BA6vYL_Ke}cS^yU6TF@QE&<meR=zH}ck7M8Z)co8X$#S#i1fMtfbLR`P{Y0 zhUPFy>|PcOMQMpvc!M6Q|6@h^38%v*F6e?#SdDRRvO6njh`L%mDsX~+Jl|(F8)CbaQ zc3F(+JS3DY3!xJ>tlCTxWK`Zj_V_)|RP%N<_rbt_y@xT4^%zXLUxVI*gsD zd+V$GRz2UG3)?ItG>zGH>{fVvQP4&QFcI%dPjFYD~D^gL?d{GJB^?zPt}U`)W#$V_v$PKHzR#f ztTW~OLqun3WK@=}=lop!hZdE6QAr()4H;wCIql6+{L8G4-jXpdCkN5LQCY7@wfF9g zWXonpEAYi(c+k4K`%}o2-`1`nx3;y!!o^Ko9LHlwVYP6*XCuSm3(fcGgqK#zjaYoCwMzCm(kA315mgw7!4pV@ZH)4m zdawO0cjRbmL(MzLAn;08Ze72Y$*byPYC$4JN7|@N{+Nl+NB&#y=EKHKH&&;KJFQSV zvTK)$@J8Cm!xBl!f0FWAT#~JR@sR7FfuwlA? zDR(N7Bq>3DQK2iFjJYfHU(P&ovr`e-Tnx|9&*oWHZsBRg9Vl8Im|N$~G7~eWX2?Yo zdf@||kFMRr0DVv=@603{R#(HQFatalv+>#YazzZZSp~N1eg-xTj~ky?(;nH;I9r!# z)Ouht3*j43(5#wsc#<9#Vxhs7EH`ftMK-Y&73$ zbe&h^QR*=4V{f!5x4?g`6i{hn8|2(8(7WTiC&GCCw<^FIFm-#8+q+ptgM*7z*VCf) z&!RLu>nbAgA3NbI(AEx_)iF_HqiBVk954-alf`^r--Y+0?4AK5_yFrt z%U~69GkPOWRXkZk(SohX`&dCe5mLY)9su*OghmYd;is&H)}gVZRh3~ttNiLO^;Z^S zHG-V1Pt)ldFV8~6;K@Q&{%)#JutdHj`qH=hhjy?pQDQ!f$X%?mrF|L6xG&ftah_RsWRpxf(=foqA(zWN6ef3TRbEMFx&04d|y& z?%+}ZG4n$wFcg}RrI=|(95R}df~KyZkAC4HZx?ylGwbO(f!)b1WSC|i^_M-vELfx$ zyjRVnSqHYEjhW2ppNaL#1MmeFA`_Hhz#dXpQ88-L)~3;@r$9{z;WHePb;2x?#F6kJ z`sj)z(s#FA??|I>wqX^ek&B$47?d*hw-x2U=U{v_pxo`8v6dt4&qo68wiWec z@i>f1FuP-8aFInqE;CakXRgihYAZM;lr~#UW8|%x(rl<%T3cx>Vl6xaM}lLb17-~i zaV~ z%BE$B>|1n`z4Kq=5)aL@@-6sGF5VfBsL`4qB$BCF7ih2fwS4Y;wu~RNim6?9HmY2z z(OYNb{yupGPhsuz`TXh-+5yby1B>pWKh;|WPnw$6jy@ordw+0zJwBBy#4R5VNJD%MZXg;Pd9At?(fO9EP>9PE`i;%PaY|HsxFu zLYnGY#QFzpsa>#6=U1bo1O;55?>W!7`H`OaygtPec}{VPUQ`3wPFW^7!Zz|Rs_9Gb zydia0vCgI$S`kBRlu?tS^)S%I%z&r6znqmK?md~$Ur_;O@NKQ**Q|#vJEK{d9*ZGH zKr%*(h0~{KFdA9;11n|my-%Z9TYRP^zjOmnSyiptYZ?~I#m(xU=2$$5*khR#?AJ=4 z{nZor;+`c{tmpH$c3~{;gzd~6riw8@LicO&Z#Zj&WCbhesPos*spDtO@Wo13^All#bUjFNm{1RiJYnK6@Z>$}=+O-Uu+v)V#vl z34Kt})@%0DS|c8` z5!l_q$k{wNY1Ml@^GUoc`e52cC|1v|{9Z|xri#V0S(aD;foYNjv1HP%ZPHe=5cX3! zTASBdJ*yy$3O8p%MXQdzR&1vZW7Z=qU`JL>v-uy`hx?7N9hS=~S*_N3mvt-PdJBV~ zzvrR*bRz9~G9g}y>w99mI3_oOMr=t`gDUjvNgPKqi(%ZYRM7sVj79CgRnYRx_JWn* zqijjfnzKSb7Ih})Y5!RIu{)nVmsfWZk9P-WNz6(Q)&%9O;MIaPRmMd=^vKed^LeXn zdI`zM%MXjk7epzs+FDKC#oF~(B+xf@AX}41Y6HUZ_p|Abo%z96H&l?1tBjpvhbnC-=)MSUP)w zt*0_?NLTRIzLl!mJfPJ}b_+>5LUXz*#F*K~@+rHu(n=rQl`x*O_-Ny*v(zKZG1umc zr{y@JBn@|Fl*O|^^ISYlJ7pMH@Kj`kTfs2ie_`yOL=Wea#W&A`rW?6FTTwr9%)T&W z`8NIWX_iYbVrp?hZYDDGD1MCb+9icdSw4M|ygg;idTEQ!VQ~3_Cn-R9kx>iIm*ZFo z*gRdNb}tW}KUzb;g3E`?L_}SfBA+zBW3?8}s}?RA;nwO3{0yV~WPULmdvY_hfEz5Q48MgM6TuV$n80;!U^ zR@E#>!Td9=$hX;-c6aBzji9}r`q`Wj^nMoAGMStiR~$+GkWaD&yGWUZ;Z0EKX5?@x zXfI;V%sT_M5G%Vh%w1s_wYze3F-U%X zDEEY7(7U4(b}kz*uS2?K=H#uidG^ibmu3m_Znl78SobxbR-OiM?3QuE;<6kzLc1LIc(l~IF)JqNw$k6U>mM6o8Fy_ zOrMRFdQ!9JSr*&urDZK*S0dnNT5QC>k`{TOD~Of$tSQH3J>rc%(kGO(OQN=#A*r^& zM!i@0rTHR|8Do~s->>IWRTy}!-<=8S$n7n`&YNxUliAb`PbW&-W!BmW{T3f(_0S4_ z<1_By5ue#PjbNaSMp*3ZFbdL?!IW8959L^Xr0(TfR>#fec{jenQ}qt+w9?E=w&xe3 z@Ou&vL)fTe`6d0hKHahj$W#`~E^p*p^WJQQ51q=@*n-_=T3NEXruPJ~R6G+}k_3s{ z_szRsI4Z+X#S|C%w1SRN z?|DP36*!lk^iEByUs<~Kz-kn};W0EMi?c?KrE5cVpm_oI;>np}pH+!vjd+Tx4tr#$ z^<8Y!6?h?gv~tm|P}_1OZThS3tI}g^EME_rQ_^qoPz|>}2A>+kSfFF;HLWPHWG#?B zH1jKmg)ctASMYLL^9)8TiJVACtYZme0t39)7Cz|O@}4yyB54a&yMaj{d__|njOQe~-hLHb67ExD2`pp1z9v7oD|wY$C< zNc!bTlX)!yJBOFpdkE`L$5KCYRaJ5MqO-&-t7%|2zh;ZByd&S}hnfr~3=7$$mgRK& zbF~L)L4}5{n!U~CbDzKod5ZjnAK*xOfrDPj_^q=bF}Sa`$5uxcAm^4RVjOC3*0Qh& zwQ_xgr+jxjpEk~`f|4)9Yn8a`EZn=Sc8KwjvRR8wX;;h5CADK+7P%U2@f}j?8zhi< zdrk_5uC8RXEa7mZ?#@}9sxhY-3frsYLm7D)uZ8mPk$)H`>w*#1yFnV7VO?aRLfiSF zW<&>qnrb)Wxgs9SNA;HVlA^nu$QF)i%e)}nkq%ABhvbr7KeRWk@#FE~T4bs=%5(LP zow6I1cJ*X=57d`8VC?UteMrVvU?;C*jVo#AdSa0I!RPb0&v$YLJB4A!B{G&r(F8`bJJMewonVR53O z^CmNzsbuli3geSv8H8;9+ZkeaJmHR3meM?CM?U-G-Slrh$8!wHW+AVAqkZVglHs>l zMCQ$}>9Dg$@~H!PlAw63=dDEV9s8~H)L6`k(yo7PPZA4_=V&<-{B) zfMpspBwb01v|a3A>wE-Y1pcZK#43CrLl+nmtaXFSi0=?6}85|jD6;tj=575P5qhrz4nur+8~w4*0E7K!P@t2yFicCWQ2 zS>Jn2mRP=vvyu`dA{)MB9AZ?f2cyXAs-qos-mAbW3 zRzn^T>E!_nT9tLC>K~mp>u?lXcdU2_CG_!P(9qfX!$)|kywy8@&Fmi!Lef8!7+HYz zC3LC(`p3R~TQ7$(@UGvp95#s|J5J2+EN`nce2Hg@EaG*s$o0i;|2xKJD#q6JPOinf z^+C;wRgxNw(2uj|L(AAPpt+bzRBCm5Va>uwv-hnJGsrPMG+er^74C4Kow&h*@KIMrGKkGh<5)iU8VO{*k} zppcr(=+C_=Y!M5^m5QXKOn)lm{1TFg7_yPx2`i^I^JFBCa~+koB9y#fCD-8Ds>6C> zta<@6>eJ%3PmBaNhcEW=>N*7afF7-*HO}d;b84%Q0$%YpI_q_u!}@)~Hy{kICfDjX zVE~NAf@oST3rn4gBvb}loAaG0pzrk^Xo^#^8J-Ex;2Q6tUt?r{-aF5kyr);uXD!Jg z*kf_pIebi?oCgD(i>dH3uYBM8kKidgc4cQ8u^+AZ&o|T?I*Ub@Dp(kCYc70AMq-Y_ zPWG}2S8{beZVV%DZ!Gc&o(RucztwkFQn_Fuq#=VbTE~)d`(&K-02`b|pVa}4@z2_O zEr4|Ty&TOhBT$b|urCsZYsR5fuhjUo-JbU<^xz!hX!Zb|DtkSETmCmHR1Z9h&D6W- z)%6@R%9OOH)(wr=M91k<)?gHD$yLaj1lcPo@oO!PY!knhyLVM>+o|pxL3(^me@0$Q zmOZ<9%SdV@SqW>!F+4>Jdc^Zs3Kmp`<5)*j(`;?zA3i71%BjcEcgII&KI>J?h>@{* zSaDPx!z>7nFFNo~5-PIlBM;X$Ov3yQN3)`)oX(Ch?3?EGnkD#@cExf2B);fjqfA+u zzSUQp(|A<7g$R!BstNqAKB`xuf_||wj04`6zd&zS)+Tw?C&|}3aS_~VGFAB23FXk( zZRgE)Wkkk?-J0hjF{6RRKD8DKhlNBmBl_1mT(j53J1d7u%UzSIH{&@6H-?q?*y$j= zF>C)~5DS~D-I6VS?_?c+56CSJAX%GYD%nG zrfH0=|A}1nY<4%9+}YC`)ho4VEuvYWNGkbSKjeFnJ1P!vMYPcjl8OJ@1h_Sqmz|kM%Zey&LZ+) zEw;8%ET(NzkrkHdn?db*S32TpeIhdU#a3XEz4S049>*$QP9C^8;Dj~eq6xIcN@}0Z z(r0&rl56Ky&L>V`iq%CWc_{nQ?T>z~CKnHAl?4}{S(Dz2(rOQ_l*4;GTrGlcley}x zI>^KP)=Oa)R@vO9-Xm7?TXj1*?e(-{763jQXYIPZXqI4K?a5rP>za5qhUl7VGvr}x z>N(cLd^l}+j~7EWnQigzZ0@pH%z5s*-kIaK1qmuOJhBnRT1wGX7Q*Lv_zOyiuXl&>rc$SxeLT-*_^c`>KI1TTo04v!*YGKJ^KN! z=FC1*f(4r;P=zpi*I{!| zY;$BbopW8GBl6eW7tX#rdSrL9AGJj3FHV^hE#*~ecTzjrU-YDBAJHs9`5267HurCpa)XW-NtH@U5 zCt|WXy=yuPpAuuByNnN7+TTqD6;6^fEbS^9cj5ZB`+ks0<7Oj{_@Y9mvZ}(#V_1f3 zvmgjscH4POVL^LMO@ z zLiT-K3?AJ~?~4(>v9Fz{-ogyEYKTHkt;U$IfW2%Heu?D#q0FAAlB_kRd*X#>(gI7x z6ZwJp_4%P+)*4tvemT#-xRWzBGLz#>^N?1=shG1hOy~2FvZ&-r7!7sjNf&CppLjb@6170AQ&N*#@BEVINw~xY@Gew zmJ=)4=Tht)Dq~pIm_SKS5!xQcfyvti**&_B;1PCqGiydB@XDH7y)AbWeGiQN<9u%B z3=VG3wWi{i=8QyEyh&tnr=5NO;5HjU?ea_4B7L){ z#-SYu-DqhKBs|Ss$;YEj{M%YV>kL@9nAcXs*Hgg*^HL-PuQ6&~0A(lhy?Yvcw}zm$6!w|sDGu8%*!36(wU!5R%&8LxJvxd?U)!&HH5 zU&k`zV||PGwI}xOxjSOLd_Z=^7x^Y;Be(Yse)V%cycp{?GrB7p^z=7(4${Jztl=c9 z&AIAY*s^2RhR>lpmcY(oFkhe{IXbLm8SF#!RonN=>x`mo0%M;UV#keW!I-ZPb^*g# zFux*eTmgd97N65HE|voOwQ`tkelW%Wq(>I`k8$1Nw1

sm!x=S+$rWmSG!k&NrtcE1B!&cRtzYP5dcCc9iS#RDHnv%-7I>x}10` z?wRpWg%UgJXQCSJDvNNg_GU9KD}Y-S=LdGKd6-_@32)(-Ja9`akfobn;%Ruhm3Nqw z-s32;O1TXCGY(!srd_*OUa&3RbTHOv29KVuL?T$J9m((?GCz^nD(0`{TfG;ttWB2H zvP`qYwL`nh(uat#o?rJmT0PJ9M1N<)A%2R-Xo;1JnfhuyH)&cEN1|{ZFP7ixhx0uv zsElA𝔡xwxPGOHqya8T0y}}$y9d6du1iiyb%*GVN2FG;D%<4MJ-s;`OU83G9x8A zG?y9ARrf?kz0ZKw&5*%}?P*#3H4;}?$(2=kSRnRL4H!uo=8M836(tJd~Ho`+TQuK#JoP29hrQ5{Y+6Lbid9@URp1O;ne5_mo z7ODc21zO>9Hs^?~@&>CxMSS<6T0;gaoFQW@YlUub=)-a_xmoJqbcglw0;ueI*uZq~}(&|*H{H-~Dc z2CTuDU=vMSbKZJb-eG*5Z*9*v=X=s5G5c$= zGT-nK)?yrVr3S7Q^7IsmpUbh3p0AL`@%XGW;fzro4=ReJ;#mLlPh3JYp-ovg-huyk z%?~{dw)+xSJC15KOq+N$Tlvh`??#qJ>3&Q*p0LZZH~uG5E~Fh90^RUZR;CuqHeDT> zSp$xSp#ozASgyR@{<1L%{) zNQLBiggTXHB-kmD*NPX`=pBeJ7o%VsA1t@*tS)IB%lPl*eAhKzD+SEg!!!35iK?{D zj-eXtk?)d%r#m!?U^TF!z9_ys#{PkD4ptZ+tb=W>Kw)whWApYqQRyIMmPxXzzO4#4 z-hGLr1m~^qlS6(nS7R?Y3GbT9bF>@F`HHGoWlmGB>>8Lm>+N~~QitWRu(jf{tHqqn z2aJ{N@)tRsPxu*@_}RRsD>Bu(tXGsJDX~)J@aFIht0==UDt*;I6mB9dppx&@3;_G*^?+rR?rzAQ~xo7 z)>FumjS+*0f>&8PSxiFcY#9aaz-4*{y2vHr}n*#UGlox?3bMdhudc><3=EZ^C_ro~1&s zori>AB&lM+*Ksm@iQkZY{j_{g79*DR4g)yEe=!eML^otay5=wRQ8j)spR)n1N))y~ zfYy6woVXAB$shkP4l@m+ktzVoq6?O5tni$5!a{PhimZ`=Y}L8FbEF+MSTQNnBh)gF z3WtoZ@@I#|+}p!3Wa=!vZH!=TYMMUPn{owL;yt`f&(w*`9e7=>tAv_GP(##v`8JEH z-@!4~T00dBosF%zu4gb{j=N)@bm-14HG3nruZDShGORy&zZ`b$8s#YLbT;cVIJE0wNa3$@U!^E3ZhU%ME*42V&nr}^|mk>uU?Y#~8 z`2e|N?nO-Tbw{+OD%O~<#@oD=bsGm>=ok8m1EP?e^rf@|&2Tap4{7C?;(ec$P*%c+ zcv+E{S5-zZyt4ohjjt7PS)S{Vn4?t%tRyTG;uChRDW`JkWb9MgAP@mc=hC_4Lm#Sy+LL zf%085Uxx>+lw0xwjJIvElQMGFq&C#Nj=XSvcs{}o|IMiIEf(*7Ah>Kr6G>qhxbCIY zht(fl$!aJzK&o(r1~F z`{!$wgpsjto-6K@H?WP)x9hWL#u~09^2z~OpqT+#?3R2dDb-OCvNJQNK2Rn1ZLMLW_UApxKj8 zT%9K2EB3KIco{oo4=}{Beov%|@~yORAfpqJx~hV-U6F<8?R4s#&_~pRunYN-ZPKsU zO4H6MqOm19Fk@EVB2%(c$AZ`Ghsayb$34p$*Z?ZTEKJfAK%qjfe1vBhGTNitxubn z&p2pR3}K0`EQZ#TSy&N9Em*vIJ}o!Lab0=qr-#qAVEsRf>^!Evsi(mn42a(psq`L0 zGY*<&VMfzABYq%KySu-f&H0XT52n8930)c?{NZz|8oUu|X}9?3bH0vmkbUi^-c}UD z0oc+)#?>F5s4pz3uS6%&!IN0Z!AOJn@+pi`x$c>EFwj@gj78`P$+wS1hvJaWn$_`F zvDEmCi3Xs9d{GaJ>Y|61^^0!h#3W~Sn@_+^EBEY4tOa#;j8&YMVIRon&9334Oij%g zx4)Y+p#`oiQ`3qFra#rMI=-k=L?vl693qM5UtuTu>-k5t#Ov>+4WC*!BKwBiYRf9> zcCEGYRc0m!P(2_E?Lwl-yf3b)FRM$sk~_2O%OpRM@t38VJ(u;lv&<*(h}=k&Cpyy# zb|bJJZ%^_{eH5`oD>eo(dB1sIwyau{&*y)~YQ4!@+3cEokve+_-^yF-t7NZAqn^TFSWurH$dVx+7AihsNybz@gc%h1 z@Iuz!_(}uz|Is5@<5^bvP8NIesd;wo^SbID?y+RYm$S;VR03GO{8_{%aq->z=H*>; zs9pLcP1tXIl{-5$L*gn$PAs7{57-H>jSu4RF8GA0h~7mm+7X|48w?Ps@O|fyY5A(w z=~sND4Y800TP>k0Gt&z>%D#NJ^Z9Z#$#VGBQeL^AiB`l<{bN`DdUi7uY#r1}B|D>; zU7SkfGiNS>w_0I^4&BJM*;wtA2J1be8L;B8sHmrqy){B((tDa>*^P-t1a*v^%o_Lf ztrFwQ{2MRBe4Gi1#IM%aNkhb*A8f*yNJ~5B>KiBNTbtqtUdMl3p^V+0QF>i}(w1wp z7U!wz!31Y$-C6ErH;>4|)qbjLE%LUru|qrWsl{QFX4*wa+>fo31IuAo^)d(r)p1j( z>pjW%g^Z4b!{(GA4IM%nwew;mhSW^F@fVMxAw9-f^{i~39YZE-_{eT#iK4B08+K>J zs#~&3Qn&sb4v;6w$>hZ`*EM4Gm1ZZdQT?(tW2Pe(!tdAzc1E`KdvT<3x^pULQkUp6 zrko*iu_B(^-W#DcXL6eAi!ZVA8}YZnraWF&C!SkOREt`4>!YUv9((ko_Z1-#Xv z)=;(P*c{Qy!SSGoqkY0+=TcvQL{^?+rZ|kbO;+IhvWnL2c(658eEaK>qZJl3LdRwi zS-Tlhyh5g8jexnwatD0ao)feL%gtblH&`;?t9^RSqI!>TGufBYrkKOxV3YM%>__ei z+vQg<)U2KQgK?}R+CnFpCa)sNFXUBMYLP#7XzZk26ynv+FP7A*4vf8<)_3Ixx2Sg2 zCs-7#!T3cu`nOAFwJ1t?EiSPHQPmui90f+V7QUKyq*c2Y-N=u$iiGt2$w9WO`Gkcy z^3!>RS@9RqWj?R==I^x}>FTUZ8p}M&?#y-;=Sjjk38+%eD5k5Wo11`H>Ov~5G8$DMev5tE zxxv1CVkk`^5PcU*d$Irx~Ju6UE zPhMxml-Nd}IEiCf6YRxbjF)ZLIfdRtUs2>(xNR}SJhPSHJA(B*b9>Gn&$TAU{wz|m zzS><%wRv$sZ09Gi-}+?n9izo6#d3df^;U;j^-{j&b5^Zb6L&7{znH(YWZf#O#th-@ zg;X)%-1e-d(a$I8gC1*nJR>pE&Xmx@^Led8zcpjFvgLC6x|rkVhfI+icvtJ}n6I-* z03yR&8J3ySy?NjHG+8JN=BHwV+yLV2PPIx-U;e?mt@@Q0*+)f1PZY#_%&1s3R`zY> zCuEmXY)gw%X-94${$Z;0IvE+dCNH%=5!Ay#Ow8fspOZD`%vG?UwSK5(|`6u(a z5X0Q)!(5SM%qhta@MN(@O_dHhuVlA(`vY&vE*#J0Q};iZ!J%C%*?J!Y)@qE_;~E+J z?V2>Bbyq%ZJj&mL(X}XV9gbZRq3rp5&tH0%hqWCzzbuLMo2xmQnH}D%W+ToSf%+B9 zf-)1ird5sd7QRkuEEmgQuT~6@?zvo1^u*;L73AnWf|D5w?=>$b^58eTa%WyGBx0(x z;i|mXPFve@RC&EMdeEG;yTeC>qMLoO7`1sddvP8LpG$vf$n2@(&Zp1#zgYlx=dj1t zDp+OLIhg5pgucn=WdiKJJX}RYgoXHKQ`EraMp`m<98ZN+e*b#(EMt_1bxnev;FKbI zG1UwUxx=_^d8MK#dvY~C*f`3^u$igB2D*Y03xs5Ma!xtDyw0<9%>t1wjo?PcA;!{V z_b$1YPvuGGec(B1Vsm0&d)FCl(y7M894M||{8@$6>cj73tR!c)!%VRm1+h$=(RZr~ ztQj;)@r;k%&bf}TdQczndGUqC@fLoixBk;Ilr~Z;waMNLb;l)p^j$RqZTN{1sK~Ml z{bo0fj>hPDE}g-K&8&Ak?cu8Q48!f?Dd+X%DmI4M?~l*8%Ml}P7KLS&BhkIHcIH^4 zX%$ADt=vRpmou_*e$w9Z4&U=c{bOVJl}xyEOtebMsx(OB2^c!_x; zK8u-@73&%8(BzR|%bnp2UPbck(ebo{|MY$voYkJIa8Xt&Bh!jDT6MxI8{aWX^EKAf z;j^8oVrh1Zf`aa(lS{}P@DTS+LXDzlt7Py`#1eZ*h#W|iAK(FI1J~1zJ(OIikpIQ6JY5^pNd~ zUygfJQOCIQ`SE|%Vy}N zTphk*-(*38;x~j7k*}mL`|`?bctV~$o9H!Ajrp};gN>v!eV z<@g(KBQ1OEipraFESXfMn9*9!z(rU&J;7H;7`GgNzpRF_nVV(7+Gj&@0CPpEWum{8 z#IV9|SFp>^TC>n|tz$)AkvbGmt(_aj$r|t!OuJ~ugItMUvJ|WJ%7*;y%1N4XRlQ=p zWg#R+m+s%x<7ygH70Ei|%?IVpGK1oust#*Kd$Od6KpA{9Vk(6|g@{md4%pVW;2`X1nvQvwX5i zk<;;AH2@LHu8SL1jfsM6%80Q$^<+Jw6ZtOv>OUQKEpgYTsw|S2%+PV7P5EzEw~BEj z3E8!bAwXYs1=)xznipax>f4Js{$eb(>nioyE^TUS&M!+LC3x2H0 zc~O^5sd%uY+EeSu@;%iwXj^|ZQ)pE|xux;!8Ek-U>ZOrFI=d3lCC`QaVn41Rin}^b zWL0nzI*^K0?rfX9_>l~}Rza%=23Z=5y9N0sX5bz?&pYKyW%HzpE6Wf_fS(yH-yHe= zOKIKRk$A~+zK1EUTxNkUn2%va(2g$DXo{>V0*=%sd(#Uo!_V0q%WhR?e4=896&a(r z#Gl>M>KmwrwL(g3+UZU8z(`m{?+jt*vI=Lm_xK&Tu>qXqUZU)5syB8-fxV5j5JsfX zDu3WPcn4l4Ho+?`_HMiC+;Lj2Hy05}x9dw-l-}z*Evow3s~L}ArDP&gXuiV6;3Lgi zN6ky>$#P}+1U@59k-4+Tve5?`^9pCtj4@%RYD9dMXK9t4sKIG@6lKgS;gdxHvxV-x z(-*Z0*ANrjflkIS7&$Hyp;GwSGUuS&oNk@W9;O~n~#OZ znQ6y&LbzGG_;Vbnbd%9RPwW%qvwJh8~-LLz#i9hl$Gh+~e z&R{zZ00D6yR?oU2sf-iKACBDA<6%2qDi>!H;#_l2l6B=GFF*6l5Wd9!)rYhwtL>a& zSp}b(%P~B$&x12#%O2#r{9-a~c$SFT8mu$lq@p&FqjqQ9@;AK@JIGQ##&&qNxfgMw zb2YrD%$g*-^1L2y?BZhQ89cp%2Cu{t%+2l0RXX3JMQzDn>4f~S>FJ?=5KC6UADVHi zqtR|zylMg&%gZp^uorhQzo=U9I~9slgZ+AaF|}~ zC+_}y(k?>F>UnAJC==_U9p>#mMGV?HLG>YP`QR!%;YF2)9CBCc@OD;kr->tYl>M+( z_pHvG&)ne3*x$^EHgX=%CRxnGm2m-h!8SZK%!)optBOhL*1Q)_;De)BH4d*b2n&i= zEJki8{=55A79pClX6?YM6N7isR<+c*t^M)JEji+9e%+(AGvCq(Okf3FnN``^<3O+3 z$Q$`TH}=0XW|d4)vB=!4iWjq1sQ9_GiCIHk($hMhfSbIDx47qzgpG-{9}9o>niaqv zb}Yj%n}zdvijJKEK-T8GWb0;!b=tskPz&c0{#*2dc$9^L3_`jB5j`f?(1l6ke zundHp^oOnJb#oI)#5zQj$N9^`9or`gSnuok7b^V_qt|!7@u^CmO2$DZ#*EqPCExZ9 zME8lE*INPNDud8na`Rdb<=7CIr?hTkP6KLnWfJVNGR?jIFm+4bW4v(pZq5@G z#Ssyo1=0=wYVL-sbT*Gi>LadnB(t!eN&QqDJ)N1%7vcjn$fn&DB(E})R2jeMkMndi@@h#o)2eU2ew`=*14~$k$sY&xv>(gL6-!o_B z2y3Cc2NLhH(#w3m`WIedoKOfO5{KY96w|tQVl0b;501l_Sx@bwIYW0pGb=?paxOE^ zTk`8^4=h%M;8n2R54-WBMb#TP0I}3sp!}hXgGKq|eqQM(du)}!dMDK|44UR~2p&Wu zn7dhZ>)pD`34M^8BMydNSVQeDDN;~5WTEiiU1U$ws;6r6dATfZ&&TONwGSeYi}_sr z_VaR>iDx++3{Np54DVFE_6@oa(`MpXt2ysU=0slm{wUgGxX>)fs>*#ilckB%u(2_M zHPv=RLjAxw>1}ju5t+Y;D0~Dz=CS2>q}X~1Z*wk-6%9zn$j$Vq;fRR(;4}P%-|{%V zt*`7!zRpIvQb}8WwTg>*CYP&102(EEUX5vK*>CSn7K@C_NJJI*=E@}M88DrTaCd^t z>QZEgIpAIRoxa#BK?KJhuxj5}k7N|&z;EPp#=#f-Hx6e(3jJ=>hmv@xBOr~aM$46r z)|}C*a_`G*F2Xl_N7f=Dd0`wPfNS$2l7WlOpTtH{2MTL*)Fax{H(bNftjxF0D&n$q z5+((ab#gei%$7Fn4WbHU?;d(4WphvBBOcOGcf}o>XAvUCv+IQ1morE23dE?xYe>l-iVYFXxEn^mxEz z`g}HV&GlKnh}wD?rcOSr3Pxct#YVH+Bm>*@hThDWVqLXRdZy*}Ror1q#%sjnVKxTIi$+B3Ox|(}^)I`iJ+fzpT?^8yOMDFHhMbN!z$0%mP@Lg#p`nM`t z9QIURxd~>3QF*%Hwj3*h@_l$={+Flj%Gr27udrJWMyMx7mikzR>rQ0rJ?ll2W-(nIzK$2?Ql6+fGE?bT)d z({uUscD_>`C=-zf_%#%l2^;W=FA=vxcRRLVXL0-c7uv z0Z&ORBeQM-7R!Jy$7XxaG5>1~86NQFSM%(onH*cD`_Wj7$i>1T2rpyR&9sUX`b@?} zL>ABA&d0yZC_sc({$&ObSGjTvO7$E#(wL?%t|8Vmx_Ce0X{~DNA3ttxLrze>`K6Ki)e{y+I?Y#%UR)sq zdlL4#y_4G7%+{rfq1708q~S)_#zE+zk`F)RIF5EdO7YqZkU0x{s{XGU1?#8dORK|oDof<)XH-i*kT&f> z^T`}}G(O4y=F$h1(q?UDh|GoZEY@E+o0&W@_S-|uWpT!+?bmXy7--afpb$^?E0gzm z<;pYm=UWKt{uq`)E*FAMo}xIB?=Pfxl_5_fvCbu!-#3#2<$G@cM6i#ic<9b6GGH|; zIZs@I|8U6qx3+>Q6$kF54d+kh6V^rQBCjgC=x-inIe&Qt%MpLfomtz)28{tCT3y$7 zbS#%${HXTc1$X+ zp?y)!HDMj(7kTx~)#{aAkrORBzI<9w^wc}&<0!>E$2Zrltc*d#WixmW&uGhJLz-nA z&KvP0-?^J$IdAR&{oH+HTWV9`K)ZSzNQx)qx^T$2E5o)CvsIbg2l0Iky!;o zwo~cT4TeJ!;t1ZgX+S2nS7BOZ>cuY{rSruLtt$xT#ujBBaFSa&_1R?H!i z{Q05o>=DCs(?FGJUK)gt(A8V92!))75T(t zHu-8=e{t-)LC=pwLPxUl)X43l&9^Xgt4<&fy_S`_x(Ebs#RazD2$>s87EyXHxABsr zN*hciSF=L=)L1cCyCxsXnW|dm<8UHBqnQLN@9imvUsNkPjkldr$Kg`ts>FhSy{?%1XPta(O73+heEtvJyhBu`RfPy}8RyhDGwQ zi@eM<@Z_DDS(BmgUfIv3oGT`?T%1CVvNaWWSkw9lBvLEINZ6qowYj48;VOEEgTRG% z^Is;(QsCUN)UjbSB)^d}%+H-j>nd07#4!JQVwkrgJ^8*myS}y(r;aNlK9sYU@`{Dv zc`6ZJ(<)!0i7UzIa7#Kgi)1y)bGZ`B#7oRb+gI(tpgGrIhq9WYX?@p`We3G4cy}`M z1@0{1m3F05GvEz02dh{(v}j~(9_;(h>+sA&^#Y?damQV6ngtei*YfXT~Kw$l> zchgSoji=-0Mx+LhVYnJZlht-LK=aDVPK^tO(2h~b5Y=nQo^-@h_e2<9*#dcsB)%2P z^~S1yQCI(2Ien=rh#wslBp1i|1NqCSkHo+90v@`HLq@e_c+a1BmAqnm+MrYOTkfIQ zm2<2TaL=i{iQKK(#)kC5zL(8*Xppa|j2(zh&BTdV)^74d6+7{vPpVMO=^Y&G+m64G z>od8c`(pSFmU%K)G|TJtVm{@;?vb%N4+gN{iy1qetI~^-DjQ-fyXAElCR|ffQ*$vl z3-?~jxkhZ}b~2wFisW6l)~j}ItvUI!Y1L*DaKFDC4hyVDsoWbgPqz#D<@8W2r~T7G z(95~M1F5glMQd3feGFvp7<$?9H(T*{x6wjrV@)Xx>_GIk4a|}(Jfs*sE zb@7kYnh_%jyODe*Z5#=5y3fIEQSU9l?dVmXm-9P{O{^UIBSX4`-x!fk_$-WJ>7;Zk zuP{t_R=u(`k=yQ7BIc$+!manP5EXUW;}_-q?9R@Jt)H=TK0$B%2>VfG6`xvFo*nq4 z>YFMr8^ia_pKr~(o=c`IQ#GXCtcoEfTZceOd>dOMS$t8=QXUDhy4s7ULk&5&JMEe! z$qwyHYpKkou}RUjD`r()^s%`E`!zOgj@D=BQ!u$*}Lh@c82HC6rnT(4K zs?u~;vb;f5gWUYZ`Se6$WKnAsJ=r-c>NB|N?83h>OI@m6L&ZBrs}IP zMcqc*dQ672%G-DeRz&Y!%WmkM=2hjf4IWiFHFqQ%dNCe7aV_$3#bO$6?CN^sdvy(C z_0E4$%!nHMXq%MUrt9+UD?yh=?M7yDX~sFKrP7EPK%4sMYGSee;BCfB%lM(Qu^hb? z6RIuTx;;^uCco2(^P6*&-5ki7NAly=9nG8T-Xsa!+kx=xu;j2 zV-%txyBH60h*2<8j57N{Rv3#4g+A~Q_C!B4*O(y(BrE=|;*F)^*N~W|Fhc$-e(UF< zFii+jcG^{-GCOxciQ;4^PHU;2f_1aAuBL)Uo?ub5#CDszSiPls)_Vn-8JP!mo~!Zz zvc*VME%dfJx;uY@NN!bo`v=~x5cPO|zc%jZ9? zXv}m2t6Z5^=?`SK2B$0iARODpI8-`1g9432XU}Dnr)te;wF2|>LtFcDluBJQ?N$zo z27IimA^4SXll=&BI!|0oVZZ7->__{&Ob$Mo-tY|e+u2ojUl$*1zr~Q&@Zbx71=aE<}R0Rg!RTG_A3GwH_Rhl&g*H=?k1QUs9i&$vfFYBde%#C^l$@ z-dMy1_XXQYPo}M>6T{W4=jWg7N>Aq^x#P*Z<@|O5tHfjw+NlMzcKTu~td559KlV-& z>}_ZMlBE?3{X|=`Z7W-$ynGP8LBz&XuXu*Lz)9DfsCES+X$oRS1Y(KDUt`T zD))m9eANAxMZ3mGy{B_6^L3fmL~=WPkL{_OlL|&eQmu{9>q;zw&5)s-q!p)Hf{FkP zWIINGC^@Alu_@Ak65Lwd&wa2VROPmcYS$hP>& z3)Kenmo*wGR_@t{Y^wJvz6Va$7e=#P<*K)OhPl*9^Pnu8%7cplm=fy2BcK#2^ z>ajcWo;3L@Z*xCD5u9JOcIh5(ISot4s$7ShcqzM+-Fse8*U9J??b;#4%Ch#Am$@U- zjBT?neQRxxw#gg%{-6R36^w<~(x(14GRdL%Z{Pb1RYymNFy@K)4G(K8Sgn5X?e5=J zzGQce+L~U_b)$nTJ723<62GEd_MjJL?aT`lP8G?;2)|@fZ25Ygm1)8d5GOW-jD-wMS8|KKfG7 zR-L3PRko(3oAFcFs>k+RypbQ=LX8xs#%k>VUtZVBi)hV1aS1E>)hpx$Y(ek&!@e+? z^?ZZlm}@3~o-I0Kv2xR$Lp*XXnAXVJ+z4M~6K0jz&!^K?c~|4I)@h&D%V?ctJ&GFJ z@gUFL;i@0bJ&r|3+O`(z@1a z+rVr%Zciq$SCsEMIPE|;h$7>Fxju340^AYBF?YxCL|2B{G+jQ(tNE3_m__C_^iDRS zf;c^@pyLnbZYnQcr*$o8xADUG&!p~eT?uc1Tx2WSblym%RvnQi`h>l~u}hgh*p{Qp zC-^wre<{~%bniXqFae8fCCv^G`pGhl2lqId70>RgfYBHd9UF}bjMsFg}|vPpyqxGvMBS{gbO4$d9>WM~N@yPF8bea^MYz^Vx~4 z#f1@NhPa934!69&8pUJ#bImB9aS5FaT_<-14+jhqg>#A))a!*!m-iEWNoiv*z zS6rxGrbUvnrb&JBNDxt-!+uP)z2lL3ZFIEbmV3M)2sGJ<^!8>xLoM9eOZ$V_UF2v` z(ds-5{zjyCEIL*Z=X*s8o{39ohlP`f9Lv6U#T>qXN5OPF#8X(TFOdi1G|izJsY!tc z@Du3CZ=rzrXIPH4+t(Z<-JDBHlQm#8ngJ1;rdKAVeJYJ-^i;Y znU~$0|0>eW`0y$#cwAeJjNh4Chi|+_rHoYQLyU%FVt}2%unjTh{eSnpKlJ|hj6L(W zCyt)EJT>*RANFg!-oJ3b#lr{B9DV`Qp(dfBtsI&;7+~ZgtJSv1RcS&ph}4qvdw??`w4F-~Hmp zp7}fJ)Wu6v$EMF5nVLR-X=>`{Klt-gJ?I~Pe?&k2mmTsKr1I@g{qRrx@Zx)ZI6v?G zp=W;VFFX3>%LmV0Iyrmv$n%Sz+KAmOB|qDonXR?>?rQw+dj8+er=IS&mOBI<#2bAx zpXUcG#OiJ*=ed~wH*?-h-a$pr_Fai*ujiH3_mA@KQS`W$|7&C4i^a@kX6QlASkD#Q zkG`B1Hgfc(vHuvWSx&{?(}~w&S?g(YN3MB0?{B0vyPIHXYk9X6v|Y~M)gZ=t?!R)M z)_2pZg|TnPW^d)Fuccp0vG=8T)zh@{Fg;kw)$gZ2%lTc%_i7$id+W{he72gl_vD># zE~c0G>q5>l!@inNJ+H%_8=mB3)YsD5YOb*O=SSYp)$EJ6njhm>&%25AaW1d!q+d%J z-_2a*N?^sN8SahvF0cPo zzC9XryO#FeO<%S9G-tjO6#K2n{%)}LagOth!e7eg-x>S0NatRDUk_5eGU)E+*e}Nl zelsn5M##yaDa1UQXH;zD8VAxZxyF^Sd-${~5~lsP$G(~4z8mX!5bUxB>S(0zF7Y>V zmU{FDGxlFfEZUW;Jc)E~=Pc;_)tvKtVS7jNe=V)N9cis)tZ(L+i|Hx6G{fj_k{1UY zJD$o(# z9^`qOv+2=dKDm_f%RttH@LwAHmBGKhnCm`>hVEzFZ|Ai<_&{X!t>|hwqdJrEY(y)! zBh{-p@(b~*ley9tqPef6y+27yzmt)?o8z#~rRZcw&izM`;UDGQ?Hu!^XbT&BF%r}2 z*|Gne{@%^E7ostDMPfu>%Q0`~DsSfaFQ%2Rr@c4Ro9VItm2a*`VsC{5i7MaDc?;3n zVx%;i-GcmE+HSR4CiY4v)dTDm z{#JUfVz`i2XL8oV^z(Mkc$7X}$*~(b@@dX{7@0gtyRQx3LBR8Q2VWlL|J8hJ_T_1M zvXCp?N>BKrJ2YO-QF{3Gyrz-cY3D}zb~wTQZIShS`m>RrFXYv|wD^a)!p-#cX|8rWEl%b7hx5(twDyg`Kb&EdnC49S zyb&DG(yOtf-^|&+GWMg<*?M|4mCwc_w}<(3Ef!_xh{cQ^9^cO2`FPjujPvXuJ$ADg ztKhX)bL6Y}#{I3MaI z>fVi)2lLs5v{V~nQ;X5eujG4C`#aJ5cXPzGXn5b)Kac&qp8xyvbl{~JGj zH@*9IMtm>!wwiIlJKi-NTX`pKsRYVY9t7#}MsoO2Y~x?1#W(W)cO&`nT+5zbk8`$_ z)1vv01f%D2wEHT)o8z8j#IS37aP@5Z^VMk5YNW&I$&;YjMl97#t=R}w%)9xy6x-UB z>-|WsA}6>UOTLmtQVc8ow59_#{W7VLp#wzUCv>5HYo@4J0 zKD8A8cr~B#bH2|TuZ{hBdUPR2u(ka;&NC0*N$hx%Ud`odi?QpKSoQtEXK95$UyJT8 z=M4O0Jx91}<3iqv4v+KxX|8ZDpIyoC$($(Je6LkweU$#o1LyJ?F2wHH`?b7U&F3P!Csp2$miML& zdpwIT^Eq~B{!ZriaNgTPzz*g&({mhYIl5J;dYbPp=2KPljf_D1q;n_sW9I?77AVKe#DL8oW;mwBUk1#>|7+v|H1=UmTa)R&?) zQ47<0Gp*Wx`9|zhe;sY_iO1>Rqe$XIdG%NG+3RCJo*v!LckiZ!zm->{a3^QKlGWlD z#{Ng-XXg2O@Z`_Zw&$&zJ-?G9?dWnLG3d*=+Uq&uTe156x%y;!x+(n^?=Ht4S?;Te zl8>U%{}da!l5-#CJNYh8Q(0We|C__89uGG9zFg_Gv_2D0`0ebZY5!nan#d7HT+eUSFP zm^~6Onb~;9ALN@m(SwX|GA%#NnBK?;50Cw1`Z^zfdXyjbMc&(U_S=yjIi1W|c+1K3 z}(-e13N=XIvgyvo`DcAa^LZ zklyrJ18j4Ckj!lQu#)5H^hVmfI$Yfj&7zKHxr(sM@q)I`r!VHZd_eqi)PwYHBmY@9 zzj-BpMcR$%MAl-wo}&CD61tjmm(nf^hgnZE9;iAW*)Qc=IUkn$W`1ww+_@a3PW~j{ z&&9TJ%iqp8?#1Gs#2(&A%lC7HxcAL``gZ!|?%gI>Hq>s<%3hxAkPNu!rzRw!|bv`4XYXdf{ z=6ieCs3ypzUdY*N>E-((3t1el_N_46f1jQ_$a!kpZ)H^SGczv_^Z#mEkZ-B*+>bXu z&G)kz{?hibP*>oj_;(6JGrJcspeZ@sr*;e!NK;YeYSls zy||a7FrO3o-pJH2#eRKWPBe4GoxC#d{7(ML8DYo6j132Zov+7oHipq&%y;*LLyz;V zU4PWPS>>DgtA8Z*YWi?5M_{*K&uHFGZG$|T+jhP|HeAEd`Cx#pc*<4`0nQ(DhotG({#FW=cnd+Hw~gMG0Uybe3FD+~X7 zl8YM(Q2ni7> z$`y(fAt5ff;r~1T`IM1}1hT69)%nhOmuLMy&-=XZc_}|v)2`THEo1rP@byu?|C`~& z*gj zZ>Lwe+`F+SALTx4Ct}o7X?YG;I4-$GXcQ1txyy2C6f_`DC>G;jhWX_wp*Ewn1d6sRWKi{2*RYFR7M>^q_7(X?ij?og-?RjW~!kA`J| zR%Mp8S&}iIJJl?TGB4+^FXZo~{IjYdezX_oLGHUgNTXcD^*p_nxn0XWXnH@-%OQxQ z#Q87gvup-GE!Q@ckzEPm;oVNBpX>P_4iW>f_eLyBaU*|2wGgbl(RwH*&u7-^c~}65 zh1YcI=Hq;xi{Fv2^vhbc=kXO8(=g60H=AaeG`+8ciUVtvLhdn7;hSeE)Q1;VYXX*Abn(bWa z8HLcV*ls4H6??oHse?iA?smo@W|grz6$-G@x6;pKG;1YtMfQwWPEC#S-HgSoU=6Vj zyml*}MHA@FQ}En)#iz;2h{i<7n`z-##>EQS<$5EMB#&VqS(88pO#|JfAzSM5d1hFRbUgi|P4|SiKMO+1adi z<*z5R`6N#+g{m{bOJYp%#xKQFpUOC2P3xCpsm_IF{L^-R-WdBt{(lnAypW#elDjhJ zN16Rjp7?yOS_|9{QkIvRPG1`t*ON%}!?Z&Sj)yMS2Ws!9@6C+oVdl4x`=P+4+$+ku zmVU-FE>sl(Kh6KY5J|a}tMJ$zj)XRz4xLXi}sr{{>VQp zW=wAkR9ACGUP<{Mq*ron_wt=&Ue?nKiMy8X?5ukwG$Q+~>g(Z!%(fVZx1|@qoo_bs z-C9QWTCTu#WEHkoYq=U7lEZjCb9_GETuvU8MRc6`7p zoLSN)Pw)snH3pTboXn_T01>*@Z{~mhw0kXMavi!6t-8ZMG;4dJYN&@l_N=kNqk5wC z9iw)OkJ^IPIxjvuVH9W6 z&-IL0?sGYJjpu2Zi>|2g+cFnog}Z5;71;?IS&gpXp==QAstVnDjhqsnBqRQ-p(30L znTzje*N@X0o&LGZUhF)bd;%$1O%G=J zXZfy(e=qm{dMw*aB;bRLKot6w*d2U+Be?EHsQ4hQzL-eqlLNKIy7G|E<-2ovQWe7U ziK*;=_*BO8ne_Bwdih?)^>;H{Iq!+g;nmD#F8uzv;b~T5K74x=3arF0)Q7BRTn|I1 zuck-m4#-NtE~xTqT0fn!vI`e8*4J`}T18oUxZz*t_iJhI?OffCZU3D-CEpKCPo+(D zIDb91MdssvDEWt>&0_lf1Vy{=z>hqC~W{K6ou4b;AndfAp`rioWZ|3`tGZu38uhRRK{JxnM-_49q zMXRq3_Waj_^#5D%<_GCo$oI&E8-R4Xo@`eX6< z;;jn<)vSEn88rHS{(mJS`gES!9Vm1r@qwLYA0?Wh14YQ9ViIsS?L$&$GV2E!)A@|e znj{<{68l=lC(kacd?(jfBXQYm=EfE-=9+wlY=Qiz_-rPlQrmGePu$HYtmRIoHFjHm z!Q0LjGB@~4r0KhFq#f3n$KY2kyDE$&DK&ElB=w4%x8qZpE0fG zN!~$@i+3QmB1P}nyfyaYpHNpY`9q)lTmR2`hL>J_^#@BK{*J1J%`boSv!DF&|F1`X zpw{8W@Ba6n{tIdye(DoHSnDwJL!bP=YaMdS57aum|J@%^>hRZo^oM`+hyTK2hur=H z#SY*75xb)2q7S>F1HWqJ2=X8KN#2=m^N#8fSZW@3E_7vO=JW4%o)Y!Vj1cmJMBoe= z)JxF|zT9YTgiceT(sbI{iw%e6J{-Q`_04K&pqc22*I&#?RNx)SZxJECAa1=G{nn3n z_3>f-*YyItqiwQgGHKSh#X7L??$9%jC@z}Lh#^Z+$XtH3D?Wt&D1(AH$)~FChi^0!jISr`*Z^ zdY-&OW9}(O^|X-BXEU#6L8mgdW9bE!w`RSS*5xYn^en#@GkU8BxAWcY+zaQ)+RUeq zYk8vRS6)kg%5UQ&A2XrAa_*jwx2lww8T#Ih{X;u*6JZ$TYq?^i3+dgu(X~Mq=mpJy zw0#F};{_f<%%Ti^i|8)o2Y$w7Bji!gjs%O}kPOhz9xAqp_s~#ujVfqHIeh~F!oEihb)U3&}2JTr!zKLf*W~?RU|>? z!(Vph8+l4*vb9kvjM%4pxmWvh8JRWaRzzrvc38DgSwT`{T_GzqX8Ix1FfNa&hQ?eT zhAvmbtACn%jY6IpUSK({=Po>FyKs+Rt**%b+|QGja+ftmvJIolj#x3J;q=Gq&aL$5 zFOg~OpIMU{ktuoZIQe3$lp;8CMcznnYt*jhX>EwLHixM9*W$Is0%pQOs0U)P$@_GE ztR+}E=TFcKP00;APYzjE?N8^~O5Bay&*wld;tSE1u{NhiLRkUcsiQ|ZGUR_Psad@f zy}KRy^Ea2%mYLG=+ATRKbn*{2kE7es57zZEs~;XUp!_KmyVxptqd zxXNP5qgtUhI@XCjGFBcE*QPTH6k!8c8+5r|;-NYcJ#;Gc+D}XeTXO6Hy6+@~j6L?ua~hJv6A4u>53KUd}in z5m*_o+4A~^xwMF`q(GH}*4$D1Ze(~b2ZrL1ebE>F5#y8l{!<{Rge#tPuri7|tiD_|dK$UbX)O65%TRz^0%bFcjvdekDjU;!W>>U9hS&4?{9u8g50qmQ_#WC#&aG0Zia3ctT+@WTX&&tkMa~bw7#WQ zmM-y-dY{P^6*y#^wG>YkJBtf=2Qs~$?;&mL%rqLFVxxB=`)owfh%ukfTu_%^C4D4K z20{cU=3LFU`##Cba3up9M4p7cEn1eM5$;;DJjE!=9< z!^*Hoyc{nA1sE5@AyT#aQG5-m5!sdXUQFlRAjh2v_@8np4j0L=W1=$?^(x_@8E^m z(duLMgluo7RbG-b@zv1Cp^z`3WLkv2|E?emsa3$hXpt zontC#pcpoZeA5Q2H#kl=**W;m>IrPl^P{u6N%Rz@lKA;#LyzS<=hfWHnC*msxzLps zl0eps6}Q4kAGE;o&*Xj_qQBG0RX)vkyZPUHx$I(8HR}xx@&MNBuf#JxOPi0wXI4j? z`<3|0)nsMAm^OBDeJisgSH?gG%*?Jh>z&PXL5*ysbz7EBlx!_t#+mJB$LzBCR&wq) zL%k2umc4UUj2q30yoxm6%bnIWn$u}4DAURwlRpuQS_{&1K35l`=t)xFEM(Ug(Q33pr^Z4wC1#*~FSUXlXE&i_C38cP|7&u zCh3K_lOHuuMc1_%>@Z6T0g_1^)H8e!2cjLz44ripqX^TC<)dg=qZD(HOVXZE!CZ@J z1wUC`{DY1}f#_}4s67`cvoDCP>`ASn9KOQ%i?g*Lz9KV4&a9_aQOP~1*E&&Hsz`#| z$bI3z`Wb!G#9~l?lU4D$^}LO=V#fCF!+5t-&!Udyf_>S=@hSDi#M8%$5@PpIl+vy@Kz{(JhU+v{KfZ1n&)!Wt`?Z# zYU~aFAxFLx4O0jCY?uSQjAzy)-66JNhpdJ?$W^prr6*!1RG*V7{(u(oQ;%cMWJuxD zt>~tm{roLo0*%-Qe>D~SR;AgwThVUoru;0_W!DXA+Ve-c!~!d!-r?N)pE3&3f_(7P zP;Wl1i&+~%i(Kp@^9)Jwjv4bnVKRc(_Lc19FT8<7c}!JjB4ie!+~gslD6P1*oe?cH!D*UT z&22o!mdX&yy~-2FY>NreoEG~H`#~ls0k_c`cC3hzKV!!sFTRd;kq}SXw}={AbC+w{ zrn|KkwIOtl&aivdoX8?n3603D!8RgOwhpdlL(#;z$f3T`!6zdqHZ)t5@LRj$BO}+7 z@7zbv+@mG(N&|6}g)F^%>ps8D%bJ1Nv1_8va-}N{W@*&-a*w<;4DTKpB{I|8P^}a% zOv2r-4VJ^m%*#0aGk2CmRHI+$b2=PV_XPP*=V^1cqpEg`uhvF=l6KiP<0@rj^T;&a zW$`_ym9~tcG^r7SZ$&Nc$90n8@8kKcVu`&`$Jg4!=BG%|WM(e1gB0LUILba-HpU!S zJ*WUza2mE{qYmUbGgiy0{!L|(I$j)Hie^7eJI>KPI}$ueb1$SdtIBJ`J!;?f z@|177hhQ?_j1S{=E(|HBkvL%8cp&`N%V%@ z*itf6FV3>TZz`S9L&jh|Z8w%L4s>trKCO9@`(I95+i6G42L;ev*#ojX9Z#m(_npkb zN&L{!N?Jy%#=J&bPebmyCT1TO|H4+X&#d;*Accj<7%6~BNO|R^zD&>R*w`@@Am03= z*5*K-ektSYI-{{`Q{VOXtnPNM*b!Oe$`g?VIGKNk(;sGJxWv{Nu`G;LR}m(BP`?RH z6}dKM$_|TerYwxl^+k5RYY#dtHa2>3Cs|=P7bE3tk|#Z5{JW8XMq+BG`5c}? zzLz|~7goho|I0hF6Z}Ul1>Y$TKnjeO#PC}vYG0UHs;Y&hjY0m6T`q!Q|L`9c<-lNdZm^HXPM$M)0!P&b>c>C1R=d{N zJ$&Lr$$VuCH=WmFADo;UKV|M#6`BWZTuHCGP654;-eO4OA&IQ5Q)ez@9O@KFB{{T` zWX|FS)bIK$t5WPG4puqECybbAxE){hTv{eOq?7C$A3mCuiX82*Fm{pPjV#)FtgE?`C0PcRRpy{sFZmw5;4axU zZhHi*c3O>K(djVsh+3Xf+pRM4Td@TE(5c*^1s1`mj7D3q)U$Z~T0iv`a_@Fbt7q;= zn;nJS*bh?LtSkTT(~W_`H)Ex{h7Uul2&M~ca-~IOw=4}C#yg?2Yp)Cl1cH`lP1rK3R!xqqu=Vpka}5`Z>lZuIohT2WWf>e+CKG>|HpI6Ng^^hQjkP4j%3 zR%NTn7dz(;v|xiy<=%yO5tK0sE%E>3iNAPI6(r`)10KpJ++a&L2TG~aF($STRc3P9 zLGg_I4lj>4WQ>JW6AIP8mv5|i&@(Z(ctK?ft#G%(zK3s;jQdlTUtRI0Z^A^O2*Ssj<#5vT5_T9 z-n}w%Ow$8}`Bff|%vmuZU7k^YE0$$VSXI`UEfED)8st#y3H#Ob1H0%8ht+@47PCwK)leuOGxF}(3&_~E_K6JS{aOtJA_HORq%2nEVEq_%@piC!XT??q#4UamC=7HHRXWPKm)|JWhQm$Er z{>Un|tj`r4W}}6MQ@RwwLFc@zCaGMlU;Cx~biRJFOL79c+wMP8RWQe%nPL zVmDuVgEun@-@_PA+w*2IClbk}Lxp64EjA(%0 zBb&6^en{sYuqAe+TuCo1DV%|>I|IF+OADl|aSv@G0p&TFu{!=JS78*ERUY9&=&Zg+ zmQc1A@0)jQv~N`BGk55Kyvo5XBvK)Cm7Y+NHETE2{ zvzljBFpxf|5O>VZe&cuI&kp9f)%38Fe&j7)h&*oQ4l{Q)vV0l*#U7gjWTke|+9e*7 zOWL5Kl813J-9Jw?`fukXGy4hcd^HoOLdI9df-ROYw||-~5&t;tb~&RF1(1*j;g;+Z zq{E`I$DfKXdzzgkU(IOws-5(6Fui>t_{I4@_wvk@j2ewtS2*H9s zNBMsu_e_TVwKWhNIhDPu_kn{UAeB+RnaY?y%6Bl9TpNz?QT%q%oha3+U{OBQN(a@` z$<4q{s<&8Rl^7y7n9%wM+f__e-(zMpM?8y)ycz3)?#H^nljlFZsV#p|YVg#4cV>iF9kx zoWTNL7$n5(^@qa08ijH7@K9eow?c5>% zZhi0b87qr>f7prU?Z5va`?V*diJ!}?*`Tf5Z4^DxMlGM+kf^Fk*I8pSiR2qcWdL?E zgWHkRgM&WU)o53W^CtH4G`k{0N?v)ZW)kdWRq<hHFiMC?mgih-m1x4LSv=Ehu6!e-gxu^)HI7ST|06a4|Fo_W%=qu52hwBb zEI+y*X~V+>Q2>aw*{{t7}klK(}3GWLh_EY$3r8fRO;&)#D6yTjjn0^bbr ziTTK&dJbm8Z`kk3*iWQa=gkzsz;9p5b=;s=B-#2CZtx^i z@EEKaO}5HK-bEAg)*FehPv^?{jQG9u`+5*G6sN8Lo`J7;596B+S0AUPFO2>B`~~-0 zd2p`Q&Jg{vCgP4?iMIXc_~MC-1a@z&GY`cV&>skoZT-qHbGBnGquq;S**C`auuZaS ztS3Y%P7&GMOJ84#42qffO=#E|7yfGY?Zu4yHzF+)d4`m2XP%z)+xnps2vC;C;xkdS zKC1f%`6_qvKB(Irb7-I{Xex+R&Zic?*?Q-Jk{{L2qj-y#l%sHx%zApi9O^okwAn|Q zWOX3sh}PE7yc?OXHxp-=!Z-`zk7ruH1s{sGp#UQ^PL*-24_^W$K;!#qja5T6YuZ;s zGh-HsJ8Ouqk>_EtRXxI{WJKP2gfPf8R3oP#x1qQB=!+5aWp=VdM>Mq>j67$ctXevm0pxpLs91V?V!nCVuUBp5G0J&_UkGX?3zUB#vy8#jcOM zGF<0toWRDru;+3c;v4i6-J%r>E53nj|3>a!%2;`QSEqvk@E6{;{zmF~IPVr=qs31$ zsBi!v{?zw1(Seh-j?y4%y3A^sLLxH_zXGvoBc&IvW&Y@8vdoq{SYTevL0(gB0 zn;t{OGzZ0DDmH;UsE%Zlc>{S$k@RNzgX(cZoY|UowvVjP33^_PNh)x#nl8duC14d; zZ(R#AV?FSt^?GcNmixUikpsSvrI%CVr;HZgs>vi60+@;39Lty*|H+fPL`B8Cl@H@ zN92lcp5b-q9=nL1c;h?vlAO|aXS_Ly^I85PY5X)9@eC`EKF$V^5BO*KyOB7b3uC`C zT!--RyQ_4@tUp=>hsimxHRfwB z$`GM(!OP7?-Hs)KG+xNaMa+D(XrCONOdnb;uk9ycz2Ul1o=4m4HF`Y_<(-h<+6YKb z-o!Y?8e~q++ia|v;+@$yb`Ztrhwtw{+KChGT`$H3k)m|#>L6(G3~sVnqTjAK`0aTb zT|drh@~l>EQKl`44MZd;MhSbI7)f7fSW8R4-RuAIw{lVsCH}5{!AxTFuK3X$G&73IoK;~XY;_H9 zZ#<+Nf~6PwtGC$BQ{+tK-wa_hG;D^|F(V;Sbb_8g3YDF;2yf6|ew!Vrlp06p#a5U* zTGJu&EMg|f?p@FG{0?NS<#K|As8&N$9*K2<9auk9vZ7(MALQF+H?6*s4V+<-w52lj zRAOp9^m=0v(nkx#_W5s_{u*PKU;sdDs4uU)aFi|(pDeHdZ0P0WCkQqd_#JjOw;kwYoEk>W~DXb zQhx{e;Y#({8{&8<_=`S}6V?ZfY2pZzqIx-Hr9>ow$MixhM=!{jacRk{iYzKovM6`~ zJIN=InM$)yy+=L`&15=A$4K9tK#DuGrm}e|#TICacg9ZgSa4J_Us_(t zK^OTvc$aTtjo6_5{N*hNdiFdE=iFyen7PY{@B^E<1_hPE*D$9EpCz%a*Bp@XglMp)WpwKzVTYx(ow6^ zHF!Ne9j?Iztw->b69*6H3K>;D@Kd>GX|T9ZBHiX~^$IhJwpbcEB%YE1ED}HiK8ifE z9jM}+57u`0ciD~3zVa^8V?9@bT6ssUOhxA9TK3Z;%!V3t?{;KF?4~NoP8PlR7EPcv zJFTE>cIs$;a7K%r6X}I#Sx+%uC0g9ok(xtyRG-TgRk(5^+C(`t*bBY2rA^p^9T6$e zLNj3T*erBqGvwQ?F}b%hBOheZY`P*_zL7VYOj`I8rw?1_>OCZY~?l(^8Urq})W0N7kW?)B%y&3o-LT&K`?u0jRywAc+ zwI-~X`a{`J^Wn*~#apnpsGwJ#C`I9lxT9KML(tH+ETx$i=h|!#dKu0k+#E3$Tx5<9vAscJ!)vmc{j+N(^@|-=%&Lw7%)g03+GlQO3PfwEA+81)b;xzW| zephre2fAIvVb2=WLQ;#>#c9=z1>|#E%t}1>jq4LpWSIgSw1)NqQx5B=N zK;_vE=Pp%9R#|ZCwLDh~WYvajI_p<9TP=(^TdQyiAAcrCiTn0qb+vvqS6KXS#}_?H zi*%vHJS-Uxt0X35Cd49k?-`o)dnn# zy27FtpS9UO(VldW$&Lt@jgsf*8|j_B!lGId&T^u&I!t_}Pj)|}zUn4E+n7XbRy64q z?Uj+CgD7llo)BrkfvVToc@hFU;D`LYdzy75$*92liesuBvL1Af%w0%dPKI0^C1bEI%l;h>RV-oco8U=C8725C~{7G1{n~^<4$+FS#9I@RQ6hb5L>aVCqpYM)$&+q zjBBDEKA6PIUA~&@e86t*(#NgHB)oPc_ps0;;kAs0Z8y`NpvZ>d1^E_-(^4L2Kfg;S zF$$?tUBbel+p+vF?>E{FYkdqxH!~VcbDQ6zX}pv9eKI{bZR}BI_+Xd^`Ks?IKCEZ+ zM!UvGe5j0@RTS2UZBq5qtPtz~%Qkv!X0MX}I5V9K2KwIcJQU2D(tCALwf5Q)O^LP| zGj?at!^ljl$D2K2U&woPr!oJPSW)t9Olk)Ba&}#fnAi|?R6n3)pM}r2^A=Gn%eB)` z*?2Ug`sO|o0}1e7MoY)=23o~8nH>2!Qj4mrv#KkjXH`lc8it#Dxd*ZkA<6DOOtpl$ zkb3o3s!ZO=lQdk5J%g)!#eM+AKHon-<9nJZRUOhHy&`jRyTqMEXdOlX|nVPq*RgfOu z8QtY7>g`mk(CB{>s)-8lk*?U~s}-ZRhATUGC_m)`PXYPxuc8Zc#3yg?sU4#;u-WC< z7Nf15KuHooQ=n!PB(YuThDk)g>K|UnldPQhi*&H!C<7(46;>F{3l;EJ&PPUuPwtwj z+|Gz%u$9~{Zu-Z&p={&Oy|hMWU<}C4*){memsshw;&CK@(T8K810YAm?r8tY}f;Q9PCvU0c66a99Uf;uSL-4j3>r&t4`m7;3(bVokKD5`qdU*kjNu0*i5`e#{#^RBxhUr z)M}?#jI|`0@^`YHsuoqM$-!@=6>kwGmoT6R2oFm!wW4x~XNH;DdC3;jeDf$3*)^Ve zH@%P!b0t6gsGKROaCVKihLHXOlb!UZp*c_vjC#t1{bLh9a1NH}Zj83oBM^>-0jV};Wb9!Q9_}duB4(a73`TS}v zE6TGhrxL@nTGkV7XHz?8D^M_#-K+M>$%Cl0S3^Y$tQ8(jv~VNm2EG`o-3d;v_rRmt zRFYUZ@}6-XQ^t{vV2|jh{2y7U{CSEbK^Hum-0&!7!#8@bN{W7{=E4&-%32}`{6DEE z4n;|4DmVk;c;-b;%o7KlF@tKXJgY?}cJnuW`ByI@yU8-C!RV~UzfT|ZuAe0pVo92U z*77QRfEG|3n%~Y?%?4-4zk6ABTHyr9{XVQEQiH9m3zKE~0$Z{iFgL2OK-w@gQdYia1E_!PzIdmx;+VEsSDcd7k&__F-^?sThpd~Jly=ZAngClEL-Q8CgW7&2 zBUq1&+cjzx-Fi66yqb3Cx)G{PQO)5By2{zpKpuwOvsQZ`^SPfENVxZ9+T-o?j%T@p z+|pT@5$$QWbuKWYy-@TN?^!|7G7l@lN84{@oTq}+a2W!v|01Kmo;+u7zI7L_7Vwb%9*k|`c?nh6O)Hxi7Wa{Tp2 z&eN2AZl1wAZFtIyxzC${jgP5fLPXMWu~LmkA`QRj0@4y39^l&uqclX7s%nW5&jab9}8u_ zmYoYmyOke^gZ85+nHg0mJ&51qso5*moCOlIY02A<_&I06(L469s}(Clq`*shi~WAS zg_2Y?L;l8T?l1tKr{*8xl%r{WS9S?XCl$0$oMyL=Q$0yfb^6WRV;_lpj+HLj%satH z{?Vm((x%-IFJ_$Z4+)^xZ~!{+1!ulbC#P(YOf*6X@2%!*A>)_R|6+culHi@ulN#qz zjQk$&0oS1dZEBsB+3JUE+CS?Qfy<#p5fjYFLat@>d}ytSUZFALVTo{?Znt`aZs18H zCr@h&mDpk2!K<#^ko%K^oFPBD0l8R=$-4_T{mMGL7WA2-9i*GE+EZu56*Kzjkn$L>(K@#PNKl z0iq9;JMysIVPY0+w=5|XE_$>##(K-ce3qTHlk#9%hF{2_A7l5s9#`c%EbL8e&IW`R ztPRVK(z33Hu$}g&G?xk4;8ih;nR-9AlN?px$Y6_H*(-aaAq!_{vZ{PB)NS>)>vJee z=Xqgk9H=pomdLakdAn6)NZ6&GWGML8E1|*TT@nSNA?2^4o>XzU{H?cyl<-iX9D=Rq(b4XYk%iyY`+n%h<>kjoeA5X$+a9%_NFNY^IK;o4wJ) z-QsT8t!13Y~`_#Ld$D)aD<(beOeZ`-=D<0DpQnsBdZ-j2- zxzW7#{iSz)z@2rj z7)dc6eR~5Q)PtV93<~i0I0+T_WY*2D)J|*J-}>O!*N47O<*w434>B8L(;jZhV8e4f zZ1YfL#EA3&%c%RKIcNp%+GA#{s-7Wowox5CnSvnMF?NVNh!t_rKR%Vt8HMlQ{E6Ht zHfOi>;anUZlcb?34TZ9;Lc@G)ruvIJ(T!FKp{iJzl^12guk1{{ytNn^P4=MCY_mgJ zquDH(C)j23ZEXnF=3V?WzC0g9I?4q&(^B<}*g;zh;X0f1X6!89kVC8WRt{B28gJLD z;V5TNpb-5f|0?tKZG>VQR5cf~rV*sZdI0}{LaX^--O|zYTnokri7U49G)W;>;!%~B zJUN|2S@yMgb)Vg)2NXw%b` zXqwb|T6qz=N@98aUT1}Q6BS4EL;WbpL@zuL)2cHk@pGXJD%lgqE*mc^Z-&Lw?1M}@ zi{$;oSF?UiPplZ&)n(ORE~|CtlaY?1O8ZsudL`AxeCt|zqOV1PunK!V7mCUcoyw=R zSOVUFO}Lqzwkju`?vB&fvM2Jx(3@1SJ@}2T?q;{gbKU;pk8+RG5>CaVUynyo>$RT! z$-f@CuGFc^ft_I=a>=W`n%~~IY^7X%Pv2#RSH9=7x{sEPgtYC_axQFl<*Dcvk#1#8 zVLDRrUT9zq+qlH)s#EPJmaTV^t@Cbp{_(W$jhheC-gy2#oU5el^w@9Z>O|UX6zI&k z7t)q>isg*e&cpMWlarUdjS4MBXiZK*{O6?0n^`@VhvlocG6OY$FNM}!Cv>N^+<%mI zRM72ZMphKO%K@q;*Rpp;hPj&xOJ)A<4L#eeSo@9!Xn^)W`@H zrSppR;xA-$-hh2Lvw1ypKa_v61K$4rMxK$wzZeAVyw%U8MflWO%toZeF2`?VY~D@F zTdVE5llIil@q+RYyu;MUV!m`#qWej0LmaPA?Kh5Q+XkwMvR~mOY zk}UuDd!fZ*MtU-%malu9xr??A@JROT z+dXtORuuw7Ew=SqcrF8oKF$Pz^XZd#&?sv&c~lmLO{aIPh1?FSYYhA#t(OIWaOYBc z=hV-yjg4j0r*k^GoI0Niy`mJKvy^oKh~Vqtv3IH+&*z>?q85ma5Ce+5AU=@~?@f2? z*lxb2_R`z^_<1ukYS|uWxOl{KY6ze++_dYak&P?O-8J@xI>>1Cil<{eSw<+PSdO2= zIhj%xty!_YIieVoZMO^6X=BEMa@y~@DD7fz#Y^`?>DEWG^1M?sr``ANTTI#w{bw zKSL}k!kseKEDcP=roc21NO2X)p#e^^1+ZUp{Jbscp*b?ED$Srq_?(@XjfShxrn$xq zP0^3e#Y&hng6-gO;zQSlw{y>4WW?z70za59J%buZg6pF^jIkD7vxhi`@=n|{B6%?O zPpsu>t4M4DFX(*$5SuEW>b?=E2ee9v`@Ayh(}>DCB*vTLHf+N?I7z~uS1)cGlY<9_8~^e04E^JU$umM(z=Mr z3=NvHakRmSxq4?A+=GTz%IKQ8u&a;(-r`{;UQAB*dY|eZ^sQG?Hzvz=ZS2*7K5RJJ z&{mRV--@`S(!)<_htAqN@L=E;gb&+TcR^K@WKU{~>{DO~j7wwzdp-Mp59r8`Iw1y^ zS%sx=VJf-M7jm`xl$JAZvx7UxJsHM-(oBc^Wgifo-wW?+p;gk8*^hID&UEZfML(Qx z*bRN7^M&+fm5*J&IB>9gw9STn6IGcL3Dg(Qg{Q_;9kp`J@{=VzfmH0)y_S1by%Z=r9?I8H0$E)UcbR zfE{8FyP~LH@<(s#4d5*@jxS<#yzpGtb$FiI3(>77%)UESmxI?jK5e0HGX!FA{i|2O z9oNlWwXQjjkVYvWJKZB$4AokaKLb0to z*_iqG|JjV^U*;Y@$T?AbgZF|~M`)IjkPu^nOdvCH1iVc0*eZTq{#sTP6}&}|Ej1!L z4#>P5r+rqg@inIg8na%!8(*ScoHoPxY%6d})Z=@Uon zeua!q#y7EUt~ei!e$W;;3fJVW9|z;}Zg98Lru=^@-{LL5;CxaN%39EslkrL-WZCdL zxzkxiG8Hrvu5q8eQSvA3ZnL}W0-a)4D^GS=trxQh56_Cp@PEtY=)Zu9j-9-NE5)P|yVJE(n-%Eop zJ`f4Ce!rdB*x|t+!zE69g?BnnEz*`&r?ElCt^SLzw6Y7K8wJF`idWP1S~SJC8i!t; zo!Z@Q;y%4}HGy42UwIK8o9@zD|C^PFm(GuHjTu1&ye_O|&5Yz|q3fA4F7(fpo<7}N zBmAWL2MVG)+GIVAT1+o9$r|#htsO^Av|t&UDKJxZOAEBoZ}f4+`Y=Bv3*;`k#dB8Q z^llY~-Ga*5eWR?U)RsPY%25P|uH?6oD%{)ZRD8u5cE?Hz-(`%}!B|4;C@hy1zMf|Y z*>_a%sYqNuY>l?)!}c(8-e`oqSaaH?e-Uu`tsnEyJ}!!>%)&o1fYK;pUjB#u^Z_N% zLnCrBLyg2B;>th0yH(e*oHpNf3+_S!e2RsSSAToj; z#VO(db7b{-AnQk3v18ZsyU|^zp30Bzg`N|k;HT4`_aj&@Z@nS(Y^OYH?_FqO=tF4* zmA2C^zXkg^`;Ogg-LGC*URIEX;XIjzlkE23Y1s)_l~h9EbXh)&?S+ihuzNd1V{ZLg z`{J*=qmHDJWal@dq5f2UsW`DBg7$~fj~uU>LyWwSy8fiEuHg%GQA6pBvC1Qq zXAj-=#ARYA^2grVtLh9?5{v`v3)x3^awLzWeJBnN)b@jXI}^+VzoEQqWZqTw$o0w4 zXD)XCNj~!fyrC6vp35$DYd{Y&=k?sF_JCf}N7#bZ!8bB@BD1`aal@s?Y~)1@jfJ$n z7WBA~75wc%PWcyfuQldTpe?%sRdbf3;s|chGIR^~;1$P@HwreaB1F167-Hyps|E8iH_V&R3S940VS_ioe zEA1>6G}GAAxw0ZMyDM3FI#_?8?yeOz_3T&DuI#bwL4B(o_iFpxS37Uk{>ej4r+2FsVIoDY>t+M1Q;6@eC#>`5woH$Qc zP>we7jUskF&?-DxAlowo=a$p=FX!`4?&Ra;O09I@BdceI`1*dHZZ?LUfHABgvejfs zyv_bV0{Dp2D8(+B7f;d(8~Bk1K-b+7d^-J@Gszb@PQ`MotfBR&ATp3GuKa3=hsAZ( z0Okz;(r4qTwNw{+JsM;La$hVo9C3Ewgt*cgm63FZDt%JJhr+*|`_)v!W(V?VCEh}e zS1ZwdWU|(Dm90R_n~`d>VgZe&9?(FPf*#>tS4@*W)?2=|HmZQ-;oNukN zXpPS{KUSVMV-H9;-k=_B!~+t+s~M@PN0A7dgIBOF8?4V_BKP1_Gn4XMtOHAFMA~FQ z%tPGsMyOyVi*|mH5wHP#3S_N9zKGp^W7c**|2rGl-OxH(+{u0JG-^AV?JL#p3(m(6DEarvKPtBkT5<2hn7Pb&>fe1YfrLiuo zK+iED7o^_^jR&H~J-(dP5dFe%JiP06SBwuaqZ*R#uCc2^T>)tqTeKn(!lC!%k%rl2 zLV9E*M3l~qlxaOQpdqzd-YN3taHn=E*T!q@Kt9hHwWTkSA@9SI+u84Y;L~Xh&#k2% z$mn)*FK=rMMnZBPP zhu&gb`=OVWey1Ii1Gy@7`YHhIbda@O&(pjGp9On6t4r1m z70IvlB|Xq}cAI6khQ-$#pV(ZiwV!rXRvt@RM(W+9PRX{`AI0}vh7_$GHvvM50!IvN`Q(4dAf~UqoJ&H zy^mQLkN&LA+8x95Ii<|%z;;Q&tzv{brS zC%!%Qv(ffj$=-R>x!OneL_N%<#3D{(ej)eY&9x)B@_I%DiK(Kd5oCz&&>Ip%>btrK zjk!ye8gI^e$Ty;z_YaUdc7qgkbO$0QJiMp}h2=3J#czE-*A2y}*(EFT%iVFzyO3qq z(2hlxjhD&5BWUMvT5V+s{_;~0JwEUz^zLTptJ-!u@hRMl6IU~5bT!hV6RrCe`ebd; z5-O1)REf=xzdz98)lg1`Y&Ta$`esA(_(u8azQI>C)2~las3&Vtbrf+4$ z&=EJ~cG*YuQY1%gS8Ur1geS!%FhDb3R?1|P8yD2Vm)4{oZnAq!zLkxqRp#1=&<=e4 zLP&DM_7BM2$&f$=%_KEbZ}#}-J1s!c7eg)Hko_k|wP7ehW98eO;Ng7?XA=R*zIKJT zzex-40ZWMANu`$1Mtsb^%HQA?U&7aQ=LZ=k%RIFe1l8tH5-P~is$#<(&+X>9lc9mk zwmHE3_y(Di4KbbTcF@R7$#>uhD<~eM#r1hK&$BSE)*&jW5e&Att~sxUWWcfN-&r=OMWiFk2D4W8pc88g z-R&i+lsB^P4+6JJJD!%DCHlS4L_7zz@kkeANzs?(_)gY0WoiC?^vSskvXp2=m&}Cc z;?r>Og|wy?@7}P7b31dR(QA>V?c9O;PF*cJLl0815K9dcu)Fe_xOZ^4Q$E5Q1KhWp zo@fQ%!P3EHPS3cIIW;eOJ+VV0F%c2&Ig^(3qcd9tDX|y!Ig8YIXHkZl8!KQJA}3Z{ z-pW&pd6t*=rnHG%!7;hUx5H;R{3nJMwsV!wkP-NU^aORtqyC$;^^aqpofXA$IpvR~ zKvj7@-b?Ef$pg|b7VJQtae~r^Y1!Dx5RE|>^3Fyq$IrRvr$Z?fMQp$7fW;uGB2Ure zURruNeai)ShZ9>$BN`Rp0aQ*(b_R;f@C*^Ui7vD@Lf;Tzcr?T}_T5`MR)S+t;TQGEt`BrfSHviqH~gilb1{Y=i8pm$;s z5`z*X4~ChF&$ZUNlTZ37z38ZsvybfGOfZ3W+>CMqER86iR+D*raoQeBPODJM8inzV z(7H8ctJN}CY?Ek+XJ1cSty@7^9-9vp4MH5|J=#aWR-+v3EV|`6*mfK`o@>tOQv+r8 zFj^xq{>muWD)B&LS{Y0BV=LdtC6lPi0-LVpseE%?9L%eV{OBypX7t5!Y&cC2!?Gr3 z#dhR!YO{5pv zfr=tc*c3L0L9S;GPeZL2(;|E2?Ke=nSX?~cj8XF^^wDa4YXL|&Z_UQjfA(`WW2(h4 zR(Kd1&>LCg)A@ap2mzD#Rv4%*Y`o3~iHx`M99+g?sPE?6RL9FnlO36ZQ<=NSU0aZz z`XV-2JV7ey0H4EZ(4+Uq{#X9KH};n^($ksaUUX3eeIhNZN@daUjz_egiJh8GzpUwn zpnaZ@hSNpu$#pbFzMJRZH*XqJe?fA+8}~+bN1HK@Obk~2*4Up6^RlxRM&dnqz*&2_;W#i}bQ(-s|*s zs)&!}|3m3Rdm<}YP-oVX8c5~k(8`JH|0R5rcRHJyK1dIn!#-A#H97x$di_S;SPScu zF}qWt?TK&X?%3u#>0d3ZnLrQdfugEN zSZX=6g;Y$_05pOv$tCLnpR%jy{c2`^EYum#h*Z!YOIy&&YU~E?%lDkhh$i#%O1?vb zsXXGd~L)gUGp^N!FU>tPj)q(R`@xz9dAySUUdhzEFOT((-bHCz|v(`H(ZL%3kq zruo{tNLyGuenU?xKWPYi%DcnQs(@Htb0gbkA&#O=^2zp{;(d2{tHqE7BNa#Ui48So zp3u118j)1%gVX@Q7W$e9chu_3Z?Kd+u3Qu!VD97uKUl`*ob+N;MdQZCZZ&tIN~Ss> zI)|zB?+R;7$N2m4w8?fA9~uX3Lou`CVVnS_Z!*H0$p?E|r3gucv6)YM*`oo)zMm^> zvN;{hQ>-b=^C(wtCohVIJPRG7Whleui>iyLT3;#;#qZLg*3pRbyK3e+_EFz>D(3XG zkuO9lWDez8pnC0;R_qmR6$OxD)R&d;jJrr5?(hpf^K$s~o8dIx8jalzR!w<%_psdj zt@V+%#2(Y8=W-{U4%6`WdMhqqdwJ3Nx2`SQVT)?kD&YzE32pG%KdoE2mZ@V|eB*cX z4G?=R9o)w!?BzQdI<;NwJxQk7xXjL(ZD&m9&BR|kde?r7{jCP_4`f#iDe7aZjh~g0 zD>G8MZ`FpkAT>|&1EIPq`zvA~dypfaQH+T1T84332dFppy6;w=j1^)J^~SH7Z+kL! z6fp|zR))$q`D&by3$XsAesw3bGxO%vSPD90=H!Dms$VN^G8Q<@TAXYXn$joW z<|fFJsxA72)9k*SXeAD6#Bcv9pCU5*YEMTtXb$bdJDFPi7pc)HF@szYoq|x=E|v&N z?&=BZq6UefWAcIWBx|YiYA%wD;eza3?UT>vGhR6~^3xd8=%KY{hftCPK(uG`NiRm{ zUOpL3c@CF`vVEYLvHvOh_(|#KlFIEL+X8kp}jmI@1_OY(yS-0hR}K(7^yT$qegL( zPz}c=^r11RN>apU*JHOUi8PXYKvQ^+2ksnYK5H+`l6-f5K0$!>aqOl(&7LQP!?a+& z)i+TFD`zCohf^V7PhO~bdi~RxYM#hb1YhvEFz&=UI(ci_y)v~)4qLb(OKIN4AJ<>lOA=3Oo3&#Xn& z*SoW8&mc}A=>b{<$|i$MhmjaGIo5U=R|Vi(TA~{y#fW8sd`P-E3`sP zu8-=dEV|6;T)0tPA*Il%XqqfTM?Q;sjms6)`_A7nYO)}zDfnE`CApR$j0MGSVyw%J9!L z^h~Z^jTL0U=#A&$W(ZEy&M#X#d^YSMR?kCvXv2d%*Q!E!EY-$xx{wSVC4I7B&eHcu zmWFJ5vv#e#U5AJBgxFBl6n-XuEXkYsTy2pL)*H@}E8JlvMJ^&=s2>u5HGE>FpoJnC zYnNurpLnuofoqHI$Ty$O9da$W?H#Rlk{rultyCbzEU{6MF1q2r}N*ZB>yeK9!!+?PEfD^POdvet0Ep1agbY`x5jGYPL|hCBJ|TI}O{nF$*u<8nMLZ|8fhz*CYn-Ge-JE!1GW>58Z6u-KG- z;jFl%XD(HOU?&LwgS4Y!_hNp_j9-l8qahiA*4XxD`DqlpZwHV3E6gBs1YKGk;e)Kf zl7;Gt@yIZY$7aZFtz|qaUc5sMBKHJc6~Cx|kokF#zpXLQE%7PqQ8{EYPv&pWvH$o7 zQIas4PHtE+Bd1H`L~hw!RJk3?Rmg?@h%34?3%x}XWZPN?ZKs$1W|^9$XXmtPW~>-b zK*Q(*k6TaANe)_dpZ4eLPRfb+GeV?s*050nOowWi@HGC!N(G@;{Q_^a^#2 z%X&;vw{gkO;(hpmG4mlL9RlaY;fG>(Z|SxkVMb(E6~D?l_43-|m91jZ>t;;NUwtvF zOCw1Oe+oNTk76m!3(^y3tGlOP&WBTjXT<|8j8`s_B#SeQ&|TKGS^e&S^tn;s-GQ#6 zE!Kez^R%AeRZ+OO4W40nc|!lIrD<=p!TLF#k~clSohOzeKh4&V8`8qB(?Yy~nPlI| zCr``nd(RD2V5Bkur}HhpL)YCQpP*kSoxmwHhrbmwq6Djmh9dX-u`NbTy3dDGP?ZK$h#7R0y4NAJ`m`G$QUe{%n$tf9ZysP0+Q zg7}X;R&CP=8IQ%$G=OYssrbh%^-Y(mb)IKuj1eBkDM(n$qAK(yW7?4$H5)o*l>?tg z`By#yjbKH&tFr^GSPHh%NZ^R-4*Mg+0O62yUc5FAUZof0*{QN}Tl5)q@Q@!>9j9O3 zz!e#7wN3bhEwSb-(%H-3jd1);BGDRk%#CG&hm3|j zwCXe!yYGE;_ER|vNW6fu`o}*bA+!2tN9d|6B!a~Ggyv{3lQx~to`Zj70F8h@@~@Zz zPtA?h;M=UA;+|FBv(X-_W~cJJ8iuZ9?gR^n@|{Gqm5KmuKFE*E@XcU=yOGw{#{Ot{ z&iakfJAISRsRltGnq?oVcTCY67ToCZN8k8sPpin@svEvr(P#TTu(0$U(vvtBBWMBQ3{dl#%%b{{N{&hB)q=v-u|28vzL4*To`YkRtj{pd9NY6mPY ziiiAQPl}dRwts|A5%1i#!fFx;es~bu^vT!0lcKASaY-m>n z$sCKKj?e1X7cwsP5_iyEO%v%+mB5bEqGo&bb}4h>Q^+HH>3$Zg{;_tW2*+D;nlB-% zEH&FeztK*_sh>s?H0)u<>s?SMbG^PAf+NG#d-lrPmd!~7ENZ|N>owjQrf=H9zCbeQ zi7w=GFW+7a&E))L8_#5HGdVx*w{l%IPBV&lLISLhS(`v_nhxE-kd5R0FWM$aW<+1f zDvI*yE16p(2ePd)ULU*$&KeK<-?t^v5l_((&zBSKv9}(A(x-@tHNT%box;~h2@H1AoQ>+9+Fa-e4yTJnb##|qqf79-dIyLV&QnK$_eG}7 zk63SF0gD#u6>5D%_G|zfV}xq|c%qA$3yWoc)A`sJb@j8US>gfwe2}M%=Tt_D8zdVi zoy10F=teQ9=%0MDCSQmIbwv8GE{Wpx3Z#zD*bXXCmrH9Pvt|dHU1V3SI5zW#PUggC zLn~qzm7UN+tF9|stbjUz(;3sfM9n=_*Se5c-^pobLOs(JGg&4{n10&A2l)?c^<#Ie~PS=dUDMk z^o%#v!eS-T`TX#$HBGij>+l(`M*n#-<6t}JTUWGX%B;eu*k_^Tfr^JO#20uQ&#`DZ zi6E<+`KtIHAHWd+aCiD^w>6RL0xlMaUR`cU* zMzz-{2XR1G-a&27bP8(k88zk)hp>{UT$x7giL}A{S#z*XA@3)b#|qI>s2*xk=_Nx% zn|T;HI^IXzEP9*EygmC`eyayw55Gj{b|13oG+IrHk=VPjnHFI9qv>TSHF!LadO0;f zFxPwGg){MGV`tc-^_wO6Oo=l0hh9|m zqY-*rm9ei~jC3LLq3YSQYzFHm2IraTEwv_o=sL8_iB|$KUwrCIi zX9MY=_YQievb_7@%mCF%I7)f@om?|~LXynj?XmxwK95Ca`CslVu_3!-(wHD1#8YkQ0PWYKG;~9+|dd1_F$SxZOgOfs^I?EZkY$uJwu0z&8H`bw-S4%wkpev!%NL(#qD zmgx`FNa7%XiL_u$)~QvLc|y%xk6{1EZ0>p)MDFNW|T&^`a|}lzRI`F(a{Z>Vdgw2T@x|mPygY7CuxAl z&&f5SyhhjV-irjQe5({|YcGGXU9h_~vYW{Bdoc~)zgES*c3J6B*#cyCpYvQ8m&*^ zv+BPZ+mN5F;XYQcu^~OUnASwK-ZJMZtfrjO3@S#F8`U$z0(l^*~yn z`^)ircB((i9mjK3kL)ok!kT!-Dfallrl_J}KOqmbTjZ*GAolLLNwCB9&|E~RenKt< zU+oJ~(JT+qI;yUY?8ZJ?FX9okEIxTX|3mV|;~IIl&km+zulYkg`=+|dTF@$2$UiTx z=CA&RRmXc(jPw-pZ0@Z2?yiiG6_TVDlIEZjRA#*5Wj({g-BF0ksEE_7Wn*EI)03*< z-+E0a=&G}Vi;M&%;3D_7evi!AkEn|LaYjlGFGc=o+QN_(W(CF7MgisM5ly^F+*z+q zP;GQ2Z|UCFmC!Edypg-gpIM4m(Syv<6>aiPye}z+!ssak$|A8FDr>dNl61b-muVQO zr}y|l`q-`~xs$*4z1(WocWT4U9;ejfu^~?Kr&ZQX(45XVYaJ)$kkOb0w2n!F>7!^& zR5=@}vDn>J0(FT>p%?ro0sky;Zb=me<{lKKd>BG|0s7}&-tla#ARx&4%G5$u%lS#>hw{ZG!OF#9HrK zJ-fELCJU?h2Hug?GzuONej~}UptTCH3(xI52dI_Z758f!NBDAClw0|pL|H%bzBT;V z$owCs&&$L5j~$-#v22I)Y_o1E_~9&j|IpX>yilzxp8{){S@XAWo_U$AH3AU??S+2t zMGlQc?$(1C*(0Fk2gP4PSF zWtT}iOMuRB0ZFjJ!3sM?X*N;U$g8}W`|J-SVWJRgz#@Oqja&abem+$$cRmw~crzFh`R8~H@_7TnDV!rxfV*d;IjmwZA34V}oQHxGD8$*}tG1-_=Hjo`M13rN? zxC2GS9rS-MKcp1B<%eNGwxV8)P4F%~GDccxIvL;t^h_?Dd9s~vXr1*_k_pL~10RF7 zd^!U4;q_9o8mdv7@1zIhK|9)p&~V$R&6TG!6Mjpxjv{>;0Pdin}`Le}VuA9RJzJ$pRAX*Qmy zM!+-ZpffFo$7S(^r;07_XGG4CL7U6@v=H6i$|n|-Y*-h9m5qy4q!Uh^QOiRn_%t%# zRStSVZsczkgVbnkEf(*7HFFm&@dfk>0uphNEYewf2*>h5^tD-}@|-^Xh=Wo>BG=4XmBHbc<_gh6;%UI2k51bA? z>5FP+vcHo#Oh)T)fyYth=Co2d>-J_pIDH$f(VJH*P>dI)oeoAqXaf{Tf8^8Ez*O&L zoOx+~Cb$*?HVE{yH1K$Bsak&e*Un!oTwaqjN{j$IH+h zUPI0apV_6_hxN48NCGX$HZS6S`sQx^^H7ky*wDyiaoLF{`J1I-A)WNan&Ubxrnd{x zdD$emM&^iI?&o(khX+#IOMh7G`qXB&8%t`5#uSaqjac_p#n%-S`4=)o)|~5jB;P=H zG*RZy82}`KopVM6#j&E41>h zv~+6td$enZAH@6mB4bBA5@E!$5VTm!&crh!(xrDEja;flRn4MjKb~}5T;ht^I6(;F zD9-o$az@-b5A(uD=Us|RTCa@$I4o1q{1gt#Ht>xiMwEj3MFihX|0iT#H-Ve7~?L?Onc%Y16% zMf(LFYPzG)qBZ=Egw!kh*sFMeRJn^56mkB?ZzpiDREZlbPk^SyFtx%DscR&1o?h@&V zyZP;%P)(&Lt(zRyNmbPTd|Llz?kH#J5Z~T4dmaf+uo}nKU(I@_eYPt(Cqv%c83GT6 zQQr=QSS&S%;#IY&@P0{2g zs6~m|1Efuqq-~m{Nna%^+R8Hz9r(d5kcz2Ezue5!4)%gb6qH!(Ck*ym)m3uDT{@l7U47*l2;;5__>Ovw+VLYv#Nq_MK`a9T) zY5K(z*^K#siZT!oH}hs*vVkSjxeLm$B3Zjt5>|yhmQ0pM;IQ72#=Fg;)vC39v96Or zggxmQzpSc|-c`m6%~YMN$y1Sosc?ui`MNl56wr$$!FVxH4=?2TSE5N(Be;l%`;X66 zL}Wjw(yPBuzwA@7p8l_L#Yj#>ES^`H#)-5*w$OoRs3RXu+dfqT5#e~Vw-2xoe9trB zCahs6(~&=m`%(I#4gRvWz8 zoGU5tEIkl?$gR$-qA^}m=up{|NKlNYu$vd~L8k>->CXmOjMInMfZDOqvSpcEFsaW$tM*g<1mm>hq$b)|aCp=grTRIgc>8nYgI zzIdf=c&j(|@rnV)LuPm`$(~CqX0UN^k;NJSoLel1Zmxmt@(SAYjv9==_zS7Xq{Ri? z+2^!?HHtudj!ofctj*ph_oy83YTUuNF>}3=5nGFfmC6Spm^I5V-B{HXF`7NOvb0?c z*+~#|DMG^}EJ(ebean7TlAS&3?Qq(LZ*0as0_R$=O6Xy|p1qoKBP`lTtlkrUFu|Hd zuURy9Puemmvq<=`+zvbQ$=;Ui&6OKfu|K1Q3-H-27Y@e4*&>Nd=1x4LSf4mzMUU*1 zkFX?c&y{m>R$V6-{d*#D!mP#FWWEmvxvXXIog(R zoT9jptnK;44QDDoo~WR5QTY&E;XgOiu61gBb~c}3oof6t#hKip()z!-P9`#%Z_R38 zxmq^EE}yRRuboi^&E&D-lennHWgbfn<(2$r2M;86Vz;rgr^*)OoUqv%XLT!Q|6n@u zEizOMWD&H6OV{6cp;|4};4iohuOqLjkKuF1y(@@xBMf0Sn&fk$VA&}jm`#5Q8TAhz zqYGULH6ST+WyE&;*J@$i?nHimgB5&^lvL`gw(47GJC*C8ISkdV zw+?yZ!a%O96_$GKnIr(=#k1M8bTB+r6Ucm1n+begfB% zgRo~SUbOsTEa-3;$*$Z9^Qtl|!a_&v9^aCsnUCikD%IZz19k0S>#qm%DaHX`<$YvF z&Q?%86HkO%s$^o6p6gr9yniPa$Rkycy~7oc6^KR{2}DNVf+4XT*zP7fzDQkG^N5nhQKsS>LbkT@@x=Pf6%Z1mA|u!)K3_=ta$Oa-M>D3I$;iy<{y5({b4;8X z%pGb~5R!k1CiIJUS=01HDt(_$+wv6=zuquF#*m-Y!%seB<7$;w; zx(M%A5ygj}h;`|kj9kUI5DC{{GxV$fK8Mt&@|#>(-4U7;`(XPxgI<#w>%%s!=Z9CG zT(X2p<^RrjpeOiq%!oyTrM9A9XV(gE%PGWp8$j|z`9~h^Czx#5fGi>!z9m8o{Vh9{Lmh4QX|Fb*` zs;Rs|h9h~yg34#ja)`=P(KW5$#1(~An~QhC88s1=3ONi*dN$X?roV~IPUL&1`ik&3 zT4S5TX=^(V-}QxDQSuzlC=Tb2U*@XvW__<%%Qm2dH%ZAmKNdVQb7Zy>*L)#%E>5xm zXQ;E#??-m(gFK$^9Sk4CAI&qstLJhCQ>{IaLA(&J@fprV8)6!JtDM$nl1DN9g^7LC6X`M?t+{B{u1+RBZR-r{%os0&z z_6a5_7qY&bz2Nl~muXyl5u5aRELWb2jDC=3R9qHmWqdd#gz~$w;DvO|lPe~P%A^2$ zjs@}TSi6~#h}bNLuUR8kb5zh9BcgfPsQo8qq>8V|xcR^Og}H{^t>@68h$>Uj4-%LA z!S0z@8Xk(pO+|Mi4lET5yc-^`BP}S+t6`vbtjY_Gy>bVmekpB2g*t2V#Yoxryp(oW zBRy5VU7iAqUds1m1*xz=5@ui8RYS7F%bZu)937QCR&R`&RLcSjWwk}VRiQR6tU`6V zFw7{5VSBTWJ%RMYWAFp>=aqf1Yxw|AB0-2+d2p?F;z_E2Z0u%!$GyZUeqh}ZJrx(c zoZoQ6ISxjL(Eg2#teD}Y_=u_)6pn<(w7R++ao51?&zxvG0P+Y(-oSe*nSGU4HR z$>vC1E5)bH6v^4Nq9TMbd?BBl%!)glQf#dGO*4IL10tDC6P;vho`s)ddLp6x)!1K5 zOWrMF24cLNKYkhpyyvYYH(9YZ%O5i}ao?FPGKqNvr%2@3Njj1s#IHO`h#-aka9SAXpRI+2RU#?bvFm@xH^naC-LYWk z{dpES$Xu+ty+0`NVp>rv!AV6>?Ixu*SkcXYyDhBR(3%Kc_d) zsji@Zc`JN?ToAubq$@A1HwYg~uVDl4k?E-6TyIGU-cRN`b4zkebM3X8?CDri)m&Mk z6B{9sJI#Ku7y7X87YdXYhy^kVdNP-%Drva4bPR`KP%;rI@fY$W%>cl~2uwZpNbtXipFRI4S~$dmCUxsXv{4*d3VA_YDpD|B*`vrw#A zI2POzV=EJfY~mPSzm%UamtVPp+*z*5rV)~^-DgH-?U(s|cKS@~>Bp1F@C8UfLv?qt zI%uF?j|Y$pOi*jPp7vjkKk56GNP~qyW&bw@byl{IYk}l0?Rdl^MO~ zutaa|GOJrx(-ymxCz%njKE}6KS{XWv#}rgmjz)4uAcw8lBryj{l$OasWG6A7=&{!7 zIaLEw3H0uU{jEL~T44fn++1IA$Y)X$6+`@8u|?fi@1gW;#!pJnllM3`Lyg@jT=vz# z)MC_+e{0vY}PP;8vNh$wK>v&9X#O`;VYwNO!<#$g*T zF-o{qa={?A2qAm|4|xmEQBzfGs~U!#i47;{|M{E}*JWpHo2~?o@@;C8DeDoJ<@>i|1r`f_pZ)d^s zjNjcX+E1ARTV$CStyRXARXPn9va<`dAN3NFvm%jJc|gtY?92T$!Yj!MtAUMdNWC8u zSL<~b)@^ikx{)#IlNp|id7d)@=m%0j44fA>u`%An9__W^bIuvEM^k3xowPC=3{SkO zk@+d7C_iRzq*HUI{xUK$q!A;~0>3jx(lisvLqsEUez;UIP?}ZIEkrNLtDBM>#s;nU zxd_kVVLT4?Y#4!Ez2t(!)cENi&!IQ-vermo*?fRs!e*M#e@`X_81;1U$8*XnSf^?+ z3so)T0c?7AuCHFyyUtJDj-2L4SVR4z9lrm3SeD&VSd$u(HlTE^0T5@*7s{@@69o@} zQCIU^oP||s^;CGygZbxtfG?)Co$-M^iHuJLAOH5Q2QsFgr`0D~S>A zKg&;6fvKnxA4n^zaNhXj?Of026oqf)4(rO~xsnT+jno+3tCrMWBD9nXEw1H@GGS!X|kb-<*5j z4Nj=Phtp${sr=A-Iz1OvVVIaI|HZad3A6+OUeCWX$%<5CpbvRCQ_-0>P}6&9jz$B< zA}*AdIiulh&=3Zh#~z4HICt&ua-ChdC&O0`r1Eq+r(2j?w3-34@>JXv>Qtp31{_M; zawGG#wSt`A@p7j$IB(Otq@AL1Dm{5ApO`xnS40miQEqHCN!4{`@}1KGVFDeQ!%(TW zvgvqwV}*uy*{Zra0YXLVe9Lak>3bVUtt$2uXF=QwBXWiHmR9F{J=eUJG0I|85iX=< zmM=>fYspRJ;z+C==7~7RA`j;n{by=sPFwt5Z0&Sx**XC=F1p&D+VjKd)f>5EBA=Pl zbRL=g{jOpWW^OPpEKPMub|(JwShJEd`6(-{7!Di6p~{?#wTg^pS&!xw59BXCG;;_?-xBBfq?}%sL z>|X|pnb5wgL_SYo8+a!>sMQ^+cs!%tbih}h4fu?RScegIsf>i`lu&~!>X`YWSdVFvtjSW z!OVr4ZxaJth4Hu#7qB}E&fLh=a<91&=qZoKWQuRrsK~h(BD*rP&TqY|%cru|xwvJP z2oJuP?`cDe6(0+Y{7x!#QWycBtR;c(6)kY2{jmYjQ_sl7epR()nS^;YtH%$u@aU=d z1?kg|H`JJO{m1--e6;#{?ove%U&z6ZWcX)?zWkjOcoezPr>IXq7;5R&=NGM5tjih<8rMVX0Gu4pdb#cj+%TuhIgM?=S+1{L|Z2CgMKh+}5GT4%bi(`1d@km^|szH3rAV%mTaiOe?qJua)bteAmii z&9+z_bU0N^Ok8%UCW=Ri?R>_lrt{qkktu1BeqmSnYDIE=;?Hu3Iy;3XFeP$foBRt0 zRMFX&zgc2&z~YVCgj3Lv2l5g+Ghe~y<-(Q!>5;0)<=pQZT2&+Bmnu#*PCFi)tLuBH zKo8`mdR)7ra3U5`Ya!@f#ZnEY);QNFWeufYBZcey$(efkpzqd=IPXU+wfcjNUC(ca z!?noEIIPvux2iQ+m!)dLI=+|RjN3D;fWMS`d5^we^!(FvEB_OrR9RF!*&wv`)X2%i>;NKP2boIDf+&03FyHut!~v(;(c%#7up%VeS9|g- zEwQ$jVxPFJ&(&y1$a4#AD%Uz6J<4WCgxqLa2EZ;vYW6Jx`oH_e!VJ&mdY`ybT(AGD zFkltNs&>ZX3$x9rn#rO+-fd6guCUpw@s`i$cl*n3Ue&9M;Gs%@iwZY3oJL?xzmBb-d z`(o^>s&ZAk)g#`=c1njvgwg4R{Qy-3buDRFbAuT|RaQ@4o`dIlcI_B7yLu=n3Wro(Se@+FIQ8GiCq`F7o*Dm9nw22sUJejP)n{R^D>uVq{os0Zr zTJvvYUTkg)6opK!mJ zVD4Dm<6>Gea`JSZ)3L;5npSD|RGNZNY#**-$0BJ})8GO9o{rwhi|2}O5Z!D7d1~JX zc#aln;YcEoGoJc$;=6uc&2zoO`JvqNgVr-4IT@E0RGsm!ioNif9%TaFE2rXNU7j~| zn*A{&Hxx;Tb9je22cBSdYUcLnoyqo4t`~D;O!m&n0=!dhGXLj0Y@5`@a`gjRwg!e} z!wRfUrF}eOBtfxE9z?bzRC34MARJ7xS{ts=y!l5Gqh-jfDxe*6a%K#zgn+G9j>}u^ zQu>el{8WB-604aqY{woqbJqB+h&vn}^^-i4_T_}|5F>_D_@fNg3^EDgAVaAUjOVVI z*w4QFHIkp#(_4GKpq#2SW_Tiz`!91Rq=p82(?33E9sIfU0~_J(cJ8uX8Cfwc{zuN@ z9Au?ye$Q9R+Rx`%vS>3MkEI1Mtuhj5huc+FAP>NEU(S0#zMCBN|5C3xmp4jS1;wYX z=l?1TFau`6DfV87=>2+5#x8$1uV^2WEPpgV`Lek#wZq~?=85T8Jm77v!LPAwwkL}7 ze{!s~78y1LpLRH$3c|9J5c~yu z*K4cJ==!nTt4c@La&&LaCe31XH}e^uRjeF7sI?XLLpVOAE@BO|zQ}{|RW%?;L6ho8 zcE}zI=dXAp=iseyMIB4TuPDXC*&PJ+Zu7IT{AOh4tW@!)Sij1S=g6GQN8!U53mvMy zlZSk^>Rq1f`d4#}xC*_`r(brELKcW;iJU6-!u<`W_j~`L=PwBG_&&rK+IVFx?)B*RY#c zqFI`p%+Kh8nAwYnSQ6KQyj6Bl`>_`^y@@)I&utpRfJDjBuiCj3B73DHGJ ze#5%xkJVyQ`X=_ngj?y8N)XoUXH~suxwwcv@X2C=Jk{*SRQjT4ylW))>`eujesQcf zVhQRUKg?ZvVQt|^o;wxIRt-euR<1gm(Hpxb!4C}n$N9-eq2Jth=( z2`kxt)!x}ZY{er*auz(BCu5bj@)u+&#Br7Arq?7$UaVL3wsW_$QD zRJa)QJCZ%RH*=jxAomw-RnxUc9zGTAf1YpZ9J)e@T8rq(Ct_jNdRsl$%x@}VxPEaC z%wj5b!6FNDA%-?ZA?WPydJ>%S#Hz~T4wxn#7#m%lj?LM3D!vl6C+ga%->yw|UuximgC1H_>rsPaj8xeINi1sRTd-ASdv|6yL0hZg7 zP%++f*}qY!e8C%W!OB+IBZR>m*qI$cU&`O7;<+k1@&QqtG>p-xq3uLIV|D(r^5{zL za~(Wz5*v11-+6lZk~dLeP--jc{bs-DwCo&KS;1`mtNBP#6=SmNkx#+uVtLT1Y|S~A zg<4K;u!>oHfH&I2>dkG6rLMB)(v#_4{GeBn$GJ3>H>)wx*YUK*J7iWeG0zpPa5l5O zq^{yYOa8-YM4}tH24?%o2AmL7S(w#jtVh(u;i1^YjMh3h`?0jEhvZB8WbCQx>@Yyx zPXv+A%e!R8uBsI}P*c36BWs~BMDx=sXA1;(i2vxZ85cG{mBKZh74@e_hjd9RTTSIoi}NQoTnQdi^2+Ey_2_Lir$M< z#z*2OA^{q3{TvFZ=+-LX--ZxcYbu&l@a5u<`zTSFKsyxclB!`S;QM z+v?vVf6(|?m*zM24es2z=>1=5Z0d6V2KNtquDN%5aBOPmVDn2uW6i3~y|FWbxZOJai&$Obq{ChK8vo5J2DNAe6N3PxY^Jl7_~3!?w{M1uYqsP_u@|)_jei5`gPqz`1bL1h$ijv ze6g{v`vaFP!;K5AxTUu0+Fv(5*ZrBx&H~y({KLk_ySKkE4QMyegBF(LuWj+p2eFNd z7;;IZ2Q92$z9ny5(vo}7!X4ydZ6Wrch4mfKq}@OdT3C{IchbU2eW8E(TG+HwU$_IZ z-xgvITIj^~+YR)fg(d0!Z_H!cyf19(;0x``dhmq}U9`|{to_UmzR_HFPTF|~|3*Gh5zhwH_ z9&B4CEA)pvaOw25!!x(7&>!;59@5uErPy5$d4T+N(H45p!;%SXV<9eW>);RV%ewnR zf42m--PpYDw{`G`_GR7uAVwdOj`cxUBms059@)w)6M3==9=GZyeIR0?{4+M zJ=^bI{hsX~YW&uY@zJ%TLpvtN_l@tGT06F5_x>HLS8ac9qcJ$Nc6@B?z}nG~)pu`y ze`EFB-&2!=Lt{f@kLIfT8XunfbYikOGt}I__D$cuukoI@G@cuNyYZ2?e6%(*lFiAX N(dO9H;K=G#{|{Q) Date: Tue, 19 Sep 2023 02:43:32 +0400 Subject: [PATCH 044/121] [js/web] FP16 binary and unary ops (#17515) ### Description Binary and unary ops with fp16 support --- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 32 +++--- .../core/providers/js/operators/binary.cc | 18 ++-- .../core/providers/js/operators/unary.cc | 100 ++++++++++-------- 3 files changed, 81 insertions(+), 69 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 7e52954734216..f08d7a77d1099 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record): CastAt export const cast = (context: ComputeContext, attributes: CastAttributes): void => { let func: ElementwiseFunctionCall; switch (attributes.to) { + case DataType.float16: + func = 'vec4'; + break; case DataType.float: func = 'vec4'; break; @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey { } export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfoLoader( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` - const clip_min_: vec4 = vec4(f32(${attributes.min})); - const clip_max_: vec4 = vec4(f32(${attributes.max})); + const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); + const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); `, attributes.cacheKey), {inputs: [0]}); @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string) => ` -const r0: f32 = 0.3275911; -const r1: f32 = 0.254829592; -const r2: f32 = -0.284496736; -const r3: f32 = 1.421413741; -const r4: f32 = -1.453152027; -const r5: f32 = 1.061405429; +export const erfImpl = (dataType: string, varType = 'f32') => ` +const r0: ${varType} = 0.3275911; +const r1: ${varType} = 0.254829592; +const r2: ${varType} = -0.284496736; +const r3: ${varType} = 1.421413741; +const r4: ${varType} = -1.453152027; +const r5: ${varType} = 1.061405429; fn erf_vf32(v: ${dataType}) -> ${dataType} { let absv = abs(v); @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { }`; export const erf = (context: ComputeContext): void => { - context.compute( - createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4'))); + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const exp = (context: ComputeContext): void => { @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => { }; export const gelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfoLoader( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl('vec4'))); + erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index 98f7ca6e613b0..e61cb1094736d 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -6,14 +6,13 @@ namespace onnxruntime { namespace js { -#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION, \ - kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); #define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ @@ -22,8 +21,7 @@ namespace js { kOnnxDomain, \ VERSION_FROM, VERSION_TO, \ kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); JSEP_KERNEL_IMPL(Add, Add) diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 5e972e43e4566..e9bbfabcf86bd 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -6,22 +6,29 @@ namespace onnxruntime { namespace js { -#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ +#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); -#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ + KERNEL_CLASS); + +#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ KERNEL_CLASS); #define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); @@ -29,6 +36,7 @@ namespace js { ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); // math @@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg) JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg) JSEP_KERNEL_IMPL(Floor, Floor) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor) -JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor) +JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor) JSEP_KERNEL_IMPL(Ceil, Ceil) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil) -JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil) +JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil) JSEP_KERNEL_IMPL(Reciprocal, Reciprocal) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal) -JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal) +JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal) JSEP_KERNEL_IMPL(Sqrt, Sqrt) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt) -JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt) +JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt) JSEP_KERNEL_IMPL(Exp, Exp) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp) -JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp) +JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp) JSEP_KERNEL_IMPL(Erf, Erf) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf) -JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf) +JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf) JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid) -JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) +JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) JSEP_KERNEL_IMPL(Log, Log) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log) -JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) +JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) JSEP_KERNEL_IMPL(Sin, Sin) -JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin) +JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin) JSEP_KERNEL_IMPL(Cos, Cos) -JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos) +JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos) JSEP_KERNEL_IMPL(Tan, Tan) -JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan) +JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan) JSEP_KERNEL_IMPL(Asin, Asin) -JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin) +JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin) JSEP_KERNEL_IMPL(Acos, Acos) -JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos) +JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos) JSEP_KERNEL_IMPL(Atan, Atan) -JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan) +JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan) JSEP_KERNEL_IMPL(Sinh, Sinh) -JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh) +JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh) JSEP_KERNEL_IMPL(Cosh, Cosh) -JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh) +JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh) JSEP_KERNEL_IMPL(Asinh, Asinh) -JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh) +JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh) JSEP_KERNEL_IMPL(Acosh, Acosh) -JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh) +JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh) JSEP_KERNEL_IMPL(Atanh, Atanh) -JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh) +JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh) JSEP_KERNEL_IMPL(Tanh, Tanh) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh) -JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh) +JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh) JSEP_KERNEL_IMPL(Not, Not) -JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not) +JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu) +JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu) JSEP_KERNEL_IMPL(Relu, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu) -JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu) +JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu) -JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu) +JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu) +JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu) } // namespace js } // namespace onnxruntime From 7116e66c4b575ed816b2c836132e183b594c173a Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 19 Sep 2023 07:36:17 +0800 Subject: [PATCH 045/121] Improve Win QNNEP pipeline (#17586) ### Description 1. use standard win build template 2. enable compiler cache ### Motivation and Context Make win build task easy to maintain and accelerate the pipeline. --- .../templates/jobs/win-ci-build-steps.yml | 28 ++++++++++++++++ .../templates/jobs/win-ci-prebuild-steps.yml | 1 + .../azure-pipelines/win-qnn-ci-pipeline.yml | 33 +++++++------------ 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml index a81dd1e9cf240..2f67398908d5d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml @@ -12,6 +12,10 @@ parameters: type: string default: "$(Agent.TempDirectory)/ort_ccache" +- name: DebugCache + type: boolean + default: false + - name: AdditionalKey type: string default: "" @@ -45,6 +49,18 @@ parameters: steps: - ${{ if eq(parameters.WithCache, true) }}: + - powershell: | + if ([string]::IsNullOrEmpty((Get-Command ccache -errorAction SilentlyContinue))) + { + choco install ccache -y --version 4.7.4 + $ccache_path = (Get-Command ccache).Source + $ccache_parent_dir = (Split-Path -parent $ccache_path) + Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe" + Get-ChildItem $ccache_parent_dir + ccache --version + } + displayName: Install ccache + - task: Cache@2 inputs: ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: @@ -83,6 +99,11 @@ steps: createLogFile: true env: CCACHE_DIR: ${{parameters.CacheDir}} + CCACHE_SLOPPINESS: file_macro,time_macros,include_file_mtime,include_file_ctime + CCACHE_COMPILERCHECK: content + ${{if eq(parameters.DebugCache, true)}}: + CCACHE_DEBUG: 1 + CCACHE_DEBUGDIR: $(Agent.TempDirectory)/cache_debug - ${{ if eq(parameters.WithCache, true) }}: - powershell: | @@ -91,3 +112,10 @@ steps: displayName: cache stat env: CCACHE_DIR: ${{parameters.CacheDir}} + + - ${{if eq(parameters.DebugCache, true)}}: + - task: PublishPipelineArtifact@0 + displayName: 'publish cache log' + inputs: + artifactName: 'cache-log' + targetPath: $(Agent.TempDirectory)/cache_debug diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index e29d9d2cab91c..8868e671a5fa5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -1,6 +1,7 @@ parameters: - name: EnvSetupScript type: string + default: setup_env.bat - name: BuildConfig type: string diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 7db0c7302cd6f..68e0d51480a63 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -46,7 +46,8 @@ jobs: BuildConfig: 'RelWithDebInfo' ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' QNN_SDK_ROOT: 'C:\data\qnnsdk\${{parameters.QnnSdk}}' - timeoutInMinutes: 150 + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + timeoutInMinutes: 120 workspace: clean: all steps: @@ -70,29 +71,19 @@ jobs: # '_Ret &std::_Visit_strategy<1>::_Visit2<_Ret,_ListOfIndexVectors,_Callable,const # std::variant&>(size_t,_Callable &&, # const std::variant &)' - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\$(BuildConfig)\onnxruntime.sln' - platform: 'x64' - configuration: $(BuildConfig) - msbuildArgs: $(MsbuildArguments) - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: false - workingFolder: '$(Build.BinariesDirectory)\$(BuildConfig)' - createLogFile: true + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: True + Today: $(TODAY) + AdditionalKey: "win-qnn | $(BuildConfig)" + BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' + MsbuildArguments: $(MsbuildArguments) + BuildArch: $(buildArch) + Platform: 'x64' + BuildConfig: $(BuildConfig) - powershell: | python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run unit tests' From f969e7f8d866585e79c72df8a9c043f7e5be69bb Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Mon, 18 Sep 2023 16:41:11 -0700 Subject: [PATCH 046/121] Provide kwargs to remove_shared_initializers (#17539) ### Description Fixes a bug in `get_shared_initializers` where `signature_cache1, signature_cache2` are passed as positional arguments to `remove_shared_initializers` but their positions don't match the function signature. So `signature_cache1` is passed to `min_elements` and causes comparison error at line 907. Pass the arguments as kwargs so that it doesn't rely on their positions. ### Motivation and Context Fixes the bug described above. --- onnxruntime/python/tools/transformers/convert_generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 63c991167d235..c0cabbb5e9759 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -993,7 +993,11 @@ def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto encoder.remove_duplicated_initializer(signature_cache1) decoder.remove_duplicated_initializer(signature_cache2) initializers = remove_shared_initializers( - decoder.model.graph, encoder.model.graph, "s_", signature_cache1, signature_cache2 + decoder.model.graph, + encoder.model.graph, + shared_prefix="s_", + signature_cache1=signature_cache1, + signature_cache2=signature_cache2, ) return initializers From d350ab31d7ca5a6649022ef426caf51e56430bfe Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 18 Sep 2023 18:40:09 -0700 Subject: [PATCH 047/121] Remove reference to internals in torch.onnx in test (#17550) - https://github.com/microsoft/onnxruntime/issues/11901 --- onnxruntime/test/python/test_pytorch_export_contrib_ops.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py index a378721932a35..5e20d6b4e692a 100644 --- a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -49,9 +49,7 @@ def to_numpy(tensor): # These set of tests verify ONNX model export and compares outputs between # PyTorch and ORT. class ONNXExporterTest(unittest.TestCase): - from torch.onnx.symbolic_helper import _export_onnx_opset_version - - opset_version = _export_onnx_opset_version + opset_version = 17 keep_initializers_as_inputs = True # For IR version 3 type export. def setUp(self): From 068300d97eb25e5b52324e7af54a45ed1fa6a4c3 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 18 Sep 2023 19:31:04 -0700 Subject: [PATCH 048/121] Pin beartype version (#17599) PyTorch doesn't like the latest beartype: https://github.com/pytorch/pytorch/pull/109510 --- .../linux/docker/scripts/manylinux/install_deps_lort.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index b2e575f38e425..3bca6413100a2 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -20,6 +20,10 @@ export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" /opt/python/cp39-cp39/bin/python3.9 -m pip install transformers +# beartype is installed here so that onnxscript installation step won't +# install a version PyTorch doesn't like. Once beartype fixes this problem. +# We can remove this line. +/opt/python/cp39-cp39/bin/python3.9 -m pip install beartype==0.15.0 cd /usr/local/ echo "Cloning ONNX Script" From 730fab3050c631b6a888990f55840ed6481bd07b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 19 Sep 2023 09:49:21 -0700 Subject: [PATCH 049/121] Refactor Attention cuda kernel (#17578) * Break QkvToContext into small functions. Each fused and unfused kernel will have separated function. * Move DecoderAttention kernel to separated file * Move KV cache related kernel to attention_kv_cache.cu ### Motivation and Context To make the code easier to maintain. --- cmake/onnxruntime_rocm_hipify.cmake | 2 + .../contrib_ops/cuda/bert/attention_impl.cu | 931 +++++++----------- .../contrib_ops/cuda/bert/attention_impl.h | 56 +- ...ention_concat.cu => attention_kv_cache.cu} | 191 +++- .../cuda/bert/attention_prepare_qkv.cu | 58 +- .../cuda/bert/attention_softmax.cu | 1 - .../cuda/bert/decoder_attention.cc | 14 +- .../cuda/bert/decoder_attention_impl.cu | 263 +++++ .../cuda/bert/decoder_attention_impl.h | 41 + .../contrib_ops/rocm/bert/attention_impl.cu | 1 + .../contrib_ops/rocm/bert/attention_impl.h | 28 - .../rocm/bert/decoder_attention_impl.h | 46 + 12 files changed, 904 insertions(+), 728 deletions(-) rename onnxruntime/contrib_ops/cuda/bert/{attention_concat.cu => attention_kv_cache.cu} (67%) create mode 100644 onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index b483e6de81a47..cf71b6bcf7c7d 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -11,6 +11,8 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/attention_softmax.cu" "bert/attention_prepare_qkv.cu" + "bert/decoder_attention_impl.h" + "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 366d8fee1473b..b4a4ae208ceb1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -153,408 +153,285 @@ size_t GetAttentionWorkspaceSize( } template -__global__ void AddBiasTransAppendKvToPresentSmall( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int N = blockDim.y; - const int S = gridDim.x; - const int B = gridDim.y; - - constexpr int M = 3; // Matrix count in qkv - const int m = blockIdx.z + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + assert(data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.mask_index == nullptr); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, + sequence_length, stream, + data.scratch); -template -__global__ void AddBiasTransAppendKvToPresent( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = blockIdx.x; - const int s = blockIdx.y; - const int b = (blockIdx.z >> 1); - const int N = gridDim.x; - const int S = gridDim.y; - const int B = (gridDim.z >> 1); - - constexpr int M = 3; // Matrix count in qkv - const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, parameters.kv_sequence_length, stream, + kv_sequence_offset); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); -// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. -// bias is of shape (3, NxH) or nullptr -// append to present of (2, B, N, (P..T) of M, H), -template -Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int past_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const T* biases, - const T* qkv_buffer, - T* present) { - assert(head_size <= (1 << 30)); - - int64_t nh = (int64_t)head_size * num_heads; - if (nh <= max_threads_per_block) { - const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - - AddBiasTransAppendKvToPresentSmall<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); - } else { - const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v - const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); - AddBiasTransAppendKvToPresent<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = + reinterpret_cast(data.fused_cross_attention_kernel); + + // When there is no bias, we can directly use q and packed kv from inputs. + void const* query = data.q; + void const* packed_kv = data.k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; } - return CUDA_CALL(cudaGetLastError()); + run_fused_cross_attention( + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + parameters.num_heads, // number of heads + parameters.head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + parameters.kv_sequence_length, // sequence length of KV + stream); + + DUMP_TENSOR("trt cross output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + return Status::OK(); } -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* bias, - const float* qkv_buffer, - float* present); - -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* bias, - const half* qkv_buffer, - half* present); +template <> +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused cross attention does not support float tensor"); +} template -Status QkvToContext( - const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, +Status FusedTrtSelfAttention( + cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { - auto stream = static_cast(ort_stream->GetHandle()); - constexpr size_t element_size = sizeof(T); - const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int total_sequence_length = parameters.total_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - const float mask_filter_value = parameters.mask_filter_value; - void* fused_runner = data.fused_runner; - - // At most one fused kernel is enabled. - assert((int(data.use_flash_attention) + - int(data.use_memory_efficient_attention) + - int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr)) <= 1); - - const int batches = batch_size * num_heads; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - QkvData qkv; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, qkv)); - T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace; - - int present_size_per_batch_k = 0; - int present_size_per_batch_v = 0; - if (!past_present_share_buffer) { - present_size_per_batch_k = total_sequence_length * qk_head_size; - present_size_per_batch_v = total_sequence_length * v_head_size; - ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, - stream, max_threads_per_block, data, qkv)); - - } else { // past_present_share_buffer - assert(qk_head_size == v_head_size); - assert(data.fused_cross_attention_kernel == nullptr); - assert(!use_fused_kernel); - assert(data.gemm_buffer != nullptr); - assert(!data.use_memory_efficient_attention); - assert(!data.use_flash_attention); - assert(data.has_qkv_workspace); - - if (nullptr != data.past_key || nullptr != data.present_key) { - // TODO: support this case. - ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); - } - - if (data.present != data.past) { - // For easy testing. Production should better avoid this path. - int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; - cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } - - // append last k v to present - ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( - stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer - data.gemm_buffer, data.present)); + const bool causal = parameters.is_unidirectional; - present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; - present_size_per_batch_v = present_size_per_batch_k; - qkv.k = data.present; - qkv.v = data.present + batches * present_size_per_batch_k; - } + int* sequence_offset = reinterpret_cast(data.scratch); - // Q, K and V are ready now DUMP_TENSOR_INIT(); - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.mask_index == nullptr); - - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - scratch1); - - DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); - - FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = - reinterpret_cast(data.fused_cross_attention_kernel); - - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = qkv.q; - void const* packed_kv = qkv.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV - stream); - - DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size); - return Status::OK(); + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); + LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } else { + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - // Run TRT fused attention. - if (use_fused_kernel || use_fused_causal) { - int* sequence_offset = reinterpret_cast(scratch1); - if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); - } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); - } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); - FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); + const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - const int S = use_fused_causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); + // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. + const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. - const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); + fused_fp16_runner->setup(S, B); - fused_fp16_runner->setup(S, B); + if (!causal) { + assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - if (use_fused_kernel) { - assert(qkv.format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); - } else { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = data.q; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; } - return Status::OK(); + + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + DUMP_TENSOR("fused output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); + DUMP_TENSOR("fused causal output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); +} - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; +// Template Specialization for float type +template <> +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused attention does not support float tensor"); +} #if USE_FLASH_ATTENTION - if (data.use_flash_attention) { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - assert(nullptr == data.mask_index); - assert(nullptr == data.relative_position_bias); - assert(parameters.head_size == parameters.v_head_size); - - void* query = reinterpret_cast(qkv.q); - void* key = reinterpret_cast(qkv.k); - void* value = reinterpret_cast(qkv.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - constexpr bool is_causal = false; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1), - parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, is_causal)); +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + assert(parameters.head_size == parameters.v_head_size); + + void* query = reinterpret_cast(data.q); + void* key = reinterpret_cast(data.k); + void* value = reinterpret_cast(data.v); + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { + query = reinterpret_cast(const_cast(data.query)); + } - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + + DUMP_TENSOR("flash attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} - return Status::OK(); - } +template <> +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); +} #endif #if USE_MEMORY_EFFICIENT_ATTENTION - if (data.use_memory_efficient_attention) { - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = qkv.q; - const void* key = qkv.k; - const void* value = qkv.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + + const void* query = data.q; + const void* key = data.k; + const void* value = data.v; + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); + query = data.query; + } - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - MemoryEfficientAttentionParams p; - p.sm = device_prop.major * 10 + device_prop.minor; - p.is_half = sizeof(T) == 2; - p.batch_size = parameters.batch_size; - p.num_heads = parameters.num_heads; - p.sequence_length = parameters.sequence_length; - p.kv_sequence_length = parameters.total_sequence_length; - p.qk_head_size = parameters.head_size; - p.v_head_size = parameters.v_head_size; - p.causal = parameters.is_unidirectional; - p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.num_heads; + p.sequence_length = parameters.sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; - p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) - ? scratch1 - : nullptr; - p.stream = stream; - run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); - - return Status::OK(); - } + : const_cast(reinterpret_cast( + data.mask_index + parameters.batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + 2 * parameters.batch_size + 1)); + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; + p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) + ? data.scratch + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + DUMP_TENSOR("efficient attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} #endif - // The following are unfused attention. - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH); +template +Status UnfusedAttention( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + + auto stream = static_cast(ort_stream->GetHandle()); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const int batches = batch_size * num_heads; + const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; // Raw attention mask could be 2D (BxT) or 3D (BxSxT) or 4D(Bx1xMxM), where M is the max sequence length. bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxT + // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch: BxNxSxT // Q: BxNxSxH, K (present_k): BxNxTxH, Q*K': BxNxSxT float one = 1.0f; float zero = 0.f; @@ -563,22 +440,31 @@ Status QkvToContext( cublasSetStream(cublas, stream); - DUMP_TENSOR_D("q[BNSH]", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + + const int present_sequence_length = parameters.past_present_share_buffer + ? parameters.max_sequence_length + : total_sequence_length; + const int present_size_per_batch_k = present_sequence_length * qk_head_size; + const int present_size_per_batch_v = present_sequence_length * v_head_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, qkv.k, qk_head_size, present_size_per_batch_k, - qkv.q, qk_head_size, sequence_length * qk_head_size, - &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &alpha, data.k, qk_head_size, present_size_per_batch_k, + data.q, qk_head_size, sequence_length * qk_head_size, + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length); - DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); + constexpr size_t element_size = sizeof(T); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); - T* scratch2 = scratch1 + (bytes / element_size); + T* scratch2 = data.scratch + (bytes / element_size); // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask @@ -588,14 +474,15 @@ Status QkvToContext( const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + // replace Q*K' in place with masked score for persistent softmax. + T* persistent_softmax_workspace = data.scratch; ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, + data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, - mask_filter_value)); + parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index assert(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. @@ -603,277 +490,123 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional)); + data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); + parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv.q; + T* temp_output = data.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, qkv.v, v_head_size, present_size_per_batch_v, + &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, temp_output, data.output); + device_prop.maxThreadsPerBlock, false, temp_output, data.output); DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } template -Status DecoderQkvToContext( +Status QkvToContext( const cudaDeviceProp& device_prop, - Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); + // At most one fused kernel is enabled. + assert((int(data.use_flash_attention) + + int(data.use_memory_efficient_attention) + + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr)) <= 1); - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - key_cache, - temp_qkv_buffer, - qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); + + if (!parameters.past_present_share_buffer) { + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data)); + + } else { // past_present_share_buffer + assert(qk_head_size == v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(nullptr == fused_runner || parameters.is_unidirectional); + assert(data.gemm_buffer != nullptr); + assert(!data.use_memory_efficient_attention); + assert(!data.use_flash_attention); + assert(data.has_qkv_workspace); + + if (nullptr != data.past_key || nullptr != data.present_key) { + // TODO: support this case. + ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); } - } - if (has_layer_state) { - if (use_past && static_kv) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - } else { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + if (data.present != data.past) { + // For easy testing. Production should better avoid this path. + int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); } - } - // scratch1: BxNxSxL buffer - // scratch2: BxNxSxL buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL - // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - float one = 1.0f; - float zero = 0.f; + // For fused causal, bias has been added to gemm_buffer. + const T* bias = (nullptr != fused_runner && parameters.is_unidirectional) ? nullptr : data.bias; - float alpha = rsqrt_head_size; - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, key_cache, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, k, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + // append last k v to present + ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( + stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, max_threads_per_block, + bias, data.gemm_buffer, data.present)); + + data.k = data.present; + data.v = data.present + batch_size * num_heads * parameters.max_sequence_length * qk_head_size; } - constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; - if (has_key_padding_mask) { - constexpr int mask_dimension = 2; - constexpr int max_sequence_length = 0; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, - 1.0f, mask_dimension, max_sequence_length, false, nullptr, - mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax( - stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + // Q, K and V are ready now + if (data.fused_cross_attention_kernel != nullptr) { + return FusedTrtCrossAttention(stream, parameters, data); } - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + // Run TRT fused attention. + if (nullptr != fused_runner) { + return FusedTrtSelfAttention(stream, parameters, data); } - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, scratch3, output); -} + // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& device_prop, - Stream* stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } +#endif + + return UnfusedAttention(device_prop, cublas, ort_stream, parameters, data, scale); } // Template Instantiation diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index c361a47c364d3..d0a5fb51a25d6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -81,24 +81,20 @@ struct AttentionData { mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; -}; -// Intermediate data pointers available after PrepareQKV -template -struct QkvData { + // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; - T* after_v = nullptr; // pointer right after v - AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH; + T* scratch = nullptr; + AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; }; template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv_data); + int max_threads_per_block); template Status QkvToContext( @@ -108,33 +104,6 @@ Status QkvToContext( contrib::AttentionParameters& parameters, AttentionData& data); -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - Stream* stream, // ORT Stream - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - // BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true) Status LaunchTransCtx(cudaStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, @@ -184,14 +153,27 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); + +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present); template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) T* out, longlong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu similarity index 67% rename from onnxruntime/contrib_ops/cuda/bert/attention_concat.cu rename to onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 8378ee2691c6b..89be0f1115f41 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" using namespace onnxruntime::cuda; @@ -244,48 +245,48 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, present); } -#ifndef USE_ROCM // exclude from hipify +#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP + template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv) { + AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. if (nullptr != data.present) { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + ORT_RETURN_IF_ERROR( LaunchConcatPastToPresent( stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, qkv.k, data.present)); + max_threads_per_block, data.past, data.k, data.present)); // Update pointers to present_k and present_v. - qkv.k = data.present; - qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { + data.k = data.present; + data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } else if (nullptr != data.past_key || nullptr != data.present_key) { if (nullptr != data.past_key && nullptr == data.present_key) { - qkv.k = const_cast(data.past_key); - qkv.v = const_cast(data.past_value); + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) { - qkv.k = data.present_key; - qkv.v = data.present_value; + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + data.k = data.present_key; + data.v = data.present_value; } else { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - qkv.k = data.temp_k_workspace; - qkv.v = data.temp_v_workspace; + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + data.k = data.temp_k_workspace; + data.v = data.temp_v_workspace; } } else if (pass_past_in_kv) { // past_key and past_value are used directly as key and value in attention computations - qkv.k = const_cast(data.past_key); - qkv.v = const_cast(data.past_value); + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); // This path has a memory copy from past_key and past_value to present_key and present_value // Avoid this path since the memory copy is unnecessary because past_key == present_key and @@ -298,14 +299,14 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, qkv.k, data.present_key)); + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, qkv.v, data.present_value)); + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); // Update pointers to present_k and present_v. - qkv.k = data.present_key; - qkv.v = data.present_value; + data.k = data.present_key; + data.v = data.present_value; } } @@ -317,15 +318,147 @@ template Status ConcatPastToPresent(int batch_size, int num_heads, int qk int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); + +// ---------------------------------------------------------------------------------- +// Below kernels are for past and present sharing buffer +// ---------------------------------------------------------------------------------- + +template +__global__ void AddBiasTransAppendKvToPresentSmall( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + constexpr int M = 3; // Matrix count in qkv + const int m = blockIdx.z + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +template +__global__ void AddBiasTransAppendKvToPresent( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = blockIdx.x; + const int s = blockIdx.y; + const int b = (blockIdx.z >> 1); + const int N = gridDim.x; + const int S = gridDim.y; + const int B = (gridDim.z >> 1); + + constexpr int M = 3; // Matrix count in qkv + const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. +// bias is of shape (3, NxH) or nullptr +// append to present of (2, B, N, (P..T) of M, H), +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present) { + assert(head_size <= (1 << 30)); + + int64_t nh = (int64_t)head_size * num_heads; + if (nh <= max_threads_per_block) { + const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + + AddBiasTransAppendKvToPresentSmall<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } else { + const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v + const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); + AddBiasTransAppendKvToPresent<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* bias, + const float* qkv_buffer, + float* present); + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* bias, + const half* qkv_buffer, + half* present); #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index cd4137ab11de9..5c65a30918ece 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" @@ -406,22 +405,25 @@ Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, // Query (BxSxNxH) => Q (BxNxSxH) constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); @@ -435,8 +437,8 @@ template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv) { + int max_threads_per_block) { + data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -445,27 +447,27 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv.q = data.workspace; - qkv.k = data.workspace + elements_q; - qkv.v = qkv.k + elements_k; - qkv.after_v = qkv.v + elements_v; + data.q = data.workspace; + data.k = data.workspace + elements_q; + data.v = data.k + elements_k; + data.scratch = data.v + elements_v; } if (nullptr != data.gemm_buffer) { // Attention operator ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - qkv.format)); + data.qkv_format)); } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else { // multihead attention operator, no past, separated Q/K/V inputs ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -477,15 +479,13 @@ template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv); + int max_threads_per_block); template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv); + int max_threads_per_block); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index e7d2255fb46b8..01ea02f48d3ab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include -#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index f907d300607f9..3f703ae3d05e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/decoder_attention.h" +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include "core/framework/op_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape, } if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast(hidden_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_weights shall have shape (hidden size, 2 * hidden size)"); } const auto& bias_dims = bias_shape.GetDims(); @@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape, const auto& value_cache_dims = value_cache->Shape().GetDims(); if (value_cache_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimension, got ", value_cache_dims.size()); } @@ -353,10 +355,12 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { } } - size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; + size_t bytes = element_size * batch_size * + (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); - bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); + bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * + (static_cast(2) * head_size + static_cast(kv_sequence_length)); auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); Tensor* output(context->Output(0, query_shape)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu new file mode 100644 index 0000000000000..1dc22a9c8ea98 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" + +using namespace onnxruntime::contrib::attention_softmax_cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status DecoderQkvToContext( + const cudaDeviceProp& device_prop, + Stream* ort_stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + key_cache, + temp_qkv_buffer, + qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxL buffer + // scratch2: BxNxSxL buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL + // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + float one = 1.0f; + float zero = 0.f; + + float alpha = rsqrt_head_size; + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } + + constexpr bool is_unidirectional = false; + const T* add_before_softmax = nullptr; + if (has_key_padding_mask) { + constexpr int mask_dimension = 2; + constexpr int max_sequence_length = 0; + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, key_padding_mask, add_before_softmax, + false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + 1.0f, mask_dimension, max_sequence_length, false, nullptr, + mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, + is_unidirectional)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& device_prop, + Stream* stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..9db9ccb45e330 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& prop, // Device Properties + Stream* stream, // ORT Stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 9a150c9e6cd77..b0ed3ff82226a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/rocm/bert/attention_impl.h" #include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" using namespace onnxruntime::rocm; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 19b2bc34efaec..3164e8c211099 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length); -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - Status LaunchTransCtx(hipStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..d71c6d8440a44 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime From 152e61da37514e0b43bf1d9929ac6b43d54859be Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 19 Sep 2023 10:42:27 -0700 Subject: [PATCH 050/121] Avoid `get_logger` overriding root logger level (#17569) ### Description Instead, set level to DEBUG for the logger returned. ### Motivation and Context Otherwise, this function call overrides root logger level setting, which affects logging facility of other python packages. --- tools/python/util/logger.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tools/python/util/logger.py b/tools/python/util/logger.py index 9deb4475721ee..15e04528ac7ac 100644 --- a/tools/python/util/logger.py +++ b/tools/python/util/logger.py @@ -5,6 +5,7 @@ def get_logger(name): - logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) - - return logging.getLogger(name) + logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s") + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + return logger From 460f17fbb84e29cd693c6009199070e0851aae0d Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 19 Sep 2023 12:20:18 -0700 Subject: [PATCH 051/121] [JS/WebGPU] Support If on WebGPU (#17478) --- js/web/docs/webgpu-operators.md | 1 + js/web/test/suite-test-list.jsonc | 5 ++ onnxruntime/core/framework/session_state.cc | 5 +- .../providers/js/js_execution_provider.cc | 12 +++- onnxruntime/core/providers/js/operators/if.cc | 65 +++++++++++++++++++ onnxruntime/core/providers/js/operators/if.h | 24 +++++++ 6 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/js/operators/if.cc create mode 100644 onnxruntime/core/providers/js/operators/if.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 71d98f5d73671..a87a894e3b3c5 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -46,6 +46,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 94592884ccad6..6e65645ef4756 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -602,6 +602,11 @@ // // "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", + "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative", diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b7d26d87f2705..f0e5fbbd38721 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) { + if (!ep.empty() && + ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && + ep != kRocmExecutionProvider && ep != kDmlExecutionProvider && + ep != kJsExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9dccd7c47fbb6..0674fe02d093d 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -580,15 +584,17 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc new file mode 100644 index 0000000000000..ef072bb1635dd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "if.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); +// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' +// branches to be the same were relaxed in opset-11 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-13 supports sequence type for If's subgraph outputs +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-19 supports float8 +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +Status If::Compute(OpKernelContext* ctx) const { + // call the base CPU version. + return onnxruntime::If::Compute(ctx); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.h b/onnxruntime/core/providers/js/operators/if.h new file mode 100644 index 0000000000000..d060444ccc1d2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/providers/js/js_kernel.h" +#include "core/common/common.h" +#include "core/providers/cpu/controlflow/if.h" + +namespace onnxruntime { +class SessionState; + +namespace js { + +// Use the CPU implementation for the logic +class If final : public onnxruntime::If { + public: + If(const OpKernelInfo& info) : onnxruntime::If(info) {} + + Status Compute(OpKernelContext* ctx) const override; +}; +} // namespace js +} // namespace onnxruntime From d522cc7cc47e6651446a1104d8f94e3b37b34ea6 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 19 Sep 2023 14:42:08 -0700 Subject: [PATCH 052/121] Update npm-packaging-pipeline.yml to always use artifacts from main branch (#17604) ### Description Update npm-packaging-pipeline.yml to always use artifacts from main branch --- .../ci_build/github/azure-pipelines/npm-packaging-pipeline.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index ec85002503c0e..2e7ac9508a41e 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -73,7 +73,8 @@ stages: project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' definition: '940' specificBuildWithTriggering: true - buildVersionToDownload: 'latest' + buildVersionToDownload: 'latestFromBranch' + branchName: 'refs/heads/main' artifactName: 'NPM_packages' targetPath: '$(Pipeline.Workspace)' displayName: 'Download onnxruntime-node Pipeline Artifact' From f297d4dfb9f181942a224c96186d07db7155e61e Mon Sep 17 00:00:00 2001 From: Numfor Tiapo Date: Tue, 19 Sep 2023 17:12:14 -0700 Subject: [PATCH 053/121] Remove onnxruntime extensions from list of gitmodules (#17615) The extensions submodule was removed in [this PR](https://github.com/microsoft/onnxruntime/pull/17097) but not deleted from the list of git modules. This causes breaks in code ingesting ORT that references the git modules for an accurate list of submodules. This change removes the extensions from the list of git modules to resolve this issue. --- .gitmodules | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 036a248070855..7bb49e98bfec1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,6 +8,3 @@ path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git branch = 3.1.44 -[submodule "cmake/external/onnxruntime-extensions"] - path = cmake/external/onnxruntime-extensions - url = https://github.com/microsoft/onnxruntime-extensions.git From e6301eee6aa3f22f31bdcf431d0ed0ca828d20be Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 20 Sep 2023 11:02:58 +0800 Subject: [PATCH 054/121] Bump Up Version to 1.17.0 (#17587) Bump up version to 1.17.0 as the 1.16.0 release branch had been branched out. --- VERSION_NUMBER | 2 +- .../Training/NativeTrainingMethods.shared.cs | 4 ++-- docs/python/README.rst | 5 +++++ include/onnxruntime/core/session/onnxruntime_c_api.h | 2 +- js/common/lib/version.ts | 2 +- js/common/package-lock.json | 4 ++-- js/common/package.json | 2 +- js/node/lib/version.ts | 2 +- js/node/package-lock.json | 6 +++--- js/node/package.json | 2 +- js/react_native/lib/version.ts | 2 +- js/react_native/package.json | 2 +- js/react_native/yarn.lock | 2 +- js/web/lib/version.ts | 2 +- js/web/package-lock.json | 6 +++--- js/web/package.json | 2 +- onnxruntime/__init__.py | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 8 ++++---- 18 files changed, 31 insertions(+), 26 deletions(-) diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 15b989e398fc7..092afa15df4df 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.16.0 +1.17.0 diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index ac790242409e3..1868ff509bfc3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -62,10 +62,10 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(16 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(16 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); diff --git a/docs/python/README.rst b/docs/python/README.rst index 7d978b0941235..32bb3729e01d0 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.16.0" +__version__ = "1.17.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 15fe5acfe0fd2..4c0adcdd374aa 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2397,7 +2397,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_16 = { +static constexpr OrtApi ort_api_1_to_17 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2745,16 +2745,16 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.16.0", +static_assert(std::string_view(ORT_VERSION) == "1.17.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it -// 2. If there were any APIs added to ort_api_1_to_16 above: +// 2. If there were any APIs added to ort_api_1_to_17 above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_16; + return &ort_api_1_to_17; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." From c65e892089e2e6f383e80339334450ac06be5ba0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 20 Sep 2023 10:35:15 -0700 Subject: [PATCH 055/121] [CUDA] Fix performance bug in DecoderMaskedMultiheadAttention for BeamSearch (#17613) --- ...decoder_masked_multihead_attention_impl.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c8877a5e3f872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio q = add_vec(q, q_bias); } - T* params_k_cache = reinterpret_cast(params.k_cache); const float inv_sqrt_dh = params.scale; @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The keys loaded from the key cache. K_vec_k k_vec[K_VECS_PER_THREAD]; + if (ti < tlength) { + if (has_beams) { + const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; - if (has_beams) { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { - const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } - } - } else { + } else { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } From c55da45e20435b8aa9edb78179b6027502b778b0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 20 Sep 2023 14:31:01 -0700 Subject: [PATCH 056/121] [QNN EP] Add more op unit tests (fix Clip, TopK, Tile) (#17457) ### Description Adds more operator unit tests (all op types should now have at least 1 unit test): - [x] Reshape - [x] Flatten - [x] Squeeze - [x] Unsqueeze - [x] Gemm - [x] Clip - Enable QDQ Clip on HTP backend (when not optimized away by L1 ClipQuantFusion optimizer) - Add support for 16-bit QDQ Clip to ClipQuantFusion optimizer - [x] Split - [x] Topk - Enable QDQ TopK on HTP backend - [x] Tile - Enable QDQ Tile on HTP backend ### Motivation and Context Increase QNN operator support and test coverage. --- .../qdq_transformer/clip_quantizelinear.cc | 25 +- .../selectors_actions/qdq_selectors.cc | 36 ++ .../selectors_actions/qdq_selectors.h | 8 + .../selectors_actions/shared/utils.cc | 18 +- .../qnn/builder/opbuilder/clip_op_builder.cc | 127 +++--- .../providers/qnn/builder/opbuilder/topk.cc | 15 +- .../test/optimizer/qdq_transformer_test.cc | 7 + .../test/providers/qnn/average_pool_test.cc | 6 +- .../test/providers/qnn/clip_op_test.cc | 188 +++++++++ .../test/providers/qnn/flatten_op_test.cc | 202 +++++++++ .../test/providers/qnn/gather_op_htp_test.cc | 64 +-- .../test/providers/qnn/gemm_op_test.cc | 341 +++++++++++++++ .../providers/qnn/instance_norm_htp_test.cc | 21 +- .../test/providers/qnn/layer_norm_test.cc | 4 +- .../providers/qnn/leakyrelu_op_htp_test.cc | 46 +-- .../test/providers/qnn/max_min_op_test.cc | 9 +- .../test/providers/qnn/pool_op_test.cpp | 19 +- .../test/providers/qnn/qnn_test_utils.cc | 14 +- .../test/providers/qnn/qnn_test_utils.h | 82 ++-- .../test/providers/qnn/reshape_op_test.cc | 225 ++++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 16 +- .../test/providers/qnn/slice_htp_test.cc | 64 +-- .../test/providers/qnn/split_op_test.cc | 387 ++++++++++++++++++ .../qnn/squeeze_unsqueeze_op_test.cc | 324 +++++++++++++++ .../test/providers/qnn/tile_op_test.cc | 132 ++++++ .../test/providers/qnn/topk_op_test.cc | 209 ++++++++++ 26 files changed, 2273 insertions(+), 316 deletions(-) create mode 100644 onnxruntime/test/providers/qnn/clip_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/flatten_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/gemm_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/reshape_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/split_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/tile_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/topk_op_test.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index a0942c31b0161..50653b368857d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" + +#include + +#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" @@ -50,14 +53,26 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& switch (zp_initializer.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT8: { const int8_t zero_point = zp_initializer.data()[0]; - lower = scale * (-128 - zero_point); - upper = scale * (127 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { const uint8_t zero_point = zp_initializer.data()[0]; - lower = scale * (0 - zero_point); - upper = scale * (255 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + const int16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + const uint16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } default: diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 16c7bd5fce960..5015e48fdb7b8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -496,6 +496,42 @@ bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, return dt_input_1 == dt_input_2; } +bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + constexpr int num_dq_inputs = 1; + constexpr int num_q_outputs = 1; + if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { + return false; + } + + if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); + !dq_validation_status.IsOK()) { + return false; + } + + if (num_q_outputs != gsl::narrow_cast(q_nodes.size())) { + return false; + } + + const Node& dq_node = *dq_nodes.front(); + const Node& q_node = *q_nodes.front(); + + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d8fefdd8dc3d9..be7f7e0288eda 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -220,6 +220,14 @@ class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// TopK has 1 DQ input node and 1 Q output node. +// Zero point and scale are constant scalars and must match +class TopKNodeGroupSelector : public NodeGroupSelector { + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index f1bdd7a99c329..3f1b2f0458bc0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -36,7 +36,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Resize", {}}, {"Split", {}}, {"Squeeze", {}}, - {"Unsqueeze", {}}}; + {"Unsqueeze", {}}, + {"Tile", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { @@ -78,7 +79,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Abs", {}}, {"Neg", {}}, {"DepthToSpace", {}}, - {"SpaceToDepth", {}}}; + {"SpaceToDepth", {}}, + {"Clip", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -127,6 +129,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { return {{"Pad", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() { + return {{"TopK", {}}}; +} + /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { /* register selectors for miscellaneous ops */ @@ -227,6 +233,13 @@ void RegisterPadSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterTopKSelector(Selectors& qdq_selectors) { + /* register selector for TopK op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetTopKOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -242,6 +255,7 @@ void SelectorManager::CreateSelectors() { RegisterLogicalComparisonSelectors(qdq_selectors_); RegisterWhereSelectors(qdq_selectors_); RegisterPadSelectors(qdq_selectors_); + RegisterTopKSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index 92a7feea7fc54..df4c718949269 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -9,8 +12,6 @@ #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class ClipOpBuilder : public BaseOpBuilder { @@ -33,8 +34,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; - mutable float min_value_ = std::numeric_limits::lowest(); - mutable float max_value_ = std::numeric_limits::max(); }; Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { @@ -61,61 +60,8 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (do_op_validation) { ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit)); } - Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; - - auto inputs = node_unit.Inputs(); - for (size_t input_i = 0; input_i < inputs.size(); ++input_i) { - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = inputs[input_i].quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); - - auto& input_name = inputs[input_i].node_arg.Name(); - if (input_name.empty()) { - // Ignore unspecified/unused optional input - continue; - } - if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { - LOGS(logger, VERBOSE) << "Tensor already added or the input is not named, skip it: " << input_name; - input_names.push_back(input_name); - continue; - } - - const auto* type_proto = inputs[input_i].node_arg.TypeAsProto(); - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input_shape), "Cannot get shape"); - - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[input_i].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - - float* ini_data = nullptr; - std::vector unpacked_tensor; - bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); - if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); - ini_data = reinterpret_cast(unpacked_tensor.data()); - if (input_i == 1) { - min_value_ = *ini_data; - continue; - } else if (input_i == 2) { - max_value_ = *ini_data; - continue; - } - } - ORT_ENFORCE(input_i == 0, "QNN ReluMinMax operator expects only one input. Min and max are expected to be parameters, ie. initializer inputs in ONNX model"); - - Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, - std::move(input_shape), std::move(unpacked_tensor)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - input_names.push_back(input_name); - } - return Status::OK(); + return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names); } Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -123,20 +69,59 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + const size_t num_inputs = inputs.size(); + + const Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; std::vector param_tensor_names; - Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; - min_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - min_qnn_scalar.floatValue = min_value_; - QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, min_qnn_scalar); - param_tensor_names.push_back(min_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); - - Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; - max_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - max_qnn_scalar.floatValue = max_value_; - QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, max_qnn_scalar); - param_tensor_names.push_back(max_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + + auto get_f32_from_bytes = [](const std::vector& bytes, float default_val) -> float { + return bytes.empty() ? default_val : *reinterpret_cast(bytes.data()); + }; + + // Set the 'min' parameter. + { + std::vector min_val_bytes; + + if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) { + OnnxInputInfo min_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], min_input_info)); + ORT_RETURN_IF_NOT(min_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'min' input of the Clip operator must be of type float32."); + assert(min_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*min_input_info.initializer_tensor, min_val_bytes)); + } + + Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; + min_qnn_scalar.dataType = qnn_data_type; + min_qnn_scalar.floatValue = get_f32_from_bytes(min_val_bytes, std::numeric_limits::lowest()); + QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, + min_qnn_scalar); + param_tensor_names.push_back(min_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); + } + + // Set the 'max' parameter. + { + std::vector max_val_bytes; + + if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) { + OnnxInputInfo max_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], max_input_info)); + ORT_RETURN_IF_NOT(max_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'max' input of the Clip operator must of type float32."); + assert(max_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*max_input_info.initializer_tensor, max_val_bytes)); + } + + Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; + max_qnn_scalar.dataType = qnn_data_type; + max_qnn_scalar.floatValue = get_f32_from_bytes(max_val_bytes, std::numeric_limits::max()); + QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, + max_qnn_scalar); + param_tensor_names.push_back(max_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + } ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 6ca36736f2f7f..047972294f78c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -63,9 +63,20 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N auto rank = input_shape.size(); auto axis = node_helper.Get("axis", -1); - if (-1 == axis && axis != static_cast(rank - 1)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK axis is always the last dimension"); + ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast(rank - 1), + "QNN TopK's axis is always the last dimension"); + + // ONNX TopK outputs int64 indices, but the equivalent QNN op outputs uint32 indices. + // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type + // for TopK ops within the graph. However, if the TopK op **generates** a graph output, + // then we cannot support it on the HTP backend. + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + if (is_npu_backend) { + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), + "QNN EP does not support TopK ops that generate a graph output."); } + return Status::OK(); } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index a438a61cb9b36..d3616a14d8a5d 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2497,10 +2497,15 @@ TEST(QDQTransformerTests, Clip) { epsilon); }; + constexpr int16_t int16_min = std::numeric_limits::min(); + constexpr uint16_t uint16_min = std::numeric_limits::min(); + std::vector opsets{12, 18, 19}; for (auto opset : opsets) { test_case(.0235294122248888f, static_cast(-128), 0, opset); // [0, 6] test_case(.0235294122248888f, static_cast(-128), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, int16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, int16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(-128), 0, opset); // [0, 5.1] test_case(.02f, static_cast(-128), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(-128), 1, opset); // [0, 7.65] @@ -2513,6 +2518,8 @@ TEST(QDQTransformerTests, Clip) { test_case(.04f, static_cast(-97), 1, opset, true); // [-1.24, 8.96] contrib qdq test_case(.02352941176f, static_cast(0), 0, opset); // [0, 6] test_case(.02352941176f, static_cast(0), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, uint16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, uint16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(0), 0, opset); // [0, 5.1] test_case(.02f, static_cast(0), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(0), 1, opset); // [0, 7.65] diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 79ec07796c0e8..0ee52f7fec21a 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -32,7 +32,7 @@ static void RunAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -53,8 +53,8 @@ static void RunQDQAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs), - BuildQDQOpTestCase(op_type, input_defs, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc new file mode 100644 index 0000000000000..15ba3b5de2fa1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Clip operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunClipTestOnCPU(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Clip with a dynamic min or max input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Clip_Dynamic_MinMax_Unsupported) { + // Dynamic min input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, false /* is_initializer */, {-5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + // Dynamic max input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, false, {5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Clip with default min/max. +TEST_F(QnnCPUBackendTests, Clip_4D_f32_DefaultMinMax) { + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test Clip with 5D input. +TEST_F(QnnCPUBackendTests, Clip_5D_f32) { + RunClipTestOnCPU(TestInputDef({1, 1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a QDQ Clip model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQClipTestOnHTP(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Clip", {input_def}, {min_max_defs}, {}); + auto qdq_model_builder = BuildQDQOpTestCase("Clip", {input_def}, {min_max_defs}, {}, + kOnnxDomain, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U8_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U16_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U16_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Clip of rank 5. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Clip -> Q + // QDQ node group, which gets lowered to a single QNN Clip node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 1, 2, 2, 2}, {0, 1, 6, 10, 20, 100, 128, 255}); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Min/Max initializers + NodeArg* min_input = builder.MakeScalarInitializer(5.0f); + NodeArg* max_input = builder.MakeScalarInitializer(100.0f); + + // Clip -> + NodeArg* clip_output = builder.MakeIntermediate(); + builder.AddNode("Clip", {input_dq, min_input, max_input}, {clip_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(clip_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/flatten_op_test.cc b/onnxruntime/test/providers/qnn/flatten_op_test.cc new file mode 100644 index 0000000000000..637d3257ddea7 --- /dev/null +++ b/onnxruntime/test/providers/qnn/flatten_op_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Flatten operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnCPU(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Flatten input (rank4) with axis == 0. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_Axis0) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank4) with axis == -1. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_AxisNeg1) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank5) with axis == 2. +TEST_F(QnnCPUBackendTests, Flatten_Rank5_Axis2) { + RunFlattenTestOnCPU(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a model with a non-QDQ Flatten operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Flatten model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Flatten", {input_def}, {}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Flatten", {input_def}, {}, attrs, kOnnxDomain, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten with an input of rank5. +TEST_F(QnnHTPBackendTests, Flatten_QDQ8bit_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Flatten -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 2, 3, 4, 5}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Flatten -> + NodeArg* flatten_output = builder.MakeIntermediate(); + Node& flatten_node = builder.AddNode("Flatten", {input_dq}, {flatten_output}); + flatten_node.AddAttribute("axis", static_cast(2)); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(flatten_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test that int32 non-QDQ Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank4_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +// Test that rank 5 int32 Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank5_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 5b05b39f34a27..37e0db906d054 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -14,47 +15,14 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float model with a Gather op. -template -static GetTestModelFn BuildGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* indices = MakeTestInput(builder, indices_def); - NodeArg* output = builder.MakeOutput(); - - Node& gather_node = builder.AddNode("Gather", {input, indices}, {output}); - gather_node.AddAttribute("axis", axis); - }; -} - -// Function that builds a QDQ model with a Gather op. -template -static GetTestQDQModelFn BuildQDQGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - NodeArg* indices = MakeTestInput(builder, indices_def); - - NodeArg* gather_output = builder.MakeIntermediate(); - Node& gather_node = builder.AddNode("Gather", {input_qdq, indices}, {gather_output}); - gather_node.AddAttribute("axis", axis); - - AddQDQNodePairWithOutputAsGraphOutput(builder, gather_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - // Test the accuracy of a QDQ Gather model on QNN EP. Checks if the QDQ model on QNN EP as accurate as the QDQ model on CPU EP // (compared to float32 model). template -static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestInputDef& indices_def, - int64_t axis, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { +static void RunQDQGatherOpTest(const TestInputDef& input_def, + const TestInputDef& indices_def, + const std::vector& attrs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -62,12 +30,14 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildGatherOpTestCase(input_def, indices_def, axis), - BuildQDQGatherOpTestCase(input_def, indices_def, axis), + auto f32_model_builder = BuildOpTestCase("Gather", {input_def}, {indices_def}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Gather", {input_def}, {indices_def}, attrs); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -77,7 +47,7 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -86,7 +56,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::None); } @@ -98,7 +68,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -110,7 +80,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -122,7 +92,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis1) { RunQDQGatherOpTest(TestInputDef({3, 3}, false, {1.0f, 1.2f, 1.9f, 2.3f, 3.4f, 3.9f, 4.5f, 5.7f, 5.9f}), TestInputDef({1, 2}, true, {0, 2}), - 1, + {utils::MakeAttribute("axis", static_cast(1))}, 13, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc new file mode 100644 index 0000000000000..15f26717b06fd --- /dev/null +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Gemm operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunGemmTestOnCPU(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Gemm", input_defs, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Gemm with non-default 'alpha' or 'beta' attributes is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) { + // Check that alpha != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("alpha", 1.5f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // Check that beta != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {utils::MakeAttribute("beta", 1.2f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1). +// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12); + + // 2D matrix mul with bias not supported. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data), + TestInputDef({2, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data)}, + {}, + ExpectedEPNodeAssignment::All); // Assigned to QNN EP. +} + +// Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). +TEST_F(QnnCPUBackendTests, Gemm_Dynamic_A_B_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that builds a model with a QDQ Gemm node. +template +inline GetTestQDQModelFn BuildQDQGemmTestCase(const std::vector>& input_defs, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_defs, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const size_t num_inputs = input_defs.size(); + assert(num_inputs == 2 || num_inputs == 3); + + std::vector op_inputs; + op_inputs.reserve(num_inputs); + + // Process input 0 + NodeArg* input0 = MakeTestInput(builder, input_defs[0]); + QuantParams input0_qparams = GetTestInputQuantParams(input_defs[0]); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input0_after_qdq); + + // Process input 1 + NodeArg* input1 = MakeTestInput(builder, input_defs[1]); + QuantParams input1_qparams = GetTestInputQuantParams(input_defs[1]); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input1_after_qdq); + + // Process bias + if (num_inputs == 3) { + NodeArg* bias_input = MakeTestQDQBiasInput(builder, input_defs[2], input0_qparams.scale * input1_qparams.scale, + use_contrib_qdq); + op_inputs.push_back(bias_input); + } + + // Op -> op_output + auto* gemm_output = builder.MakeIntermediate(); + Node& gemm_node = builder.AddNode("Gemm", op_inputs, {gemm_output}); + + for (const auto& attr : attrs) { + gemm_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, gemm_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Gemm model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQGemmTestOnHTP(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + float f32_abs_err = 1e-4f, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + auto f32_model_builder = BuildOpTestCase("Gemm", input_defs, {}, attrs); + auto qdq_model_builder = BuildQDQGemmTestCase(input_defs, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment, + f32_abs_err); +} + +// Test 8-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 0 (err 120.73912048339844) +// CPU QDQ val: 120.73889923095703 (err 0.00022125244140625) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 1e-4f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm (16bit act, 8bit weight) with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 120.48043823242188 (err 0.2586822509765625) +// CPU QDQ val: 120.48980712890625 (err 0.2493133544921875) +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with dynamic A and B inputs. The Bias is static. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.48132994771003723, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 77.012794494628906 (err 43.726325988769531) +// CPU QDQ val: 119.85115814208984 (err 0.88796234130859375) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_B_Static_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), // Dynamic => inaccuracy + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm with static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm (16bit activation, 8bit weight) with transposed A/B and static B and Bias inputs. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.00047966410056687891, zero_point=0. +// Expected val: 29.434776306152344 +// QNN QDQ val: 29.191877365112305 (err 0.24289894104003906) +// CPU QDQ val: 29.197153091430664 (err 0.23762321472167969) +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc index 594973e37ef0b..f662ac14336f8 100644 --- a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc @@ -16,25 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float32 model with an InstanceNormalization operator. -GetTestModelFn BuildInstanceNormTestCase(const TestInputDef& input_def, - const TestInputDef& scale_def, - const TestInputDef& bias_def, - const std::vector& attrs) { - return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* scale = MakeTestInput(builder, scale_def); - NodeArg* bias = MakeTestInput(builder, bias_def); - - NodeArg* output = builder.MakeOutput(); - Node& op_node = builder.AddNode("InstanceNormalization", {input, scale, bias}, {output}); - - for (const auto& attr : attrs) { - op_node.AddAttributeProto(attr); - } - }; -} - // Function that builds a QDQ model with an InstanceNormalization operator. template static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInputDef& input_def, @@ -93,7 +74,7 @@ static void RunInstanceNormQDQTest(const TestInputDef& input_def, #endif // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildInstanceNormTestCase(input_def, scale_def, bias_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs), BuildQDQInstanceNormTestCase(input_def, scale_def, bias_def, attrs), provider_options, 18, diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index aa6c6a142e6d1..085454004e5a5 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -29,7 +29,7 @@ static void RunLayerNormCpuTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), provider_options, 17, expected_ep_assignment); @@ -114,7 +114,7 @@ static void RunLayerNormQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), BuildQDQLayerNormTestCase(input_def, scale_def, attrs), provider_options, 17, // opset diff --git a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc index a8237817c71df..e3077ec569923 100644 --- a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -15,42 +16,10 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Creates a function that builds a model with a LeakyRelu operator. -static GetTestModelFn BuildLeakyReluOpTestCase(const TestInputDef& input_def, float alpha) { - return [input_def, alpha](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input}, {output}); - leakyrelu_node.AddAttribute("alpha", alpha); - }; -} - -// Creates a function that builds a QDQ model with a LeakyRelu operator. -template -static GetTestQDQModelFn BuildQDQLeakyReluOpTestCase(const TestInputDef& input_def, - float alpha) { - return [input_def, alpha](ModelTestBuilder& builder, - std::vector>& output_qparams) { - // input => Q => DQ => - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - // LeakryRelu - auto* leakyrelu_output = builder.MakeIntermediate(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input_qdq}, {leakyrelu_output}); - leakyrelu_node.AddAttribute("alpha", alpha); - - // => Q => DQ -> final output - AddQDQNodePairWithOutputAsGraphOutput(builder, leakyrelu_output, output_qparams[0].scale, - output_qparams[0].zero_point); - }; -} - // Checks the accuracy of a QDQ LeakyRelu model by comparing to ORT CPU EP. template static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, - float alpha, + const std::vector& attrs, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; @@ -60,12 +29,11 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildLeakyReluOpTestCase(input_def, alpha), - BuildQDQLeakyReluOpTestCase(input_def, alpha), + TestQDQModelAccuracy(BuildOpTestCase("LeakyRelu", {input_def}, {}, attrs), + BuildQDQOpTestCase("LeakyRelu", {input_def}, {}, attrs), provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -74,7 +42,7 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 15, ExpectedEPNodeAssignment::All); } @@ -85,7 +53,7 @@ TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet16) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 16, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/max_min_op_test.cc b/onnxruntime/test/providers/qnn/max_min_op_test.cc index 09ea71e5f03eb..3deff121f3c72 100644 --- a/onnxruntime/test/providers/qnn/max_min_op_test.cc +++ b/onnxruntime/test/providers/qnn/max_min_op_test.cc @@ -27,7 +27,7 @@ static void RunCPUMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), provider_options, opset, expected_ep_assignment); @@ -48,12 +48,11 @@ static void RunQDQMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), // baseline float32 model - BuildQDQOpTestCase(op_type, input_defs, {}, kOnnxDomain), // QDQ model + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // baseline float32 model + BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // QDQ model provider_options, opset, - expected_ep_assignment, - 1e-4f); + expected_ep_assignment); } // diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index fee10a542fb82..7ed9072a95b32 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -17,21 +17,6 @@ namespace onnxruntime { namespace test { -// Returns a function that creates a graph with a single MaxPool operator. -static GetTestModelFn BuildPoolTestCase(const std::string& op_type, - const TestInputDef& input_def, - const std::vector& attrs) { - return [op_type, input_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& pool_node = builder.AddNode(op_type, {input}, {output}); - - for (const auto& attr : attrs) { - pool_node.AddAttributeProto(attr); - } - }; -} - // Returns a function that creates a graph with a QDQ MaxPool operator. template GetTestQDQModelFn BuildPoolQDQTestCase(const std::string& op_type, @@ -74,7 +59,7 @@ static void RunPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildPoolTestCase(op_type, input_def, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -95,7 +80,7 @@ static void RunQDQPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildPoolTestCase(op_type, input_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, attrs), BuildPoolQDQTestCase(op_type, input_def, attrs), provider_options, opset, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 724e9a11cd781..51df93f8853ec 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -73,7 +73,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOption void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals) { + std::vector& output_vals) { SessionOptions so; so.session_logid = log_id; RunOptions run_options; @@ -102,14 +102,12 @@ void InferenceModel(const std::string& model_data, const char* log_id, } const auto& outputs = graph.GetOutputs(); + std::vector output_names; - // fetch all outputs if necessary. - if (output_names.empty()) { - output_names.reserve(outputs.size()); - for (const auto* node_arg : outputs) { - if (node_arg->Exists()) { - output_names.push_back(node_arg->Name()); - } + output_names.reserve(outputs.size()); + for (const auto* node_arg : outputs) { + if (node_arg->Exists()) { + output_names.push_back(node_arg->Name()); } } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index fd572fa17f2b1..14c62f98f6a3e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -213,13 +213,12 @@ inline QuantParams GetTestInputQuantParams(const TestInputDef& inp * \param execution_provider The EP on which to run the model. Set to nullptr for CPU EP. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. * \param feeds The input feeds. - * \param output_names If empty, the function will write the output names. * \param output_vals Initialized to the inference results. */ void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals); + std::vector& output_vals); /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: @@ -263,9 +262,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; - std::vector output_names; InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, - f32_helper.feeds_, output_names, cpu_f32_outputs); + f32_helper.feeds_, cpu_f32_outputs); ASSERT_FALSE(cpu_f32_outputs.empty()); const size_t num_outputs = cpu_f32_outputs.size(); @@ -304,13 +302,13 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. std::vector qnn_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, output_names, qnn_qdq_outputs); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { // Run QDQ model on CPU EP and collect outputs. std::vector cpu_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", nullptr, ExpectedEPNodeAssignment::All, - qdq_helper.feeds_, output_names, cpu_qdq_outputs); + qdq_helper.feeds_, cpu_qdq_outputs); ASSERT_EQ(cpu_qdq_outputs.size(), num_outputs); ASSERT_EQ(qnn_qdq_outputs.size(), num_outputs); @@ -320,7 +318,9 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Compare accuracy of QDQ results with float model. // QNN EP must be at least as accurate as CPU EP when running the QDQ model. + const std::string base_output_name = "output_"; for (size_t i = 0; i < num_outputs; i++) { + std::string debug_output_name = base_output_name + std::to_string(i); auto& cpu_qdq_tensor = cpu_qdq_outputs[i].Get(); auto& qnn_qdq_tensor = qnn_qdq_outputs[i].Get(); @@ -353,8 +353,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe } EXPECT_TRUE(is_as_accurate_as_cpu_qdq) - << "Inaccuracy detected for output '" - << output_names[i] + << "Inaccuracy detected for output '" << debug_output_name << "', element " << j << ".\nOutput quant params: scale=" << output_qparams[i].scale << ", zero_point=" << static_cast(output_qparams[i].zero_point) @@ -363,7 +362,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe << "CPU QDQ val: " << cpu_qdq_val << " (err " << cpu_err << ")"; } } else { - VerifyOutput(output_names[i], cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); + VerifyOutput(debug_output_name, cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); } } } @@ -438,25 +437,33 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef +template inline GetTestModelFn BuildOpTestCase(const std::string& op_type, - const std::vector>& input_defs, + const std::vector>& input_defs_1, + const std::vector>& input_defs_2, const std::vector& attrs, const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs, attrs, op_domain](ModelTestBuilder& builder) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain](ModelTestBuilder& builder) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); + + for (const auto& input_def : input_defs_1) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } - for (const auto& input_def : input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + for (const auto& input_def : input_defs_2) { + NodeArg* input = MakeTestInput(builder, input_def); op_inputs.push_back(input); } @@ -470,7 +477,8 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, } /** - * Returns a function that builds a model with a single QDQ operator with N inputs of the same element type. + * Returns a function that builds a model with a single QDQ operator with N float (quantizeable) inputs + * and M inputs of a potentially different type. * * \param op_type The operator to instantiate. * \param input_defs List of input definitions. @@ -478,25 +486,33 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \param op_domain The operator's domain. Defaults to the ONNX domain (i.e., ""). * \returns A model building function. */ -template -inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, - const std::vector>& input_defs, - const std::vector& attrs, - const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { - return [op_type, input_defs, attrs, op_domain, - use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { +template +inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, + const std::vector>& quant_input_defs, + const std::vector>& non_quant_input_defs, + const std::vector& attrs, + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { + return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, + use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); - for (const auto& input_def : input_defs) { + // Create QDQ inputs + for (const auto& input_def : quant_input_defs) { NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, - input_qparams.zero_point, use_contrib_qdq); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); op_inputs.push_back(input_after_qdq); } + // Create non-QDQ inputs + for (const auto& input_def : non_quant_input_defs) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } + // Op -> op_output auto* op_output = builder.MakeIntermediate(); Node& onnx_node = builder.AddNode(op_type, op_inputs, {op_output}, op_domain); @@ -506,8 +522,8 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty } // op_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point, use_contrib_qdq); + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } diff --git a/onnxruntime/test/providers/qnn/reshape_op_test.cc b/onnxruntime/test/providers/qnn/reshape_op_test.cc new file mode 100644 index 0000000000000..eb495e44ec770 --- /dev/null +++ b/onnxruntime/test/providers/qnn/reshape_op_test.cc @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Reshape operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnCPU(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test Reshape of rank 4 -> rank 2. +TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Reshape operator. +template +GetTestQDQModelFn BuildQDQReshapeTestCase(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, shape_def, attrs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // shape input + NodeArg* shape_input = MakeTestInput(builder, shape_def); + + // Reshape op + NodeArg* reshape_output = builder.MakeIntermediate(); + Node& reshape_node = builder.AddNode("Reshape", {input_qdq, shape_input}, {reshape_output}); + + for (const auto& attr : attrs) { + reshape_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, reshape_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ Reshape operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Reshape model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs); + auto qdq_model_builder = BuildQDQReshapeTestCase(input_def, shape_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test 8-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u8) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test 16-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u16) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // Opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Reshape runs on HTP backend. +TEST_F(QnnHTPBackendTests, Reshape_4D_int32) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunReshapeTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({3}, true, {1, 1, 12}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of 0 (copy dimension from input) +TEST_F(QnnHTPBackendTests, Reshape_4D_0MeansCopy) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 0, 16}), // zero means copy => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of -1 (dimension is inferred from the expected number of elements) +TEST_F(QnnHTPBackendTests, Reshape_4D_Neg1MeansInfer) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 3, -1}), // -1 means infer => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 63498982930f5..f77c098f72116 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -32,7 +32,7 @@ static void RunOpTestOnCPU(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -113,8 +113,8 @@ static void RunQDQOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs, op_domain), - BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain, use_contrib_qdq), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs, op_domain, use_contrib_qdq), provider_options, opset_version, expected_ep_assignment, @@ -137,7 +137,7 @@ static void RunOpTest(const std::string& op_type, #endif // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -698,8 +698,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); @@ -708,8 +708,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run will load and run from Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); diff --git a/onnxruntime/test/providers/qnn/slice_htp_test.cc b/onnxruntime/test/providers/qnn/slice_htp_test.cc index f7163f04736a5..edc079dc65276 100644 --- a/onnxruntime/test/providers/qnn/slice_htp_test.cc +++ b/onnxruntime/test/providers/qnn/slice_htp_test.cc @@ -16,51 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a model with a Slice operator. -template -GetTestModelFn BuildSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder) { - NodeArg* data = MakeTestInput(builder, data_def); - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - NodeArg* output = builder.MakeOutput(); - builder.AddNode("Slice", {data, starts, ends, axes, steps}, {output}); - }; -} - -// Function that builds a QDQ model with a Slice operator. -template -static GetTestQDQModelFn BuildQDQSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* data = MakeTestInput(builder, data_def); - QuantParams data_qparams = GetTestInputQuantParams(data_def); - NodeArg* data_qdq = AddQDQNodePair(builder, data, data_qparams.scale, data_qparams.zero_point); - - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - auto* slice_output = builder.MakeIntermediate(); - builder.AddNode("Slice", {data_qdq, starts, ends, axes, steps}, {slice_output}); - - // Add output -> Q -> output_u8 - AddQDQNodePairWithOutputAsGraphOutput(builder, slice_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - /** * Runs an Slice model on the QNN HTP backend. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. @@ -86,13 +41,14 @@ static void RunSliceQDQTest(const TestInputDef& data_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - // Runs model with DQ-> Slice -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), - BuildQDQSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + const std::vector> f32_inputs = {data_def}; + const std::vector> int64_inputs = {starts_def, ends_def, axes_def, steps_def}; + + TestQDQModelAccuracy(BuildOpTestCase("Slice", f32_inputs, int64_inputs, {}), + BuildQDQOpTestCase("Slice", f32_inputs, int64_inputs, {}), provider_options, 18, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } /** @@ -119,12 +75,12 @@ static void RunSliceNonQDQOnHTP(const TestInputDef& data_def, #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - - RunQnnModelTest(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + auto f32_model_builder = BuildOpTestCase("Slice", {data_def}, + {starts_def, ends_def, axes_def, steps_def}, {}); + RunQnnModelTest(f32_model_builder, provider_options, 13, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Check that QNN compiles DQ -> Slice -> Q as a single unit. diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc new file mode 100644 index 0000000000000..57e4b211777bb --- /dev/null +++ b/onnxruntime/test/providers/qnn/split_op_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +GetTestModelFn BuildSplitTestCase(const TestInputDef& input_def, + const std::vector& split, bool split_is_input, + int64_t axis, int64_t num_outputs) { + return [input_def, split, split_is_input, axis, num_outputs](ModelTestBuilder& builder) { + std::vector op_inputs; + + op_inputs.push_back(MakeTestInput(builder, input_def)); + + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeOutput()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + }; +} + +template +static void RunSplitOpTestOnCPU(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test Split opset 18 on CPU backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Return function that builds a model with a QDQ Split. +template +GetTestQDQModelFn BuildQDQSplitTestCase(const TestInputDef& input_def, + const std::vector& split, + bool split_is_input, + int64_t axis, + int64_t num_outputs, + bool use_contrib_qdq = false) { + return [input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector op_inputs; + + // Add QDQ input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input_after_qdq); + + // Add split input + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeIntermediate()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + + // op_output -> Q -> DQ -> output + assert(output_qparams.size() == actual_num_outputs); + for (size_t i = 0; i < actual_num_outputs; i++) { + // NOTE: Input and output quantization parameters must be equal for Split. + output_qparams[i] = input_qparams; + AddQDQNodePairWithOutputAsGraphOutput(builder, split_outputs[i], output_qparams[i].scale, + output_qparams[i].zero_point, use_contrib_qdq); + } + }; +} + +// Runs a non-QDQ Split operator on the HTP backend. +template +static void RunSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Split operator on the HTP backend. +template +static void RunQDQSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + auto f32_model_builder = BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs); + auto qdq_model_builder = BuildQDQSplitTestCase(input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that HTP can run non-QDQ Split (int32 input). +TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { + // Equal split. + RunSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18_U16) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 13. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 11. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc new file mode 100644 index 0000000000000..33d2f64c0315e --- /dev/null +++ b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Squeeze (or Unsqueeze) operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnCPU(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnCPUBackendTests, Squeeze_Rank5_Rank2_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 1, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Squeeze axes 0 and 2 => (3, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnCPUBackendTests, Squeeze_Rank4_Rank3_NegAxes_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank5_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({3, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Add 1's => (1, 3, 1, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ (Un)Squeeze operator. +template +GetTestQDQModelFn BuildQDQSqueezeTestCase(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + bool use_contrib_qdq = false) { + return [op_type, input_def, axes_def, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // axes input + NodeArg* axes_input = MakeTestInput(builder, axes_def); + + // (Un)Squeeze op + NodeArg* op_output = builder.MakeIntermediate(); + builder.AddNode(op_type, {input_qdq, axes_input}, {op_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for (Un)Squeeze. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ (Un)Squeeze operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnHTP(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ (Un)Squeeze model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and +// that inference running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP +// (when compared to the baseline float32 model). +template +static void RunQDQSqueezeTestOnHTP(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase(op_type, {input_def}, {axes_def}, {}); + auto qdq_model_builder = BuildQDQSqueezeTestCase(op_type, input_def, axes_def, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnHTPBackendTests, Squeeze_Rank5_Rank2_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Squeeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 3, 1, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Squeeze axes 0 and 2 => (3, 2, 4) + + // Squeeze -> + NodeArg* squeeze_output = builder.MakeIntermediate(); + builder.AddNode("Squeeze", {input_dq, axes_input}, {squeeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(squeeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank5_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Unsqueeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({3, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Add 1's => (1, 3, 1, 2, 4) + + // Unsqueeze -> + NodeArg* unsqueeze_output = builder.MakeIntermediate(); + builder.AddNode("Unsqueeze", {input_dq, axes_input}, {unsqueeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(unsqueeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Squeeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Squeeze_Int32_Rank4_Rank3) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Squeeze 0th axis => (3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +// Test that int32 Unsqueeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Unsqueeze_Int32_Rank3_Rank4) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Unsqueeze", + TestInputDef({3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Unsqueeze 0th axis => (1, 3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/tile_op_test.cc b/onnxruntime/test/providers/qnn/tile_op_test.cc new file mode 100644 index 0000000000000..2b35c730ee5fe --- /dev/null +++ b/onnxruntime/test/providers/qnn/tile_op_test.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Tile operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTileTestOnCPU(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Test that Tile with a dynamic repeats input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Tile_DynamicRepeats_Unsupported) { + RunTileTestOnCPU(TestInputDef({2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), + TestInputDef({2}, false /* is_initializer */, {1, 2}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Tile with rank 4 float input. +TEST_F(QnnCPUBackendTests, Tile_F32_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunTileTestOnCPU(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Tile operator. +template +GetTestQDQModelFn BuildQDQTileTestCase(const TestInputDef& input_def, + const TestInputDef& repeats_def, + bool use_contrib_qdq = false) { + return [input_def, repeats_def, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // repeats input + NodeArg* repeats_input = MakeTestInput(builder, repeats_def); + + // Tile op + NodeArg* tile_output = builder.MakeIntermediate(); + builder.AddNode("Tile", {input_qdq, repeats_input}, {tile_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Tile. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, tile_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Tile model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTileTestOnHTP(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}); + auto qdq_model_builder = BuildQDQTileTestCase(input_def, repeats_def, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U8_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U16_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/topk_op_test.cc b/onnxruntime/test/providers/qnn/topk_op_test.cc new file mode 100644 index 0000000000000..93e725af5f20e --- /dev/null +++ b/onnxruntime/test/providers/qnn/topk_op_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that builds a model with a TopK operator. +template +inline GetTestModelFn BuildTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool cast_output_indices = true) { + return [input_def, k_def, attrs, cast_output_indices](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* k_input = MakeTestInput(builder, k_def); + + NodeArg* values_output = builder.MakeOutput(); + NodeArg* indices_output = cast_output_indices ? builder.MakeIntermediate() : builder.MakeOutput(); + Node& topk_node = builder.AddNode("TopK", {input, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // Cast indices to uint32 + if (cast_output_indices) { + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + } + }; +} + +// Runs a model with a TopK operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTopKTestOnCPU(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs, false /*cast_output_indices*/), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that TopK with a dynamic K input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, false /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("axis", static_cast(1))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK that returns the top k minimum values is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_MinValues_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("largest", static_cast(0))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test TopK on CPU backend: top 2 largest floats from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestFloats_LastAxis) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test TopK on CPU backend: top 2 largest int32s from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestInt32s_LastAxis) { + std::vector input_data = {-6, -5, -4, -3, -2, 0, 1, 2, 3, 4, 5, 6}; + RunTopKTestOnCPU(TestInputDef({1, 2, 2, 3}, false, input_data), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ TopK operator. +template +GetTestQDQModelFn BuildQDQTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, k_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // K input + NodeArg* k_input = MakeTestInput(builder, k_def); + + // Reshape op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + Node& topk_node = builder.AddNode("TopK", {input_qdq, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + + // Cast indices to uint32 (HTP backend does not support int64 graph outputs) + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; +} + +// Runs a QDQ TopK model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTopKTestOnHTP(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs, true /*cast_output_indices*/); + auto qdq_model_builder = BuildQDQTopKTestCase(input_def, k_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +// TODO: Inaccuracy detected for output 'output_0', element 6. +// Output quant params: scale=0.00061036087572574615, zero_point=32768. +// Expected val: -7.2340402603149414 +// QNN QDQ val: -17.446556091308594 (err 10.212515830993652) +// CPU QDQ val: -7.2339968681335449 (err 4.3392181396484375e-05) +TEST_F(QnnHTPBackendTests, DISABLED_TopK_LargestFloats_U16_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) From dd561f201524ea5c78d9fa26df818397139f80af Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 20 Sep 2023 18:44:23 -0700 Subject: [PATCH 057/121] Upgrade sympy (#17639) AB#17015 --- .../docker/inference/x64/python/cpu/scripts/requirements.txt | 2 +- .../github/linux/docker/scripts/manylinux/requirements.txt | 2 +- tools/ci_build/github/linux/docker/scripts/requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index c0c6505ca010d..8a9c4dac1dd58 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -6,5 +6,5 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index c8ff7a804e1df..6b8003c01c24d 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -6,6 +6,6 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers neural-compressor>=2.2.1 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 2248652c98043..9dbe856753faa 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -7,7 +7,7 @@ setuptools>=41.4.0 wheel>=0.35.1 git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx argparse -sympy==1.10.1 +sympy==1.12 flatbuffers protobuf==3.20.2 packaging From 1f991f27f1d4a8d19c877bc4d33457dd45994c80 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 21 Sep 2023 10:45:16 +0800 Subject: [PATCH 058/121] [ROCm] add manylinux build test for ROCm CI (#17621) manylinux build is used for nightly packaging generation and it's hard to capture issue in time when related files change. This PR add manylinux build in CI. --- .../orttraining-pai-ci-pipeline.yml | 107 +++++++++++++++++- .../docker/Dockerfile.manylinux2_28_rocm | 7 ++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 523390debc887..3333a7d22a41b 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -11,6 +11,14 @@ pr: - 'onnxruntime/core/providers/js' name: 'orttraining_ci_$(Date:yyyyMMdd)_$(Rev:r)' +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + variables: - name: video value: 44 @@ -22,7 +30,101 @@ variables: value: Release jobs: -- job: Linux_Build +- job: Linux_Build_manylinux + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all + pool: onnxruntime-Ubuntu2004-AMD-CPU + timeoutInMinutes: 120 + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: recursive + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur + --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 + --build-arg BUILD_UID=$(id -u) + --build-arg ROCM_VERSION=$(RocmVersion) + --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root + --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: + --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib + Repository: onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build + + - task: Cache@2 + inputs: + key: '"manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + cacheHitVar: CACHE_RESTORED + restoreKeys: | + "manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" + "manylinux" | "$(TODAY)" | + displayName: Cache Task + + - script: mkdir -p $(CCACHE_DIR) + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - task: CmdLine@2 + inputs: + script: |- + export ROCM_HOME=/opt/rocm + docker run --rm \ + --ipc=host \ + --network=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --user $UID:$(id -g $USER) \ + -e CC=/opt/rh/gcc-toolset-12/root/usr/bin/cc -e CXX=/opt/rh/gcc-toolset-12/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \ + -e CCACHE_DIR=/cache \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume $(CCACHE_DIR):/cache \ + --workdir /onnxruntime_src \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build \ + /bin/bash -c " + set -ex; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 tools/ci_build/build.py \ + --config $(BuildConfig) \ + --enable_training \ + --mpi_home /opt/ompi \ + --cmake_extra_defines \ + CMAKE_HIP_COMPILER=${ROCM_HOME}/llvm/bin/clang++ \ + onnxruntime_BUILD_UNIT_TESTS=OFF \ + FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ + --use_cache \ + --use_rocm \ + --rocm_version=$(RocmVersion) \ + --rocm_home ${ROCM_HOME} \ + --nccl_home ${ROCM_HOME}\ + --update \ + --build_dir /build \ + --build \ + --parallel \ + --build_wheel \ + --skip_submodule_sync \ + --skip_tests; \ + ccache -sv; \ + ccache -z" + displayName: 'Build onnxruntime' + + - template: templates/explicitly-defined-final-tasks.yml + +- job: Linux_Build_ubuntu variables: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache @@ -115,8 +217,7 @@ jobs: - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test +- job: Linux_Test_ubuntu workspace: clean: all pool: AMD-GPU diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 10ce8f0ed65f7..19599c9f613d4 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -185,6 +185,13 @@ RUN cd /tmp/scripts && \ rm -rf /tmp/scripts +# Install ccache to reuse this dockerfile for CI +RUN mkdir -p /tmp/ccache && \ + cd /tmp/ccache && \ + wget -q -O - https://github.com/ccache/ccache/releases/download/v4.7.4/ccache-4.7.4-linux-x86_64.tar.xz | tar --strip 1 -J -xf - && \ + cp /tmp/ccache/ccache /usr/bin && \ + rm -rf /tmp/ccache + ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER From 4f3f4366d5838f80b45060c26bcd648ee387a252 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 20 Sep 2023 19:51:50 -0700 Subject: [PATCH 059/121] Fix API 16's marker (#17640) --- onnxruntime/core/session/onnxruntime_c_api.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4c0adcdd374aa..60b6296f7f539 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2711,9 +2711,8 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::GetTensorRTProviderOptionsByName, &OrtApis::UpdateCUDAProviderOptionsWithValue, &OrtApis::GetCUDAProviderOptionsByName, - // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::KernelContext_GetResource, + // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,7 +2741,7 @@ static_assert(offsetof(OrtApi, ReleaseKernelInfo) / sizeof(void*) == 218, "Size static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 224, "Size of version 13 API cannot change"); static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); -static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.17.0", From 038c76378fdee45261d43af45466a0797e6ad124 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Thu, 21 Sep 2023 00:08:10 -0700 Subject: [PATCH 060/121] Include onnxruntime_float16.h in the package. (#17637) ### Description Include onnxruntime_float16.h in the package. ### Motivation and Context This was missed in the recently released 1.16 pkgs (except Nuget). --- tools/ci_build/github/linux/copy_strip_binary.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index b875a3937aaa9..63690b69fc91a 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -48,6 +48,7 @@ fi cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include From 57dfd15d7bc9d9c5779896f6685ec473875dc6e1 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 21 Sep 2023 07:33:29 -0700 Subject: [PATCH 061/121] Remove dnf update from docker build scripts (#17551) ### Description 1. Remove 'dnf update' from docker build scripts, because it upgrades TRT packages from CUDA 11.x to CUDA 12.x. To reproduce it, you can run the following commands in a CentOS CUDA 11.x docker image such as nvidia/cuda:11.8.0-cudnn8-devel-ubi8. ``` export v=8.6.1.6-1.cuda11.8 dnf install -y libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v} libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} dnf update -y ``` The last command will generate the following outputs: ``` ======================================================================================================================== Package Architecture Version Repository Size ======================================================================================================================== Upgrading: libnvinfer-devel x86_64 8.6.1.6-1.cuda12.0 cuda 542 M libnvinfer-headers-devel x86_64 8.6.1.6-1.cuda12.0 cuda 118 k libnvinfer-headers-plugin-devel x86_64 8.6.1.6-1.cuda12.0 cuda 14 k libnvinfer-plugin-devel x86_64 8.6.1.6-1.cuda12.0 cuda 13 M libnvinfer-plugin8 x86_64 8.6.1.6-1.cuda12.0 cuda 13 M libnvinfer-vc-plugin-devel x86_64 8.6.1.6-1.cuda12.0 cuda 107 k libnvinfer-vc-plugin8 x86_64 8.6.1.6-1.cuda12.0 cuda 251 k libnvinfer8 x86_64 8.6.1.6-1.cuda12.0 cuda 543 M libnvonnxparsers-devel x86_64 8.6.1.6-1.cuda12.0 cuda 467 k libnvonnxparsers8 x86_64 8.6.1.6-1.cuda12.0 cuda 757 k libnvparsers-devel x86_64 8.6.1.6-1.cuda12.0 cuda 2.0 M libnvparsers8 x86_64 8.6.1.6-1.cuda12.0 cuda 854 k Installing dependencies: cuda-toolkit-12-0-config-common noarch 12.0.146-1 cuda 7.7 k cuda-toolkit-12-config-common noarch 12.2.140-1 cuda 7.9 k libcublas-12-0 x86_64 12.0.2.224-1 cuda 361 M libcublas-devel-12-0 x86_64 12.0.2.224-1 cuda 397 M Transaction Summary ======================================================================================================================== ``` As you can see from the output, they are CUDA 12 packages. The problem can also be solved by lock the packages' versions by using "dnf versionlock" command right after installing the CUDA/TRT packages. However, going forward, to get the better reproducibility, I suggest manually fix dnf package versions in the installation scripts like we do for TRT now. ```bash v="8.6.1.6-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} ``` When we have a need to upgrade a package due to security alert or some other reasons, we manually change the version string instead of relying on "dnf update". Though this approach increases efforts, it can make our pipeines more stable. 2. Move python test to docker ### Motivation and Context Right now the nightly gpu package mixes using CUDA 11.x and CUDA 12.x and the result package is totally not usable(crashes every time) --- .../azure-pipelines/linux-ci-pipeline.yml | 7 +- .../py-package-test-pipeline.yml | 37 +++--- .../templates/c-api-linux-cpu.yml | 2 +- .../templates/py-package-smoking-test.yml | 28 ++--- .../templates/py-packaging-linux-test-cpu.yml | 117 ++++++++++++++++++ .../py-packaging-linux-test-cuda.yml | 98 +++++++++++++++ .../templates/py-packaging-linux-test.yml | 85 ------------- .../linux/docker/Dockerfile.manylinux2_28_cpu | 9 +- .../docker/Dockerfile.manylinux2_28_cuda11 | 5 +- ...kerfile.manylinux2_28_cuda11_6_tensorrt8_4 | 5 +- ...kerfile.manylinux2_28_cuda11_6_tensorrt8_5 | 5 +- ...kerfile.manylinux2_28_cuda11_8_tensorrt8_6 | 5 +- ...Dockerfile.manylinux2_28_training_cuda11_8 | 3 - ...erfile.package_ubuntu_cuda11_8_tensorrt8_6 | 20 +-- .../default/cpu/scripts/install_centos.sh | 7 +- .../default/cpu/scripts/install_deps.sh | 24 ++-- .../inference/x64/default/cpu/Dockerfile | 4 +- .../x64/default/cpu/scripts/install_centos.sh | 8 +- .../inference/x64/default/gpu/Dockerfile | 2 + .../x64/default/gpu/scripts/install_centos.sh | 8 +- .../python/cpu/Dockerfile.manylinux2_28_cpu | 3 - .../x64/python/cpu/scripts/install_centos.sh | 6 +- .../github/linux/docker/manylinux.patch | 9 +- .../linux/docker/scripts/install_dotnet.sh | 10 +- .../scripts/manylinux/install_centos.sh | 9 +- .../docker/scripts/manylinux/install_deps.sh | 26 ++-- .../scripts/manylinux/install_deps_aten.sh | 2 +- .../scripts/manylinux/install_deps_eager.sh | 2 +- .../github/linux/run_python_dockertest.sh | 29 +++++ .../ci_build/github/linux/run_python_tests.sh | 20 ++- tools/scripts/python_test.sh | 0 tools/scripts/symbolic_shape_infer_test.sh | 0 32 files changed, 351 insertions(+), 244 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml delete mode 100644 tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml create mode 100755 tools/ci_build/github/linux/run_python_dockertest.sh mode change 100644 => 100755 tools/scripts/python_test.sh mode change 100644 => 100755 tools/scripts/symbolic_shape_infer_test.sh diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 21bc1c481b3e6..33fc9d94bac09 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -200,8 +200,11 @@ stages: - stage: arm64_test dependsOn: ['arm64_build'] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/almalinux:8' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c684e08ba1258..2161a9205f22d 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -3,24 +3,38 @@ resources: - pipeline: build source: 'Python packaging pipeline' trigger: true + branch: main # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + #TODO: Remove the following dependency. Running python tests should not need to use manylinux. + repositories: + - repository: manylinux # The name used to reference this repository in the checkout step + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - stage: Linux_Test_CPU_x86_64_stage jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' - device: 'CPU' + base_image: 'registry.access.redhat.com/ubi8/ubi' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' - stage: Linux_Test_CPU_aarch64_stage dependsOn: [] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'aiinfra-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/almalinux:8' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' - stage: Packages_Somking_Test dependsOn: [] @@ -31,19 +45,6 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_64_Wheels - itemPattern: '*/*win_amd64.whl' - machine_pool: - vmImage: 'windows-2022' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_32_Wheels - itemPattern: '*/*win32.whl' - python_arch: 'x86' - machine_pool: - vmImage: 'windows-2022' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels @@ -61,7 +62,7 @@ stages: - Linux_Test_CPU_aarch64_stage - Packages_Somking_Test jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cuda.yml parameters: arch: 'x86_64' machine_pool: 'Onnxruntime-Linux-GPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index 796938dc22a67..15fcec0511741 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -68,7 +68,7 @@ jobs: script: | mkdir -p $HOME/.onnx docker run --rm -e CFLAGS="${{parameters.OnnxruntimeCFlags}}" -e CXXFLAGS="${{parameters.OnnxruntimeCXXFlags}}" --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3 \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3.9 \ /onnxruntime_src/tools/ci_build/build.py --build_java --build_nodejs --build_dir /build --config Release \ --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/linux-${{parameters.OnnxruntimeArch}}" workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index cee3bd9c9e968..8d5ca19a73535 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -39,36 +39,22 @@ jobs: versionSpec: $(PythonVersion) architecture: ${{ parameters.python_arch }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime' - targetPath: '$(Build.BinariesDirectory)/whl' - itemPattern: ${{parameters.itemPattern}} - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific + - download: build # pipeline resource identifier. + artifact: 'onnxruntime' - task: Bash@3 inputs: targetType: 'inline' script: | set -ex - files=(whl/*.whl) + files=(*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Build.BinariesDirectory)/whl" $PYTHON_PACKAGE_NAME - pip show $PYTHON_PACKAGE_NAME - python -c "import onnxruntime as ort; print(ort.__version__)" - workingDirectory: $(Build.BinariesDirectory) + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip show $PYTHON_PACKAGE_NAME + python3 -c "import onnxruntime as ort; print(ort.__version__)" + workingDirectory: $(Pipeline.Workspace)/build/onnxruntime displayName: Test Package Installation - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml new file mode 100644 index 0000000000000..cc90085e184dc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -0,0 +1,117 @@ +parameters: +- name: arch + type: string + +- name: base_image + type: string + +- name: devtoolset_rootpath + type: string + +- name: ld_library_path_arg + type: string + +- name: prepend_path + type: string + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + - download: current # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: current # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + #- task: PostAnalysis@2 + # inputs: + # GdnBreakAllTools: true + # GdnBreakPolicy: M365 + # GdnBreakPolicyMinSev: Error + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu + Context: tools/ci_build/github/linux/docker/inference/x64/python/cpu + DockerBuildArgs: "--build-arg POLICY=manylinux_2_28 --build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ parameters.base_image }} --build-arg PLATFORM=${{ parameters.arch }} --build-arg PREPEND_PATH=${{ parameters.prepend_path }} --build-arg LD_LIBRARY_PATH_ARG=${{ parameters.ld_library_path_arg }} --build-arg DEVTOOLSET_ROOTPATH=${{ parameters.devtoolset_rootpath }}" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + ${{ if eq(parameters.arch, 'aarch64') }}: + UpdateDepsTxt: false + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d CPU -c ${{parameters.cmake_build_type}} -i onnxruntimecpubuildpython${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml new file mode 100644 index 0000000000000..43ed0172825bc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -0,0 +1,98 @@ +parameters: +- name: arch + type: string + +- name: device + type: string + values: + - CPU + - GPU + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + # - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-gpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + #- task: PostAnalysis@2 + # inputs: + # GdnBreakAllTools: true + # GdnBreakPolicy: M365 + # GdnBreakPolicyMinSev: Error + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml deleted file mode 100644 index 8ddc917e8591e..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml +++ /dev/null @@ -1,85 +0,0 @@ -parameters: -- name: arch - type: string - -- name: device - type: string - -- name: machine_pool - type: string - -- name: extra_job_id - type: string - default: '' - -- name: python_wheel_suffix - type: string - default: '' - - -# TODO: Ideally it should fetch information from the build that triggers it -- name: cmake_build_type - type: string - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -- name: timeout - type: number - default: 120 - -jobs: -- job: Linux_Test_${{ parameters.device }}${{ parameters.extra_job_id }}_${{ parameters.arch }} - timeoutInMinutes: ${{ parameters.timeout }} - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: ${{ parameters.machine_pool }} - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'drop-linux-${{ lower(parameters.device) }}-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/${{parameters.cmake_build_type}}' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime${{ parameters.python_wheel_suffix }}' - targetPath: '$(Build.BinariesDirectory)/whl' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - - task: Bash@3 - displayName: 'Bash Script' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_tests.sh - arguments: -d ${{ parameters.device }} -c ${{parameters.cmake_build_type}} - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index a9a1e6b39a8cb..af87852561e0a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,9 +1,9 @@ -ARG BASEIMAGE=amd64/almalinux:8 +ARG BASEIMAGE=registry.access.redhat.com/ubi8/ubi ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 ARG DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root ARG LD_LIBRARY_PATH_ARG=${DEVTOOLSET_ROOTPATH}/usr/lib64:${DEVTOOLSET_ROOTPATH}/usr/lib:${DEVTOOLSET_ROOTPATH}/usr/lib64/dyninst:${DEVTOOLSET_ROOTPATH}/usr/lib/dyninst:/usr/local/lib64 -ARG PREPEND_PATH=${DEVTOOLSET_ROOTPATH}/usr/bin: +ARG PREPEND_PATH=/usr/lib/jvm/msopenjdk-11/bin:${DEVTOOLSET_ROOTPATH}/usr/bin: #Build manylinux2014 docker image begin FROM $BASEIMAGE AS runtime_base @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -137,9 +135,7 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ build_scripts/requirements3.10.txt \ @@ -156,6 +152,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ENV PATH ${DEVTOOLSET_ROOTPATH}/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 index dab8df6703c4f..933b0211b0e6c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -156,7 +153,7 @@ ENV SSL_CERT_FILE=/opt/_internal/certs.pem CMD ["/bin/bash"] #Build manylinux2014 docker image end - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 index 303e83eb23bca..003bb2324c049 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -163,7 +160,7 @@ RUN v="8.4.1-1.cuda11.6" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 index d17e4b24582fe..0337ffc5e00a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -163,7 +160,7 @@ RUN v="8.5.1-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 index 3c0ac22e38b5a..2c953a10cbf64 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -147,7 +145,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.7.txt \ build_scripts/requirements3.8.txt \ @@ -171,7 +168,7 @@ RUN v="8.6.1.6-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 index 326e15d58456a..09ab7951552a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 index c211fa9b9e2b8..83a974469234f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 @@ -7,40 +7,30 @@ # Build base image with required system packages FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base -# The local directory into which to build and install CMAKE -ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ - apt-get install -y sudo git bash unattended-upgrades wget -RUN unattended-upgrade + apt-get install -y git bash wget # Install python3 RUN apt-get install -y --no-install-recommends \ python3 \ python3-pip \ python3-dev \ - python3-wheel &&\ - cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python &&\ - ln -s /usr/bin/pip3 pip; + python3-wheel + RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 # Install TensorRT RUN v="8.6.1.6-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ - sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ + apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} -# Install Valgrind -RUN apt-get install -y valgrind - ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh index a1ade39e57e16..adb0464d6496a 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh @@ -1,9 +1,8 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale \ No newline at end of file +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran +locale diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index 7ecd0525c7e7e..7598ab0a7a536 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -14,20 +14,20 @@ function GetFile { echo "File '$path' already exists. Skipping download" return 0 else - rm -rf $path + rm -rf "$path" fi fi if [[ -f $uri ]]; then echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path + cp "$uri" "$path" return $? fi echo "Downloading $uri" # Use aria2c if available, otherwise use curl if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" + aria2c -q -d "$(dirname $path)" -o "$(basename $path)" "$uri" else curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail fi @@ -38,9 +38,10 @@ mkdir -p /tmp/src cd /tmp/src +CPU_ARCH=$(uname -m) echo "Installing cmake" -GetFile https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-`uname -m`.tar.gz /tmp/src/cmake-3.27.3-linux-`uname -m`.tar.gz -tar -zxf /tmp/src/cmake-3.27.3-linux-`uname -m`.tar.gz --strip=1 -C /usr +GetFile "https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" +tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz @@ -52,7 +53,7 @@ mv ./build-cmake/ninja /usr/bin popd echo "Installing Node.js" -CPU_ARCH=`uname -m` + if [[ "$CPU_ARCH" = "x86_64" ]]; then NODEJS_ARCH=x64 elif [[ "$CPU_ARCH" = "aarch64" ]]; then @@ -64,16 +65,5 @@ fi GetFile https://nodejs.org/dist/v18.17.1/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz tar --strip 1 -xf /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr -# The Python version in CentOS 7's python3 package is no longer supported (3.6) so we will build Python from source. -echo "Installing Python" -PYTHON_VERSION="3.8.17" -GetFile https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz /tmp/src/Python-${PYTHON_VERSION}.tgz -tar -zxf Python-${PYTHON_VERSION}.tgz -pushd Python-${PYTHON_VERSION} -./configure -make -make install -popd - cd / rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile index 0324f377b8e9e..caf9583807b62 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile @@ -5,10 +5,10 @@ ARG BASEIMAGE=amd64/almalinux:8 FROM $BASEIMAGE -ENV PATH /opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh index 8e18a237a807e..b5f8bf1a49a19 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh @@ -1,9 +1,9 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale \ No newline at end of file +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran +locale diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile index 386759890d085..318791072f46d 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile @@ -4,8 +4,10 @@ # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh index 3cf259dc7240e..31e3e40f1b7ee 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh @@ -1,9 +1,9 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y python39-devel python3-devel glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel -locale \ No newline at end of file +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 +locale diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu index 33660cbb3f2e5..06e75ee1a39f6 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -132,7 +130,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh index 98bb730a43776..c81e57c60c9da 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh @@ -1,11 +1,11 @@ #!/bin/bash set -e -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* -yum install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget +yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget # export PATH=/opt/python/cp38-cp38/bin:$PATH @@ -17,4 +17,4 @@ mkdir build cd build cmake .. cmake --install . -cd ../.. \ No newline at end of file +cd ../.. diff --git a/tools/ci_build/github/linux/docker/manylinux.patch b/tools/ci_build/github/linux/docker/manylinux.patch index f1821f9197525..75923e746f93c 100644 --- a/tools/ci_build/github/linux/docker/manylinux.patch +++ b/tools/ci_build/github/linux/docker/manylinux.patch @@ -94,7 +94,7 @@ index 9ef1e99..ec52833 100755 +fi \ No newline at end of file diff --git a/install-runtime-packages.sh b/install-runtime-packages.sh -index 137d2e2..4269afb 100755 +index 137d2e2..203b4bc 100755 --- a/install-runtime-packages.sh +++ b/install-runtime-packages.sh @@ -33,7 +33,7 @@ source $MY_DIR/build_utils.sh @@ -130,7 +130,7 @@ index 137d2e2..4269afb 100755 elif [ "${AUDITWHEEL_ARCH}" == "aarch64" ] || [ "${AUDITWHEEL_ARCH}" == "ppc64le" ] || [ "${AUDITWHEEL_ARCH}" == "s390x" ]; then # Software collection (for devtoolset-10) yum -y install centos-release-scl-rh -@@ -86,19 +88,18 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then +@@ -86,19 +88,21 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then fi elif [ "${AUDITWHEEL_POLICY}" == "manylinux_2_28" ]; then PACKAGE_MANAGER=dnf @@ -148,6 +148,9 @@ index 137d2e2..4269afb 100755 - TOOLCHAIN_DEPS="gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran" - if [ "${AUDITWHEEL_ARCH}" == "x86_64" ]; then - TOOLCHAIN_DEPS="${TOOLCHAIN_DEPS} yasm" ++ if test -f "/etc/yum.repos.d/ubi.repo"; then ++ sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/ubi.repo ++ fi + if [[ -d /usr/local/cuda ]]; then + TOOLCHAIN_DEPS="gcc gcc-c++" + else @@ -155,7 +158,7 @@ index 137d2e2..4269afb 100755 fi elif [ "${AUDITWHEEL_POLICY}" == "musllinux_1_1" ]; then TOOLCHAIN_DEPS="binutils gcc g++ gfortran" -@@ -121,12 +122,6 @@ else +@@ -121,12 +125,6 @@ else exit 1 fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh index b9accb134b26d..c4689ed19c148 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh @@ -2,13 +2,15 @@ set -e -x if [ -f /etc/redhat-release ]; then - dnf update --refresh -y \ - && dnf install -y dotnet-sdk-6.0 + # If you found the following command went successfully but dotnet command still reports no sdk was found, most likely + # it was because the dotnet packages were installed from more than one dnf repos. + dnf install -y dotnet-sdk-6.0 dotnet-runtime-6.0 elif [ -f /etc/os-release ]; then # Get Ubuntu version - declare repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) + declare repo_version + repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) # Download Microsoft signing key and repository - wget https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + wget "https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb" -O packages-microsoft-prod.deb # Install Microsoft signing key and repository dpkg -i packages-microsoft-prod.deb # Clean up diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index 4f544a50cb94d..63b953a95add6 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -1,17 +1,18 @@ #!/bin/bash set -e -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" if [ "$os_major_version" -gt 7 ]; then PACKAGE_MANAGER="dnf" - $PACKAGE_MANAGER install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget + $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget else PACKAGE_MANAGER="yum" - $PACKAGE_MANAGER install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make libunwind bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget + $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make libunwind bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget fi +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm # Install Java # Install automatic documentation generation dependencies -$PACKAGE_MANAGER install -y java-11-openjdk-devel graphviz +$PACKAGE_MANAGER install -y msopenjdk-11 graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh index a1cb4be5b72c9..8c79918120d8d 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh @@ -3,18 +3,20 @@ set -e -x # Development tools and libraries if [ -f /etc/redhat-release ]; then - yum update && yum -y install graphviz - os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) + dnf -y install graphviz elif [ -f /etc/os-release ]; then apt-get update && apt-get install -y graphviz - os_major_version=$(cat /etc/os-release | tr -dc '0-9.'|cut -d \. -f1) else echo "Unsupported OS" exit 1 fi # Install dotnet -source $(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)/install_dotnet.sh +LOCAL_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" +PARENT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)" +# ShellCheck is unable to follow dynamic paths, such as source "$somedir/file". +# shellcheck disable=SC1091 +source "$PARENT_DIR/install_dotnet.sh" if [ ! -d "/opt/conda/bin" ]; then PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11") @@ -22,23 +24,17 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src -GLIBC_VERSION=$(getconf GNU_LIBC_VERSION | cut -f 2 -d \.) - -if [[ $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" -else - LIBDIR="lib" -fi cd /tmp/src -source $(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)/install_shared_deps.sh +# shellcheck disable=SC1091 +source "$LOCAL_DIR/install_shared_deps.sh" cd /tmp/src if ! [ -x "$(command -v protoc)" ]; then - source ${0/%install_deps.sh/..\/install_protobuf.sh} +# shellcheck disable=SC1091 + source "$PARENT_DIR/install_protobuf.sh" fi export ONNX_ML=1 @@ -46,7 +42,7 @@ export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" for PYTHON_EXE in "${PYTHON_EXES[@]}" do - ${PYTHON_EXE} -m pip install -r ${0/%install_deps\.sh/requirements\.txt} + ${PYTHON_EXE} -m pip install -r "${0/%install_deps\.sh/requirements\.txt}" done cd / diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh index ed220b487d06c..1f85f72aef423 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh @@ -11,7 +11,7 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh index e141e0793a2bd..ad3366b0bb3b6 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh @@ -11,7 +11,7 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src diff --git a/tools/ci_build/github/linux/run_python_dockertest.sh b/tools/ci_build/github/linux/run_python_dockertest.sh new file mode 100755 index 0000000000000..332dd9c7284c0 --- /dev/null +++ b/tools/ci_build/github/linux/run_python_dockertest.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e -x +BUILD_CONFIG="Release" + +while getopts "i:d:x:c:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +d) DEVICE=${OPTARG};; +c) BUILD_CONFIG=${OPTARG};; +esac +done + +if [ $DEVICE = "GPU" ]; then + ADDITIONAL_DOCKER_PARAMETER="--gpus all" +fi + +mkdir -p $HOME/.onnx +docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ + --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -w /onnxruntime_src \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + $ADDITIONAL_DOCKER_PARAMETER \ + $DOCKER_IMAGE tools/ci_build/github/linux/run_python_tests.sh -d $DEVICE -c $BUILD_CONFIG diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index c11ea42cd0541..f080c7e8c39d8 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -15,7 +15,8 @@ c) BUILD_CONFIG=${OPTARG};; esac done -cd $BUILD_BINARIESDIRECTORY +export PATH=/opt/python/cp38-cp38/bin:$PATH +cd /build files=(whl/*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) @@ -23,7 +24,7 @@ PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') echo "Package name:$PYTHON_PACKAGE_NAME" -BUILD_ARGS="--build_dir $BUILD_BINARIESDIRECTORY --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " +BUILD_ARGS="--build_dir /build --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " ARCH=$(uname -m) @@ -35,20 +36,15 @@ if [ $BUILD_DEVICE == "GPU" ]; then BUILD_ARGS="$BUILD_ARGS --use_cuda --use_tensorrt --cuda_version=11.8 --tensorrt_home=/usr --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8" fi # We assume the machine doesn't have gcc and python development header files, so we don't build onnxruntime from source -sudo rm -rf /build /onnxruntime_src -sudo ln -s $BUILD_SOURCESDIRECTORY /onnxruntime_src python3 -m pip install --upgrade pip -python3 -m pip uninstall -y $PYTHON_PACKAGE_NAME ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq # Install the packages that are needed for installing the onnxruntime python package -python3 -m pip install -r $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/requirements.txt +python3 -m pip install -r /build/$BUILD_CONFIG/requirements.txt # Install the packages that are needed for running test scripts -# Install the latest ONNX release which may contain not fixed bugs. However, it is what most people use. -python3 -m pip install onnx pytest +python3 -m pip install pytest # The "--no-index" flag is crucial. The local whl folder is just an additional source. Pypi's doc says "there is no # ordering in the locations that are searched" if we don't disable the default one with "--no-index" -python3 -m pip install --no-index --find-links $BUILD_BINARIESDIRECTORY/whl $PYTHON_PACKAGE_NAME -ln -s /data/models $BUILD_BINARIESDIRECTORY -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG +python3 -m pip install --no-index --find-links /build/whl $PYTHON_PACKAGE_NAME +cd /build/$BUILD_CONFIG # Restore file permissions xargs -a perms.txt chmod a+x -python3 $BUILD_SOURCESDIRECTORY/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' +python3 /onnxruntime_src/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' diff --git a/tools/scripts/python_test.sh b/tools/scripts/python_test.sh old mode 100644 new mode 100755 diff --git a/tools/scripts/symbolic_shape_infer_test.sh b/tools/scripts/symbolic_shape_infer_test.sh old mode 100644 new mode 100755 From 5b9cd91a9cddbe7c461c1ad7ca44edd5111ea920 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 21 Sep 2023 22:37:50 +0800 Subject: [PATCH 062/121] [ROCm] fix CI (#17648) fix CI, follow #17621 --- .../github/azure-pipelines/orttraining-pai-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 3333a7d22a41b..8dd1f0c5c6461 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -222,7 +222,7 @@ jobs: clean: all pool: AMD-GPU dependsOn: - - Linux_Build + - Linux_Build_ubuntu timeoutInMinutes: 120 steps: From f299016cbe87a5341e0a8aa69b621555c9d49a35 Mon Sep 17 00:00:00 2001 From: George Nash Date: Thu, 21 Sep 2023 09:25:41 -0700 Subject: [PATCH 063/121] Fix crash on Windows server 2016 on Intel Gen4 Xeon processors (#17611) This adds an additional check before enabling MlasGemmU8S8DispatchAmx for GEMM operations. After checking the CPUID for AMX-TILE and AMX-INT8, an additional check is added that checks value of the XCR0 register. The value in the OXR0 register is set by the OS and indicates support for various CPU features. In this case the bits indicating XTILECFG and XTILEDATA support are checked. ### Description This adds an additional check before enabling MlasGemmU8S8DispatchAmx for GEMM operations. After checking the CPUID for AMX-TILE and AMX-INT8, an additional check is added that checks value of the XCR0 register. The value in the OXR0 register is set by the OS and indicates support for various CPU features. In this case the bits indicating XTILECFG and XTILEDATA support are checked. ### Motivation and Context Fix for crash reported directly by customer. When running older Windows server OS on newer Gen4 Xeon processors. Signed-off-by: Nash --- onnxruntime/core/mlas/lib/platform.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 7e2b117d6f249..96bc1d8010bed 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -112,6 +112,14 @@ MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[ #define _XCR_XFEATURE_ENABLED_MASK 0 #endif +#if !defined(XFEATURE_MASK_XTILE) +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#endif + inline uint64_t MlasReadExtendedControlRegister( @@ -142,11 +150,6 @@ bool MlasInitAMX() { #if defined(__linux__) -#define XFEATURE_XTILECFG 17 -#define XFEATURE_XTILEDATA 18 -#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) -#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) -#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) #define ARCH_GET_XCOMP_PERM 0x1022 #define ARCH_REQ_XCOMP_PERM 0x1023 @@ -417,7 +420,9 @@ Return Value: // Check if the processor supports AMX-TILE and AMX-INT8 // features. // - if ((Cpuid7[3] & 0b1 << 24) != 0 && (Cpuid7[3] & 0b1 << 25) != 0) { + if ((Cpuid7[3] & 0b1 << 24) != 0 && + (Cpuid7[3] & 0b1 << 25) != 0 && + (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { if (MlasInitAMX()) { this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; From d56fc7ebf5377abc96db728eafaffd8bf79a3b81 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Thu, 21 Sep 2023 14:16:41 -0700 Subject: [PATCH 064/121] Layer norm fusion deepspeed stage3 changes (#17614) ### Description Layer norm fusion changes required for deepspeed stage 3, also includes test case. ### Motivation and Context It helps fusing layer norm for Deepspeed Stage 3. Added a test case scenario which ensures that the fusion is working properly for the scenario. --- .../core/optimizer/layer_norm_fusion.cc | 42 ++++----- .../graph_transform_test_layernorm.cc | 34 ++++++++ .../fusion/layer_norm_fusion_scale_bias.onnx | Bin 0 -> 854 bytes .../fusion/layer_norm_fusion_scale_bias.py | 81 ++++++++++++++++++ 4 files changed, 136 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index bf36f11521be2..159e3b23d1ab0 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* scale = nullptr; NodeArg* bias = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - bias = last_add_node.MutableInputDefs()[i]; - } + if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + bias = last_add_node.MutableInputDefs()[i]; } } if (scale == nullptr || bias == nullptr) { @@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // because SkipLayerNorm kernel, for example, has dependency on single dim size NodeArg* scale = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } #ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (axes_values.empty() || + mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; + } #else - // Scale must be 1d. - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - scale = mul_node.MutableInputDefs()[i]; - } -#endif + // Scale must be 1d. + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { + scale = mul_node.MutableInputDefs()[i]; } +#endif } if (scale == nullptr) { diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 1f671e90090ba..a55238396cea3 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } +// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph +// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly +TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["ReduceMean"], 0); + ASSERT_EQ(op_to_count["Sub"], 0); + ASSERT_EQ(op_to_count["Cast"], 0); + ASSERT_EQ(op_to_count["Pow"], 0); + ASSERT_EQ(op_to_count["Add"], 0); + ASSERT_EQ(op_to_count["Sqrt"], 0); + ASSERT_EQ(op_to_count["Div"], 0); + ASSERT_EQ(op_to_count["Mul"], 0); + ASSERT_EQ(op_to_count["LayerNormalization"], 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } + } +} + // If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization // doesn't support input and scale having different data types. TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ec0f9a97815b888701198d94c92e3f61c581d6dc GIT binary patch literal 854 zcmbVLO;6h}7_J-BxG#kTYZE93gnT22m1Yu$ooG5~nzW*c-nc|*?V^aLfqd|B%VCEd z_!0b!9ry|SC$N(kk+Bn&96fr!@;r}i(*63k1K$A+X(!=+oM*O~2%gWxfWb)##v)ic z9{~q9B0YN23*95r`2gfxhzlM@>6Q$%VOtJ@dJr|!d|FO4Bw)rQpTZvWWxz-=(HP>i32O%=i^w!!hU}H52Z>Cg;9~+&<_rur`aAl7<+#{``weNx=D_oR1Y^ z#*lN^ftN5P>1C2t1qv}dk>9s!bQLvucvY#9fEnMyE9gV-EQq4O4@;YCBkDS8M){&@ zkboKEd;zSzP?b?MvK3XgqUwV7nl}8kiFTXek@Vf^LOYAAqdEXhvhLB8 zJ7tgC=m2%NeOM_K(1s9uUCR>7EX-~h`N1m$Ln*TIxg<{C$gn@X&P!+h9YMQ4gIkdt z$4TT+3o+bkwT`@(TjOl1*yF?9p4U84XNzD99K0cy*C0`6R*RdWK*euV{6StN>vU7S p0tyxZ+JiQ+ Date: Fri, 22 Sep 2023 01:52:13 +0400 Subject: [PATCH 065/121] [js/web] fp16 Pool & Reduce (#17512) ### Description Two more ops to support fp16 --- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 6 +- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 14 +- .../providers/js/js_execution_provider.cc | 256 +++++++++--------- .../core/providers/js/operators/pool.cc | 112 ++++---- .../core/providers/js/operators/pool.h | 8 +- .../core/providers/js/operators/reduce.cc | 28 +- .../core/providers/js/operators/reduce.h | 2 +- 7 files changed, 206 insertions(+), 220 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 8c8c12fc54ddb..120a0e9de5490 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -22,9 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims.length !== 4) { throw new Error('Pool ops supports 2-D inputs only for now.'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Invalid input type.'); - } }; const getAdjustedPoolAttributesAndOutputShape = ( @@ -248,7 +244,7 @@ const createAveragePoolProgramInfo = const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); const x = inputVariable('x', input.dataType, input.dims); - const dataType = 'f32'; + const dataType = x.type.value; const op1 = 'value += x_val;'; let op2 = ''; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 0b8d03ea73b6b..598b1db033c61 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -17,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs.length === 2 && inputs[1].dims.length !== 1) { throw new Error('Invalid axes input dims.'); } - - if (inputs[0].dataType !== DataType.float) { - throw new Error('Invalid input type.'); - } }; export interface ReduceAttributes extends AttributeWithCacheKey { @@ -161,7 +157,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => - [`var t = f32(0); var value = ${output.type.storage}(0);`, + [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', `t = ${input.getByOffset('inputOffset')}; value += (t * t);`, 'value = sqrt(value);', @@ -212,10 +208,10 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes } return [ - `var value = ${output.type.storage}(0);`, + 'var sum = f32(0);', '', - `value += ${input.getByOffset('inputOffset')};`, - `value = value / ${size}.;`, + `sum += f32(${input.getByOffset('inputOffset')});`, + `let value = ${output.type.value}(sum / ${size});`, ]; }; context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMean', attributes, reduceOp), {inputs: [0]}); @@ -266,7 +262,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => - [`var t = f32(0); var value = ${output.type.storage}(0);`, + [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', `t = ${input.getByOffset('inputOffset')}; value += t * t;`, '', diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 0674fe02d093d..72e36a161e9aa 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -129,56 +129,56 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Rel class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LeakyRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMean); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMin); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceProd); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ReduceSum); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceL1); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceL2); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); @@ -234,11 +234,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tra class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); @@ -251,16 +251,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gem class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); @@ -438,71 +438,71 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -515,16 +515,16 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc index 03e6caef7e5b8..7fdb4e5d114ea 100644 --- a/onnxruntime/core/providers/js/operators/pool.cc +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -8,69 +8,65 @@ namespace onnxruntime { namespace js { -#define POOLING_KERNEL(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL(op_name, domain, is_channels_last, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), \ + Pool); -#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - end_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + Pool); -#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); -#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - end_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); -POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 10, 10) -POOLING_KERNEL(AveragePool, kOnnxDomain, false, float, AveragePool, 11) -POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 11) -POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, float, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 1) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) +POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) +POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) -POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, float, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 12) -POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, float, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, float, MaxPool<1>, 1) +POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11) +POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12) +POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 12) +POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, MaxPool<1>, 1) +POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1) } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 5dbe5d0b8881d..5723123c0c3b8 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -41,7 +41,7 @@ namespace js { #define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"}) #define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast(is_channels_last) -template +template class Pool : public JsKernel, public PoolBase { public: Pool(const OpKernelInfo& info) : JsKernel(info), PoolBase(info) { @@ -65,10 +65,10 @@ class Pool : public JsKernel, public PoolBase { } }; -template -class Pool, is_channels_last> final : public Pool, is_channels_last> { +template +class Pool, is_channels_last> final : public Pool, is_channels_last> { public: - Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} + Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} }; } // namespace js diff --git a/onnxruntime/core/providers/js/operators/reduce.cc b/onnxruntime/core/providers/js/operators/reduce.cc index 21854fccc37ca..2679cfed86124 100644 --- a/onnxruntime/core/providers/js/operators/reduce.cc +++ b/onnxruntime/core/providers/js/operators/reduce.cc @@ -7,32 +7,30 @@ namespace onnxruntime { namespace js { #define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, sinceVersion, endVersion) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ ReduceOp, \ kOnnxDomain, \ sinceVersion, endVersion, \ - float, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ReduceOp); + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + ReduceOp); // macro REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL does not set .InputMemoryType(OrtMemTypeCPU, 1), so in future if // a new opset version update applies to Reduce* operators, we may need to add another macro like // REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type. // i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased. -#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ReduceOp, \ - kOnnxDomain, \ - sinceVersion, \ - float, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPU, 1), \ - ReduceOp); +#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ + ONNX_OPERATOR_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + sinceVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ReduceOp); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 19a6d298c7696..a5a4aa834c2ca 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { #define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ + template \ class ReduceKernel : public JsKernel, public ReduceKernelBase { \ public: \ using ReduceKernelBase::axes_; \ From 6b7bce5ec992f2b3333ee22066201f53e7978faf Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 22 Sep 2023 08:54:25 +0800 Subject: [PATCH 066/121] Model post process for zero stage3 training (#17187) ### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context --- .../_custom_autograd_function_exporter.py | 28 +- .../_custom_autograd_function_runner.py | 10 + .../ortmodule/_graph_execution_manager.py | 62 +++- .../training/ortmodule/_inference_manager.py | 4 + .../python/training/ortmodule/_io.py | 8 +- .../training/ortmodule/_training_manager.py | 4 + .../ortmodule/_zero_stage3_compatibility.py | 312 ++++++++++++++++++ .../python/training/utils/__init__.py | 3 +- .../utils/hooks/_statistics_subscriber.py | 171 +++++----- .../utils/hooks/_subscriber_manager.py | 17 +- .../utils/hooks/_zero_offload_subscriber.py | 155 ++++++--- .../python/training/utils/torch_type_map.py | 9 + .../torch_custom_function_kernel_base.cc | 7 +- 13 files changed, 619 insertions(+), 171 deletions(-) create mode 100644 orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4c72b6d98a088..f75d553a5f460 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -28,7 +28,8 @@ class PythonOpShapeInferStore: @classmethod def register(cls, kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. + """Register a shape inference function for a torch.autograd.Function if there is staticmethod + "infer_shape" defined. The signature of the shape inference function should be: @staticmethod @@ -51,6 +52,11 @@ def infer_shape( if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: cls._CLASS_MAP[kclass_name] = kclass.infer_shape + @classmethod + def register_func(cls, name: str, func: Callable) -> None: + """Register a shape inference function for a torch.autograd.Function by name.""" + cls._CLASS_MAP[name] = func + @classmethod def get_shape_infer(cls, name: str) -> Optional[Callable]: return cls._CLASS_MAP.get(name, None) @@ -228,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs): input_float_tuples.extend(list(arg)) continue - is_inspect_activation = ( - func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation" - ) + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + + is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) if is_inspect_activation and isinstance(arg, str): # _InspectActivation is a special case where the first argument is a string # that is used to determine the activation name to be inspected. @@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): _export = wrap_custom_export_function(_export_pt_1_10) -def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto: - """Post process the exported model.""" - if enable_custom_autograd_function: - exported_model = _post_process_enabling_autograd_function(exported_model) - return exported_model - - -def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: +def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. index = 0 for node in exported_model.graph.node: @@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode op_name_prefix = kclass_name break - if not node.name: - node.name = f"{op_name_prefix}_id_{index}" - index += 1 + node.name = f"{op_name_prefix}_id_{index}" + index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 845c7d83c2e7b..a5b96c4e37140 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -376,6 +376,16 @@ def wrap_all_outputs(result): result = backward_function(*wrapped_args) # Extract results as DLPack tensor list. + if isinstance(result, torch.Tensor): + result = [result] + elif isinstance(result, (tuple, list)): + result = list(result) + else: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) + wrapped_returned_args = wrap_all_outputs(result) torch_interop_utils.unregister_grad_fn(id(ctx)) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 2227b630aee23..dfaac5f0fa836 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,11 +19,10 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils -from ._custom_autograd_function_exporter import _post_process_after_export from ._fallback import ( ORTModuleDeviceException, ORTModuleONNXModelException, @@ -141,9 +140,14 @@ def __init__( register_triton_op_executor() + self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. - configure_ort_compatible_zero_stage3() + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params + + self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) + + configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): @@ -345,7 +349,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) if os.path.exists(cache_dir) and os.path.isfile(filename): self._logger.info( - f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}." + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." ) exported_model = onnx.load(filename) return exported_model @@ -409,9 +414,24 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) exported_model = onnx.load_model_from_string(f.getvalue()) - exported_model = _post_process_after_export( - exported_model, self._runtime_options.enable_custom_autograd_function - ) + if self._runtime_options.enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + exported_model = post_process_enabling_autograd_function(exported_model) + + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + exported_model = post_processing_enable_zero_stage3_compat( + exported_model, + self._zero_stage3_param_map, + [name for name, _ in self._flattened_module.named_parameters()], + ) + + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. + # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) # Cache model for future runs if cache_dir: @@ -477,7 +497,14 @@ def _initialize_graph_builder(self): grad_builder_config = C.OrtModuleGraphBuilderConfiguration() grad_builder_config.initializer_names = initializer_names grad_builder_config.initializer_names_to_train = initializer_names_to_train - grad_builder_config.input_names_require_grad = self._input_info.require_grad_names + + input_names_require_grad = self._input_info.require_grad_names + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME + + # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. + input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel( @@ -553,6 +580,9 @@ def _enable_conditional_optimizations( inputs, kwargs ) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, detected_device) + _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_builder.get_graph_info().user_input_names, @@ -562,6 +592,7 @@ def _enable_conditional_optimizations( kwargs, detected_device, self._runtime_inspector, + self._zero_stage3_param_map, ) # Enable sparsity-based optimization when applicable. @@ -587,6 +618,21 @@ def _enable_conditional_optimizations( if self._runtime_options.print_memory_stat: self._runtime_inspector.enable_memory_inspector(self._original_module) + def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): + from ._zero_stage3_compatibility import ( + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) + + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + + return kwargs + def _log_feature_stats(self): if get_rank() != 0: return diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index b7c01a1f5baf9..8d8be81c549d1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -159,6 +159,9 @@ def forward(self, *inputs, **kwargs): # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -168,6 +171,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) user_outputs, _ = InferenceManager.execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 18b965c549645..e7c1b30daae0d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -168,6 +168,7 @@ def _combine_input_buffers_initializers( kwargs: Mapping[str, ORTModelInputOutputType], device: torch.device, rt_inspector: RuntimeInspector, + zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], ): """Creates forward `*inputs` list from user input and PyTorch initializers @@ -254,7 +255,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""): ) # params is a list of all initializers known to the onnx graph - result.extend(params) + if zero_stage3_offload_param_map: + for p in params: + if p not in zero_stage3_offload_param_map.values(): + result.append(p) + else: + result.extend(params) return result, embed_sparsity_results, label_sparsity_results diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3be4c05797978..19effe2086e0a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -311,6 +311,9 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -320,6 +323,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) outputs = unflatten_user_output( diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py new file mode 100644 index 0000000000000..17756600d601e --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -0,0 +1,312 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper + +from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.training.utils import pytorch_dtype_to_onnx + +from ._custom_autograd_function_exporter import PythonOpShapeInferStore +from ._utils import get_fully_qualified_class_name + +STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1] + + +def post_processing_enable_zero_stage3_compat( + exported_model: ModelProto, + zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter], + all_param_names: List[str], +) -> ModelProto: + """This function is used to enable zero stage3 compatibility. + + Args: + exported_model (ModelProto): The exported model. + zero_stage3_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters. + all_param_names (List[str]): All parameter names. + """ + + # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. + _register_symbolic_shape_infer_functions() + + # Create weight retrieving function using zero_stage3_named_params. + func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) + + consumer_map = {} + for node in exported_model.graph.node: + for inp in node.input: + if inp not in consumer_map: + consumer_map[inp] = [] + + if node not in consumer_map[inp]: + consumer_map[inp].append(node) + + def _get_param_pull_trigger_name(param_name: str) -> str: + return f"pull_{param_name}" + + def _get_func_name(node: NodeProto) -> Optional[str]: + for attr in node.attribute: + if attr.name == "func_name": + return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + return None + + # Create weight retrieving PythonOp. + new_input, weight_pull_node = _create_weight_retrieval_pythonop( + zero_stage3_named_params, + func_full_qual_name, + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + [_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) + + from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction + + prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) + + # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. + for graph_input in exported_model.graph.input: + if graph_input.name not in zero_stage3_named_params: + continue + + if graph_input.name not in consumer_map: + continue + + consumers = consumer_map[graph_input.name] + pre_forward_pythonop_node = None + + for c in consumers: + if c.op_type != "PythonOp": + continue + + func_name = _get_func_name(c) + if func_name == prefowrad_function_name: + assert ( + pre_forward_pythonop_node is None + ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" + pre_forward_pythonop_node = c + + if pre_forward_pythonop_node is None: + raise RuntimeError( + "Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name + ) + + index_offset_on_python_op_input = [] + for i, input_name in enumerate(pre_forward_pythonop_node.input): + if input_name == graph_input.name: + index_offset_on_python_op_input.append(i) + + assert ( + len(index_offset_on_python_op_input) == 1 + ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" + + reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) + new_input_name = _get_param_pull_trigger_name(graph_input.name) + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name + + _update_python_op_input_related_attributes( + pre_forward_pythonop_node, + new_input_name, + len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type + ) + + output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) + pre_forward_pythonop_node.output[output_index] = graph_input.name + + # If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now + # `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ + # from the original one. + for c in consumers: + if c == pre_forward_pythonop_node or c.op_type != "PythonOp": + continue + _update_python_op_input_related_attributes( + c, + graph_input.name, + len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank + pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + ) + + # Delete exported_model.graph.input + graph_inputs_to_remove = [ + graph_input for graph_input in exported_model.graph.input if graph_input.name in zero_stage3_named_params + ] + for input_to_remove in graph_inputs_to_remove: + exported_model.graph.input.remove(input_to_remove) + + # Re-order graph input to make sure the weight pull trigger is before all parameter inputs. + offset = 0 + for graph_input in exported_model.graph.input: + if graph_input.name in all_param_names: + break + offset += 1 + + exported_model.graph.input.insert(offset, new_input) + exported_model.graph.node.insert(0, weight_pull_node) + + return exported_model + + +def _create_weight_retrieval_function( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]] +) -> str: + """This function is used to create a weight retrieving function using zero_stage3_named_params.""" + + class WeightRetrievalFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, weight_in_trigger): + params = list(zero_stage3_named_params.values()) + ctx.params = params + ctx.dtype = weight_in_trigger.dtype + ctx.device = weight_in_trigger.device + ctx.shape = weight_in_trigger.shape + return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params) + + @staticmethod + def backward(ctx, *grad_outputs): + return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype) + + @staticmethod + def infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + param_count = len(zero_stage3_named_params.values()) + tensor_output_shapes = [ + tensor_input_shapes[0], + ] * param_count + tensor_output_dtypes = [ + tensor_input_dtypes[0], + ] * param_count + return tensor_output_shapes, tensor_output_dtypes + + func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction) + register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction) + PythonOpShapeInferStore.register(WeightRetrievalFunction) + + return func_full_qual_name + + +def _register_symbolic_shape_infer_functions(): + """This function is used to register symbolic shape inference functions for PythonOp used in + DeepSpeed ZeRO stage3.""" + + def _simple_pass_through_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape + ) + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape + ) + + def _linear_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + # output = input.matmul(weight.t()) + tensor_input_shapes[0] # input + shape2 = tensor_input_shapes[1] # weight + output_shape = tensor_input_shapes[0] + output_shape[-1] = shape2[-2] + return [output_shape], [tensor_input_dtypes[0]] + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape + ) + + +def _create_weight_retrieval_pythonop( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + func_full_qual_name: str, + input_name: str, + output_names: List[str], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int], +) -> Tuple[ValueInfoProto, NodeProto]: + """This function is used to create a weight retrieving PythonOp.""" + offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) + new_input = helper.make_tensor_value_info( + input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE + ) + output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE) + output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE + output_tensor_ranks = [ + output_rank_for_pull_weight_trigger, + ] * offload_param_count + output_tensor_types = [ + output_dtype_for_pull_weight_trigger, + ] * offload_param_count + + node_attributes = { + "comment": "", + "inplace": 0, + "input_convention": "d", + "input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)], + "input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE], + "output_tensor_ranks": output_tensor_ranks, + "output_tensor_types": output_tensor_types, + "training_mode": 1, + "func_name": func_full_qual_name, + } + + weight_pull_node = helper.make_node( + "PythonOp", + [input_name], + ["pull_weight_trigger_ctx", *output_names], + "pull_weight_trigger", # node name + "PythonOp for weight retrieving.", + "com.microsoft", + **node_attributes, + ) + + return new_input, weight_pull_node + + +def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int): + """This function is used to update PythonOp's input related attributes, e.g. + input_tensor_ranks and input_tensor_types. + + Args: + node (NodeProto): The PythonOp node. + input_name (str): The input name to be updated. + new_rank (int): The new rank of the input, to be used in input_tensor_ranks. + new_dtype (int): The new data type of the input, to be used in input_tensor_types. + """ + input_tensor_ranks = None + input_tensor_dtypes = None + rank_attr = None + dtype_attr = None + for attr in node.attribute: + if attr.name == "input_tensor_ranks": + input_tensor_ranks = attr.ints + rank_attr = attr + if attr.name == "input_tensor_types": + input_tensor_dtypes = attr.ints + dtype_attr = attr + + assert input_tensor_ranks is not None, "input_tensor_ranks is None" + assert input_tensor_dtypes is not None, "input_tensor_dtypes is None" + + for index, node_input_name in enumerate(node.input): + if node_input_name == input_name: + input_tensor_ranks[index] = new_rank + input_tensor_dtypes[index] = new_dtype + + node.attribute.remove(rank_attr) + node.attribute.remove(dtype_attr) + node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) + node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index acf2698d55eaf..fa7c9f2750cdd 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,7 @@ extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx __all__ = [ "PrimitiveType", @@ -18,4 +18,5 @@ "extract_data_and_schema", "unflatten_data_using_schema", "pytorch_dtype_to_onnx", + "onnx_dtype_to_pytorch", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 6c8027b2fefaa..db1c69cf95ba4 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -6,6 +6,7 @@ import os import shutil import warnings +from io import TextIOWrapper from pathlib import Path from typing import List, Optional, Tuple, Union @@ -178,87 +179,97 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st order_file_path = step_path / "order.txt" tensor_file_path = step_path / output_file_name - # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, - # though it does not always guarantee to do this way. - torch.set_printoptions(precision=6, linewidth=128) - - tensor_shape = tensor.shape - tensor_dtype = tensor.dtype - flatten_array = tensor.flatten().view(-1) - - if self._run_on_cpu: - flatten_array = flatten_array.to("cpu") - - if self._run_on_cpu: - num_nan = torch.isnan(flatten_array).sum() - num_inf = torch.isinf(flatten_array).sum() - num_neg = (flatten_array < 0).sum() - num_pos = (flatten_array > 0).sum() - num_zero = (flatten_array == 0).sum() - min_value = flatten_array.min() - max_value = flatten_array.max() - mean_value = flatten_array.mean() - std_value = flatten_array.std() - else: - # Split the calculation for each bucket, then do another round of calculation on the bucket results. - # This can at the best effort reduce the peak memory impact. - bucket_size = self._bucket_size - element_count = flatten_array.numel() - ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) - nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - - # Summary for each bucket - element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - for i in range(ceil_bucket_count): - end = min((i + 1) * bucket_size, element_count) - bucket = flatten_array[i * bucket_size : end] - element_count_per_bucket[i] = bucket.numel() - - nan_buckets[i] = torch.isnan(bucket).sum() - inf_buckets[i] = torch.isinf(bucket).sum() - neg_buckets[i] = (bucket < 0).sum() - pos_buckets[i] = (bucket > 0).sum() - zero_buckets[i] = (bucket == 0).sum() - min_buckets[i] = bucket.min() - max_buckets[i] = bucket.max() - mean_buckets[i] = bucket.sum() - std_buckets[i] = bucket.std() - - # Reduction across all buckets - num_nan = nan_buckets.sum() - num_inf = inf_buckets.sum() - num_neg = neg_buckets.sum() - num_pos = pos_buckets.sum() - num_zero = zero_buckets.sum() - min_value = min_buckets.min() - max_value = max_buckets.max() - mean_value = float(mean_buckets.sum()) / float(element_count) - # Here we refer to - # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups - # to calculate the combined standard deviation of all buckets. - s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( - (mean_buckets - mean_value) ** 2 - ) - std_value = torch.sqrt(s.sum() / (element_count - 1)) - with order_file_path.open(mode="a", encoding="utf-8") as f: f.write(f"{output_file_name}\n") with tensor_file_path.open(mode="w", encoding="utf-8") as f: - f.write( - f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" - f"min: {min_value} max: {max_value}, mean: {mean_value}, " - f"std: {std_value} \n" - f"nan: {num_nan}, inf: {num_inf}\n" - ) - f.write(f"samples(top 128): {flatten_array[:128]}\n") - f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") - f.write(f"{'='*16}\n") + _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size) + + +def _summarize_tensor( + display_name: str, + tensor: torch.Tensor, + f: TextIOWrapper, + depth: int = 0, + run_on_cpu: bool = False, + bucket_size: int = 1024 * 1024 * 1024 // 2, +): + # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, + # though it does not always guarantee to do this way. + torch.set_printoptions(precision=6, linewidth=128) + + tensor_shape = tensor.shape + tensor_dtype = tensor.dtype + flatten_array = tensor.flatten().view(-1) + + if run_on_cpu: + flatten_array = flatten_array.to("cpu") + + if run_on_cpu: + num_nan = torch.isnan(flatten_array).sum() + num_inf = torch.isinf(flatten_array).sum() + num_neg = (flatten_array < 0).sum() + num_pos = (flatten_array > 0).sum() + num_zero = (flatten_array == 0).sum() + min_value = flatten_array.min() + max_value = flatten_array.max() + mean_value = flatten_array.mean() + std_value = flatten_array.std() + else: + # Split the calculation for each bucket, then do another round of calculation on the bucket results. + # This can at the best effort reduce the peak memory impact. + element_count = flatten_array.numel() + ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) + nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + + # Summary for each bucket + element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + for i in range(ceil_bucket_count): + end = min((i + 1) * bucket_size, element_count) + bucket = flatten_array[i * bucket_size : end] + element_count_per_bucket[i] = bucket.numel() + + nan_buckets[i] = torch.isnan(bucket).sum() + inf_buckets[i] = torch.isinf(bucket).sum() + neg_buckets[i] = (bucket < 0).sum() + pos_buckets[i] = (bucket > 0).sum() + zero_buckets[i] = (bucket == 0).sum() + min_buckets[i] = bucket.min() + max_buckets[i] = bucket.max() + mean_buckets[i] = bucket.sum() + std_buckets[i] = bucket.std() + + # Reduction across all buckets + num_nan = nan_buckets.sum() + num_inf = inf_buckets.sum() + num_neg = neg_buckets.sum() + num_pos = pos_buckets.sum() + num_zero = zero_buckets.sum() + min_value = min_buckets.min() + max_value = max_buckets.max() + mean_value = float(mean_buckets.sum()) / float(element_count) + # Here we refer to + # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups + # to calculate the combined standard deviation of all buckets. + s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( + (mean_buckets - mean_value) ** 2 + ) + std_value = torch.sqrt(s.sum() / (element_count - 1)) + + f.write( + f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" + f"min: {min_value} max: {max_value}, mean: {mean_value}, " + f"std: {std_value} \n" + f"nan: {num_nan}, inf: {num_inf}\n" + ) + f.write(f"samples(top 128): {flatten_array[:128]}\n") + f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") + f.write(f"{'='*16}\n") diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index db38f58d8f324..b2bc64be42fc1 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -29,14 +29,6 @@ def no_increase_global_step(): finally: ORT_NO_INCREASE_GLOBAL_STEP[0] = False - @staticmethod - def infer_shape( - node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - return tensor_input_shapes, tensor_input_dtypes - class _IncrementStep(torch.autograd.Function): """This class is used to manage the global execution step, e.g. @@ -55,8 +47,9 @@ def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ctx.current_step = run_ctx.global_states.execution_step ctx.run_ctx = run_ctx - if ctx.current_step >= 0: - print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") + # Uncomment the following line for debugging purposes. + # if ctx.current_step >= 0: + # print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: ctx.run_ctx.global_states.execution_step += 1 @@ -191,7 +184,7 @@ def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: L next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#1: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index @@ -217,7 +210,7 @@ def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_ next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#2: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 3d42e172eea82..ad1297962db71 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -23,25 +23,37 @@ from ._subscriber_base import RuntimeStates, SubscriberBase -# Used to monkey patch the original function -# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 -def _setup_zero_stage3_ort_compatible_hooks(self): - self.hierarchy = 0 +def _get_ort_compatible_zero_stage3_hook_function(debug, stats_output_dir, stats_overwrite): + """Create ort compatible hook function for DeepSpeed ZeRO stage3. - from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer + Args: + debug: whether to enable convergence debugging. + stats_output_dir: the directory to store convergence stats. + stats_overwrite: whether to overwrite the stats file if it already exists. + """ + + # Used to monkey patch the original function + # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 + def _setup_zero_stage3_ort_compatible_hooks(self): + self.hierarchy = 0 + + from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager, ZeROOffloadSubscriber + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer - # Each DeepSpeed engine has a separate subscriber manager. - self._offload_subscriber_manager = SubscriberManager() - self._offload_subscriber_manager.subscribe( - self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] - ) - self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) - self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) + subscribers = [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] + if debug is True: + subscribers.append(StatisticsSubscriber(output_dir=stats_output_dir, override_output_dir=stats_overwrite)) + # Each DeepSpeed engine has a separate subscriber manager. + self._offload_subscriber_manager = SubscriberManager() + self._offload_subscriber_manager.subscribe(self.module, subscribers) + self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) + self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) - # Add top module to stack trace - global FWD_MODULE_STACK # noqa: PLW0602 - FWD_MODULE_STACK.append(self.module) + # Add top module to stack trace + global FWD_MODULE_STACK # noqa: PLW0602 + FWD_MODULE_STACK.append(self.module) + + return _setup_zero_stage3_ort_compatible_hooks # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104 @@ -86,14 +98,16 @@ def collect_code(self, function: Callable): _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks) # This is the function to enable ORT ZeRO offload. - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", stats_overwrite=False): """Configure ZeRO stage3 to be ORT compatible. This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible. """ # Only done once no matter how many times this function is called for different modules. - DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks + DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _get_ort_compatible_zero_stage3_hook_function( + debug, stats_output_dir, stats_overwrite + ) from deepspeed.runtime.zero.linear import zero3_linear_wrap @@ -103,7 +117,7 @@ def configure_ort_compatible_zero_stage3(): except ImportError as e: warnings.warn(f"DeepSpeed import error {e}") - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") @@ -115,13 +129,13 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par """ from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params - # Retrive the parameters that are not available for this module. + # Retrieve all parameters for this module. partitioned_params = [param for param in iter_params(module)] return partitioned_params -def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: +def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -134,16 +148,13 @@ def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.par class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): - """This function is a common bridge to call original PyTorch's - pre_forward_function and post_backward_function. - """ + """This function is a common bridge to call original PyTorch's pre_forward_function""" @staticmethod def forward( ctx, module, pre_forward_with_kwargs_function, - post_backward_function, args_schema, kwargs_schema, args_tensor_count, @@ -155,7 +166,6 @@ def forward( ctx: context object module: the module to be called pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function) - post_backward_function: the function to be called after backward (PyTorch's post_backward_function) args_schema: the schema of the args, used to reconstruct the args in original form in PyTorch's pre_forward_function's inputs. kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in @@ -168,6 +178,17 @@ def forward( args_tensors = tensor_list[:args_tensor_count] kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + # For PyTorch runs, the sizes are all 0, it does not need a gradient because + # param._detach().requires_grad_(False) is called. + # But for ORT runs, the sizes are all [1], as output of weight retrieval function. + # So we keep track of the shapes and dtypes of the passed-in tensors, then generate the grads in backward. + # While for both PyTorch and ORT runs, the grad is not important because they are not param grads + # anymore, they are only used for completing the full backward propagation. + passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :] + ctx.shapes = [p.shape for p in passed_in_param_tensors] + ctx.dtypes = [p.dtype for p in passed_in_param_tensors] + ctx.devices = [p.device for p in passed_in_param_tensors] + args = unflatten_data_using_schema(args_tensors, args_schema) kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) @@ -179,6 +200,8 @@ def forward( partitioned_params = _get_params_for_current_module(module) ctx.partitioned_params = partitioned_params + assert len(partitioned_params) == len(passed_in_param_tensors) + f_ret = pre_forward_with_kwargs_function(module, args, kwargs) if f_ret is None: @@ -188,7 +211,6 @@ def forward( updated_args, updated_kwargs = f_ret ctx.module = module - ctx.post_backward_function = post_backward_function updated_args_tensors, _ = extract_data_and_schema(updated_args) updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) @@ -203,17 +225,32 @@ def forward( @staticmethod def backward(ctx, *grads): updated_grads = grads - if ctx.post_backward_function is not None: - ret = ctx.post_backward_function(ctx.module, grads) - if ret is not None: - updated_grads = ret - # TODO(pengwa) Update grad for partitioned parameters. input_count = len(updated_grads) - len(ctx.partitioned_params) - zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params] - zero_grads = updated_grads[:input_count] + tuple(zeros) - - return (None, None, None, None, None, None, None, *zero_grads) + param_start_offset = input_count + + # Only need to accumulate grad explicitly for ORT run (e.g. ctx.shapes[0] == (1,)); + # In the PyTorch run, the accumulation happens automatically. + need_manual_grad_acc = len(ctx.shapes) > 0 and ctx.shapes[0] == (1,) + if need_manual_grad_acc: + for param_index, p in enumerate(ctx.partitioned_params): + g = updated_grads[param_index + param_start_offset] + if g is None: + raise RuntimeError(f"param {p} has no grad, this should not happen.") + # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. + assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" + p.backward(g) + + # At this point, the **real** param grads are already updated, the following grads are only used for + # completing the full backward propagation, will not affect parameter updates. + passed_in_param_grad = [ + torch.zeros(shape, dtype=dtype, device=device) + for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices) + ] + + zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) + + return (None, None, None, None, None, None, *zero_grads) @staticmethod def infer_shape( @@ -258,14 +295,14 @@ def forward( module: the module to be called post_forward_function: the function to be called after forward (PyTorch's post_forward_function) pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function) - output_schema: the schema of the output, used to reconstruct the output in original form in + output_schema: the schema of the output, used to reconstruct the output in its original form in PyTorch's post_forward_function's inputs. output_tensors: the list of tensors. """ outputs = unflatten_data_using_schema(output_tensors, output_schema) - # STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. updated_outputs = post_forward_function(module, None, outputs) if updated_outputs is None: @@ -341,11 +378,19 @@ def pre_forward_module_apply_impl( input and output for torch.autograd.Function, so we do flatten and unflatten here. """ + ## Handle `_post_backward_module_hook` - args_tensors, args_schema = extract_data_and_schema(args) - kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) + # Put `_post_backward_module_hook` first because in backward, it is responsible for unloading parameters, + # we want ORTZeROOffloadPreForwardFunction's backward still be able to access the full sized parameters. + _post_backward_module_hook = self._functions.get("_post_backward_module_hook") + # STAGE3WARN#4: most logic in _post_backward_module_hook can be traced correctly so we don't need to + # wrap with PythonOp. For those cannot be traced, we handle them in STAGE3WARN#5. + updated_args = _post_backward_module_hook(module, args) - partitioned_params = _get_params_for_current_module(module) + ## Handle `_pre_forward_module_hook` + + args_tensors, args_schema = extract_data_and_schema(updated_args) + kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) _pre_forward_module_hook = self._functions.get("_pre_forward_module_hook") @@ -358,18 +403,29 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): if rets is not None: updated_args = rets - # STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. + # STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. module.ds_grads_remaining = 0 + return updated_args, updated_kwargs - all_tensors = args_tensors + kwargs_tensors + partitioned_params + # Need to pass the parameters as input to let the exporter trace the related weights for + # current ORTZeROOffloadPreForwardFunction + partitioned_params = _get_params_for_current_module(module) + # Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward + # returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata, + # PyTorch run will fail during backward. + # This will not harm parameter gradient build either in ORT or PyTorch, imagine the weights are used by + # computation anyway, so the gradient will be built. This hook only references the parameter, but won't + # generate a gradient path for it. + detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params] + + all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check") rets = ORTZeROOffloadPreForwardFunction.apply( module, _wrap_pre_forward_module_hook, - None, args_schema, kwargs_schema, args_tensor_count, @@ -385,11 +441,6 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema) updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema) - _post_backward_module_hook = self._functions.get("_post_backward_module_hook") - # STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to - # wrap with PythonOp. - updated_args = _post_backward_module_hook(module, updated_args) - return updated_args, updated_kwargs def post_forward_module_apply_impl( @@ -411,7 +462,7 @@ def post_forward_module_apply_impl( _post_forward_module_hook = self._functions.get("_post_forward_module_hook") def _wrap_post_forward_module_hook(module, input, outputs): - # STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. + # STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. from deepspeed.runtime.zero.partition_parameters import is_zero_param updated_outputs = _post_forward_module_hook(module, input, outputs) @@ -438,8 +489,8 @@ def _wrap_post_forward_module_hook(module, input, outputs): updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) _pre_backward_module_hook = self._functions.get("_pre_backward_module_hook") - # STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. - # STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into + # STAGE3WARN#7: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#8: part of the original _pre_backward_module_hook can be traced correctly so we moved them into # _wrap_post_forward_module_hook above. updated_outputs = _pre_backward_module_hook(module, None, updated_outputs) diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index 699747723f457..bdacab8ad04fe 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -33,6 +33,8 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()} +_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} + def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: """Converts a pytorch dtype or scalar type string to an onnx dtype.""" @@ -45,3 +47,10 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc if dtype not in _DTYPE_TO_ONNX: raise RuntimeError(f"Unsupported dtype {dtype}") return _DTYPE_TO_ONNX[dtype] + + +def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: + """Converts an onnx dtype to a pytorch dtype.""" + if dtype not in _ONNX_TO_DTYPE: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _ONNX_TO_DTYPE[dtype] diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 4e7fcbc95bb1d..e1d4be24861f5 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -153,8 +153,11 @@ void PythonOpBase::RunForward(OpKernelContext* context, inplace_ != 0, kernel_invoke_id_); - ORT_ENFORCE(1 + returned_ortvalues.size() == static_cast(context->OutputCount()), - "Output count mismatch for PythonOp run"); + const size_t returned_output_count = 1 + returned_ortvalues.size(); + const size_t kernel_output_count = static_cast(context->OutputCount()); + ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ", + "returned_output_count: ", returned_output_count, ", expected kernel_output_count: ", + kernel_output_count); } void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector& returned_args) const { From 1bc215e1d1c1e3509a1dd0bc413b1537563dedb5 Mon Sep 17 00:00:00 2001 From: Yiming Hu Date: Thu, 21 Sep 2023 19:22:28 -0700 Subject: [PATCH 067/121] [VITISAI] add float16 and bfloat16 support (#17438) ### Description Add float16 and bfloat16 data type support for VitisAI ep ### Motivation and Context The VitisAI ep has added the bfloat datatype support. So we would like to register the datatype from onnxruntime side to enable them. --------- Signed-off-by: Yiming Hu --- onnxruntime/core/providers/vitisai/README.md | 2 +- onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/README.md b/onnxruntime/core/providers/vitisai/README.md index 15e0c804489c5..6ddb58b8d96ae 100644 --- a/onnxruntime/core/providers/vitisai/README.md +++ b/onnxruntime/core/providers/vitisai/README.md @@ -1,4 +1,4 @@ -VitsAI Execution Prividers +VitisAI Execution Provider ============================ diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 544e18350635d..ee8dfc6d03d12 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -34,9 +34,12 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); } else if (data_type->s() == "int1") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + } else if (data_type->s() == "bfloat16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); + } else if (data_type->s() == "float16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); } else { - std::cerr << "not supported data_type " << data_type->s(); - abort(); + vai_assert(false, ", not supported data_type: " + data_type->s()); } if (shape != nullptr) { for (auto i = 0; i < shape->ints_size(); ++i) { From cd3fb377ea867570796cf61bc420cd985129a2a0 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 22 Sep 2023 11:55:08 +0800 Subject: [PATCH 068/121] [js/webgpu] Allow binary ops with scalar to use the vectorize path (#17589) ### Description 1. For binary ops, the components is always 4. So the dispatchGroup should be : `{x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}` instead of `{x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}`. 2. If any of a or b only has one element, we still can use the vectorize path since the same value will be broadcasted. --- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 13d3a91bb339e..9c05080f7e118 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -62,14 +62,24 @@ const createBinaryOpProgramShader = let assignment: string; if (vectorize) { if (doBroadcast) { - assignment = ` + const isAOneElement = ShapeUtil.size(dimsA) === 1; + const isBOneElement = ShapeUtil.size(dimsB) === 1; + if (isAOneElement || isBOneElement) { + assignment = output.setByOffset( + 'global_idx', + expressionVector( + isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'), + isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'))); + } else { + assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; let offsetA = calcOffsetA(outputIndices); let offsetB = calcOffsetB(outputIndices); ${ - output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + output.setByOffset( + 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} `; + } } else { assignment = output.setByOffset( 'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'))); @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo = } outputShape = calculatedShape; outputSize = ShapeUtil.size(outputShape); + const isAOneElement = ShapeUtil.size(a.dims) === 1; + const isBOneElement = ShapeUtil.size(b.dims) === 1; // check whether vectorize can be enabled let sharedDimension = 1; @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0) { + if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { vectorize = true; } } else { @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo = shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, outputDataType, additionalImplementation), outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => - ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}) }; }; From 891fba3b9cd71e2e1afdeab9fb3c5b5497db20cf Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 22 Sep 2023 12:00:36 +0800 Subject: [PATCH 069/121] [js/webgpu] Optimize Gather op (#17625) ### Description This PR optimizes the gather op, which is improved ~6ms in segment anything model in ADL. The problem in original algorithm is that it includes a for loop to calculate a block size of data. However, the block size may be very large, like `65536`. In GPU shader, we should try to avoid large loop in shader and try to use more threads to do it parallelly. Before: ``` [profiling] kernel "41771992|[Gather] 41771992" input[0]: [4,65536] | float32, input[1]: [1] | int64, output[0]: [1,65536] | float32, execution time: 6886207 ns ``` After: ``` [profiling] kernel "41771992|[Gather] 41771992" input[0]: [4,65536] | float32, input[1]: [1] | int64, output[0]: [1,65536] | float32, execution time: 11719 ns --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 91 ++++++++++------------- 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c054da51a3098..0ab777bfbdee9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -366,7 +366,7 @@ const createIndicesHelper = const getByIndicesImplementation = rank < 2 ? '' : ` fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { - return ${name}[i2o_${name}(indices)]; + return ${getByOffset(`i2o_${name}(indices)`)}; }`; const getImplementation = rank < 2 ? '' : (() => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 0db060dbec54a..47aae13d6799d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -30,63 +29,55 @@ const createGatherProgramInfo = const outputShape = inputShape.slice(0); outputShape.splice(axis, 1, ...indicesShape); - const inputDataType = inputs[0].dataType; - const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); - const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; - const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; - const blockSize = elementSize * block; - const M = ShapeUtil.sizeToDimension(inputShape, axis); - const N = ShapeUtil.size(indicesShape); - const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; - const gatheredBatchElements = N * block * elementSize; const axisDimLimit = inputShape[axis]; + const outputSize = ShapeUtil.size(outputShape); + + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims); + const output = outputVariable('output', inputs[0].dataType, outputShape); + const calcDataIndices = (): string => { + const indicesRank = indicesShape.length; + let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; + for (let i = 0; i < indicesRank; i++) { + calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`; + } + calcStr += ` + var idx = ${indices.getByIndices('indicesIndices')}; + if (idx < 0) { + idx = idx + ${axisDimLimit}; + } + var dataIndices = ${data.type.indices}(0); + `; + for (let i = 0, j = 0; i < inputRank; i++) { + if (i === axis) { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; + j += indicesRank; + } else { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; + j++; + } + } + return calcStr; + }; - const inputSize = ShapeUtil.size(inputShape) * elementSize; - const outputSize = ShapeUtil.size(outputShape) * elementSize; - - const totalGathers = M * N; - // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits - // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor - // Input data will be treated as u32 or two u32 for 8-byte tensors const getShaderSource = (shaderHelper: ShaderHelper) => ` - const N: u32 = ${N}; - const elementSize: u32 = ${elementSize}; - const indicesElementSize: u32 = ${indicesElementSize}; - - @group(0) @binding(0) var input : array; - @group(0) @binding(1) var inputIndices : array; - @group(0) @binding(2) var output: array; - - ${shaderHelper.mainStart()} - let batch: u32 = global_idx / N; - let i: u32 = global_idx % N; - - let srcOffsetBatch: u32 = batch * ${dataBatchElements}; - let dstOffsetBatch: u32 = batch * ${gatheredBatchElements}; - var idx = inputIndices[i * indicesElementSize]; - if (idx < 0) { - idx = idx + ${axisDimLimit}; - } - - let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; - let dstOffset = dstOffsetBatch + i * ${blockSize}; - if (srcOffset >= ${inputSize}) { - return; - } - if (dstOffset >= ${outputSize}) { - return; - } - for (var j: u32 = 0; j < ${blockSize}; j++) { - output[dstOffset + j] = input[srcOffset + j]; - } - }`; + ${shaderHelper.declareVariables(data, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let outputIndices = ${output.offsetToIndices('global_idx')}; + ${calcDataIndices()}; + let value = ${data.getByIndices('dataIndices')}; + ${output.setByOffset('global_idx', 'value')}; + }`; return { ...metadata, outputs: [ {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, ], getShaderSource, - dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; }; From 55b16d347cbcde41b35c3ed12f34eeca1a1b05d6 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sat, 23 Sep 2023 00:50:36 +0800 Subject: [PATCH 070/121] Read model zoo test (#17666) --- onnxruntime/test/providers/cpu/model_tests.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index ef2d7e31654ba..9b41ba8c0d2ba 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1133,11 +1133,15 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #if defined(NDEBUG) || defined(RUN_MODELTEST_IN_DEBUG_MODE) #ifdef _WIN32 ORT_STRING_VIEW model_test_root_path = ORT_TSTR("..\\models"); + // thus, only the root path should be mounted. + ORT_STRING_VIEW model_zoo_path = ORT_TSTR("..\\models\\zoo"); #else ORT_STRING_VIEW model_test_root_path = ORT_TSTR("../models"); + ORT_STRING_VIEW model_zoo_path = ORT_TSTR("../models/zoo"); #endif for (auto p : kvp.second) { paths.push_back(ConcatPathComponent(model_test_root_path, p)); + paths.push_back(ConcatPathComponent(model_zoo_path, p)); } #endif From 6d7bc2a097a1a08541cd0d4628831c79ab8092d5 Mon Sep 17 00:00:00 2001 From: Lukas Berbuer <36054362+lukasberbuer@users.noreply.github.com> Date: Fri, 22 Sep 2023 18:54:38 +0200 Subject: [PATCH 071/121] Fix ARMv7 build (#13891) Fix ARMv7 build error on Linux. ### Description `cpuinfo_*` functions are only available if `CPUINFO_SUPPORTED` set and therefore `"cpuinfo.h"` included. Fixed with extended conditional code. ### Motivation and Context Compilation with ARMv7 on Linux system fails. --- onnxruntime/core/common/cpuid_info.cc | 54 +++++++++++++-------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index a23409292bb74..6a82b3fcc734d 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -135,38 +135,34 @@ void CPUIDInfo::ArmLinuxInit() { LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; return; } + is_hybrid_ = cpuinfo_get_uarchs_count() > 1; + has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + const uint32_t core_cnt = cpuinfo_get_cores_count(); + core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); + is_armv8_narrow_ld_.resize(core_cnt, false); + for (uint32_t c = 0; c < core_cnt; c++) { + const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); + if (proc == nullptr) { + continue; + } + const struct cpuinfo_core* corep = proc->core; + if (corep == nullptr) { + continue; + } + auto coreid = proc->linux_id; + auto uarch = corep->uarch; + core_uarchs_[coreid] = uarch; + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_[coreid] = true; + } + } #else pytorch_cpuinfo_init_ = false; + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; #endif - - if (pytorch_cpuinfo_init_) { - is_hybrid_ = cpuinfo_get_uarchs_count() > 1; - has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); - has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - const uint32_t core_cnt = cpuinfo_get_cores_count(); - core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); - is_armv8_narrow_ld_.resize(core_cnt, false); - for (uint32_t c = 0; c < core_cnt; c++) { - const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); - if (proc == nullptr) { - continue; - } - const struct cpuinfo_core* corep = proc->core; - if (corep == nullptr) { - continue; - } - auto coreid = proc->linux_id; - auto uarch = corep->uarch; - core_uarchs_[coreid] = uarch; - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_[coreid] = true; - } - } - } else { - has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); - has_fp16_ |= has_arm_neon_dot_; - } } #elif defined(_WIN32) From e70a23f8dc6fc181218106f0e12730f980cc867e Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 22 Sep 2023 10:52:47 -0700 Subject: [PATCH 072/121] [QNN EP] Integrate Resize op fixes from QNN 2.14.1 (#17641) ### Description QNN SDK version 2.14.1 fixed several issues with the QNN Resize operator. This PR integrates the fixes and simplifies the implementation. ### Motivation and Context Improve Resize operator and test coverage. --- .../builder/opbuilder/resize_op_builder.cc | 379 ++++++------------ .../providers/cpu/tensor/resize_op_test.cc | 38 +- onnxruntime/test/providers/qnn/resize_test.cc | 224 ++++++++--- 3 files changed, 308 insertions(+), 333 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 511f2a5149f2e..4039c4fbf8d70 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -2,7 +2,8 @@ // Licensed under the MIT License. #include -#include +#include +#include #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" @@ -42,76 +43,6 @@ class ResizeOpBuilder : public BaseOpBuilder { bool do_op_validation) const override ORT_MUST_USE_RESULT; private: - /** - * Returns the QNN integer value that corresponds to the given ONNX mode (string). - * - * /param onnx_modes Array of ONNX modes supported by QNN. The index of each mode corresponds to the QNN value. - * /param onnx_mode The ONNX mode for which to get the corresponding QNN value. - * /param onnx_model_label Mode label to print out in case of error (e.g., "nearest_mode"). - * /param qnn_mode Output parameter that is set to the appropriate QNN value from the given ONNX mode. - * - * /returns A status indicating failure or success. - */ - template - Status GetQnnModeFromString(const std::array& onnx_modes, std::string_view onnx_mode, - const char* onnx_mode_label, QnnValType& qnn_mode) const ORT_MUST_USE_RESULT; - - /** - * Called by IsOpSupported to validate the op for non-quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator. - * - * /returns A status indicating failure or success. - */ - Status ValidateOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; - - /** - * Called by IsOpSupported to validate the op for quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator and its Q/DQ nodes. - * - * /returns A status indicating failure or success. - */ - Status ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; - - /** - * Called by ProcessAttributesAndOutputs to process the op's attributes and outputs - * for non-quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator. - * /param input_names The operator's input names. - * /param logger A logger. - * /param do_op_validation Set to true if the op should be validated using QNN's validation API. - * - * /returns A status indicating failure or success. - */ - Status ProcessOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const ORT_MUST_USE_RESULT; - - /** - * Called by ProcessAttributesAndOutputs to process the op's attributes and outputs - * for quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator and its Q/DQ nodes. - * /param input_names The operator's input names. - * /param logger A logger. - * /param do_op_validation Set to true if the op should be validated using QNN's validation API. - * - * /returns A status indicating failure or success. - */ - Status ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const ORT_MUST_USE_RESULT; - // Info for each ONNX attribute of interest (attribute name + default value) static const OnnxAttrInfo onnx_mode_attr; static const OnnxAttrInfo onnx_coord_transf_mode_attr; @@ -119,21 +50,29 @@ class ResizeOpBuilder : public BaseOpBuilder { static const OnnxAttrInfo onnx_antialias_attr; static const OnnxAttrInfo onnx_exclude_outside_attr; - // Arrays of supported QNN modes for QNN's Resize op. The index of each mode is used as the corresponding - // QNN parameter value. Ex: The "nearest" mode is represented as the value 0 in QNN. Note, that - // not all modes are supported by every QNN backend. + // Tables that map an ONNX attribute value (string) to the corresponding integer (enum) QNN parameter value. + // Ex: The "half_pixel" coordinate_transformation_mode is represented as the value 0 in QNN. + // Only the modes supported by QNN Resize are mapped by these tables. + static const std::unordered_map supported_modes; + static const std::unordered_map supported_coord_transf_modes; + static const std::unordered_map supported_nearest_modes; +}; - // QNN values: NEAREST = 0, LINEAR = 1 - static constexpr std::array supported_modes = {"nearest", "linear"}; +const std::unordered_map ResizeOpBuilder::supported_modes = { + {"nearest", QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST}, + {"linear", QNN_OP_RESIZE_INTERPOLATION_MODE_LINEAR}}; - // QNN values: HALF_PIXEL = 0, PYTORCH_HALF_PIXEL = 1, ALIGN_CORNERS = 2, ASYMMETRIC = 3 - static constexpr std::array supported_coord_transf_modes = {"half_pixel", "pytorch_half_pixel", - "align_corners", "asymmetric"}; +const std::unordered_map ResizeOpBuilder::supported_coord_transf_modes = { + {"half_pixel", QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL}, + {"pytorch_half_pixel", QNN_OP_RESIZE_TRANSFORMATION_MODE_PYTORCH_HALF_PIXEL}, + {"align_corners", QNN_OP_RESIZE_TRANSFORMATION_MODE_ALIGN_CORNERS}, + {"asymmetric", QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC}}; - // QNN values: ROUND_PREFER_FLOOR = 0, ROUND_PREFER_CEIL = 1, FLOOR = 2, CEIL = 3 - static constexpr std::array supported_nearest_modes = {"round_prefer_floor", "round_prefer_ceil", - "floor", "ceil"}; -}; +const std::unordered_map ResizeOpBuilder::supported_nearest_modes = { + {"round_prefer_floor", QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR}, + {"round_prefer_ceil", QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_CEIL}, + {"floor", QNN_OP_RESIZE_NEAREST_MODE_FLOOR}, + {"ceil", QNN_OP_RESIZE_NEAREST_MODE_CEIL}}; const OnnxAttrInfo ResizeOpBuilder::onnx_mode_attr = {"mode", "nearest"}; const OnnxAttrInfo ResizeOpBuilder::onnx_coord_transf_mode_attr = {"coordinate_transformation_mode", @@ -143,19 +82,26 @@ const OnnxAttrInfo ResizeOpBuilder::onnx_nearest_mode_attr = {"near const OnnxAttrInfo ResizeOpBuilder::onnx_antialias_attr = {"antialias", 0}; const OnnxAttrInfo ResizeOpBuilder::onnx_exclude_outside_attr = {"exclude_outside", 0}; -template -Status ResizeOpBuilder::GetQnnModeFromString(const std::array& onnx_modes, - std::string_view onnx_mode, const char* onnx_mode_label, - QnnValType& qnn_mode) const { - for (size_t i = 0; i < onnx_modes.size(); ++i) { - if (onnx_modes[i] == onnx_mode) { - qnn_mode = SafeInt(i); - return Status::OK(); - } +// Returns the QNN parameter integer value that corresponds to the given ONNX attribute mode string value. +static Status GetQnnModeValFromOnnxString(const std::unordered_map& supported_qnn_modes, + const std::string& onnx_attr_value, + const char* onnx_attr_name, + uint32_t& qnn_mode_value) { + auto it = supported_qnn_modes.find(onnx_attr_value); + if (it != supported_qnn_modes.end()) { + qnn_mode_value = it->second; + return Status::OK(); } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Resize operator does not support ", onnx_mode_label, - " ", std::string(onnx_mode)); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Resize operator does not support ", onnx_attr_name, + " ", std::string(onnx_attr_value)); +} + +// Returns true if the given ONNX attribute mode value is generally supported on QNN. Note that +// different QNN backends may support a smaller subset of modes. +static bool IsOnnxAttrModeSupported(const std::unordered_map& supported_qnn_modes, + const std::string& onnx_attr_value) { + return supported_qnn_modes.find(onnx_attr_value) != supported_qnn_modes.end(); } // Resize ops are sensitive with data layout, no special validation so far @@ -169,118 +115,95 @@ Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + NodeAttrHelper node_helper(node_unit); + // QNN doesn't support anti-aliasing (added in opset 18) if (node_unit.SinceVersion() >= 18) { - NodeAttrHelper node_helper(node_unit); const bool antialias = GetOnnxAttr(node_helper, onnx_antialias_attr) != 0; ORT_RETURN_IF(antialias, "QNN EP: Resize doesn't support anti-aliasing."); } - // The QNN Resize op does not currently work with the QNN cpu backend, but works with the HTP backend. Therefore, we - // currently use QNN's Resize op for quantized models and either ResizeBilinear or ResizeNearestNeighbor for - // non-quantized models. This requires separate validation for quantized models. - // TODO: Use only Resize once QNN's Resize op works in the QNN cpu backend. - bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); - return is_npu_backend ? ValidateQDQOp(qnn_model_wrapper, node_unit) : ValidateOp(qnn_model_wrapper, node_unit); -} - -Status ResizeOpBuilder::ValidateOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { - NodeAttrHelper node_helper(node_unit); - const std::string resize_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - ORT_RETURN_IF((resize_mode != "nearest") && (resize_mode != "linear"), - "QNN EP: Resize doesn't support mode '", resize_mode.c_str(), "'.", - "Only 'nearest' and 'linear' are supported."); - - const std::string coordinate_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - ORT_RETURN_IF((coordinate_mode != "half_pixel") && (coordinate_mode != "align_corners"), - "QNN EP: coordinate transformation mode '", coordinate_mode.c_str(), "' not supported for Resize op.", - "Only 'align_corners' and 'half_pixel' are supported."); - - // Check for a valid "nearest_mode" if the mode is "nearest". - if (resize_mode == "nearest") { - // NOTE: QNN's ResizeNearestNeighbor operator does not have a way to specify rounding (i.e., "nearest_mode"). - // The output of the QNN ResizeNearestNeighbor operator is not always equivalent to ONNX's Resize - // operator with any single specific "nearest_mode". - // - // For some input/output shapes, QNN's ResizeNearestNeighbor is equivalent to ONNX's Resize with "round_prefer_floor". - // For other shapes, QNN's ResizeNearestNeighbor is equivalent to ONNX Resize with "round_prefer_ceil". - // - // From unit tests, I've found a relationship between input/output shapes and the equivalent ONNX "nearest_mode". - // If the new and old spatial dimensions are evenly divisible, the "nearest_mode" is "round_prefer_floor". - // Otherwise, the "nearest_mode" is "round_prefer_ceil". - // - // This relationship is probably incomplete/wrong. - // - // TODO: Ask Qualcomm what the correct "nearest_mode" should be, - // OR use QNN's own Resize operator once it works on QnnCpu. - const std::string& nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); - ORT_RETURN_IF_NOT("floor" == nearest_mode, "QNN Resize only supports nearest_mode: floor!"); // This is wrong! - } - - auto& input_0 = node_unit.Inputs()[0]; - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), - "QNN EP: Cannot get input shape for Resize op"); - - const auto& output_0 = node_unit.Outputs()[0]; - std::vector output_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), - "QNN EP: Cannot get output shape for Resize op"); - - ORT_RETURN_IF(input_shape.size() != 4 || output_shape.size() != 4, "QNN Resize only supports 4D!"); - - ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); - ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), - "QNN EP: Data type ", input_data_type->c_str(), - " is not supported for Resize operator in CPU backend."); - - return Status::OK(); -} - -Status ResizeOpBuilder::ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { - NodeAttrHelper node_helper(node_unit); - - using namespace onnxruntime::qnn::utils; // Check mode const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_modes, interp_mode), "QNN EP: Resize does not support mode ", + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_modes, interp_mode), "QNN EP: Resize does not support mode ", interp_mode.c_str()); // Check coordinate transformation mode const std::string transformation_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_coord_transf_modes, transformation_mode), + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_coord_transf_modes, transformation_mode), "QNN EP: Resize does not support coordinate_transformation_mode ", transformation_mode.c_str()); - // Check nearest mode + const auto& input_0 = node_unit.Inputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), + "QNN EP: Cannot get shape for Resize input"); + const size_t input_rank = input_shape.size(); + + // Validate Resize w/ "nearest" mode. + // Translation matrix of ONNX Resize w/ "nearest" mode on HTP backend. + // Table entries correspond to the QNN operator used for the given configuration + // (Resize = QNN Resize op, RNN = QNN ResizeNearestNeighbor op, X = Unsupported). + // + // nearest_mode: + // coordinate_transformation_mode: | round_prefer_floor round_prefer_ceil floor ceil + // ----------------------------------------------------------------------------------------- + // half_pixel | Resize X RNN X + // pytorch_half_pixel | Resize X X X + // align_corners | Resize X RNN X + // asymmetric | Resize X RNN X + if (interp_mode == "nearest") { const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_nearest_modes, nearest_mode), + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_nearest_modes, nearest_mode), "QNN EP: Resize does not support nearest_mode ", nearest_mode.c_str()); - // TODO: Support 'asymmetric' transformation mode with nearest_mode != 'floor'. - // - // QNN's ONNX converter tool translates 'nearest' + 'asymmetric' (regardless of rounding mode) - // to QNN's ResizeNearestNeighbor with {align_corners: 0, half_pixel: 0}. - // This is only accurate if the rounding mode is "floor". Need to investigate how to handle - // other rounding modes with Qualcomm. Ideally, we would use QNN's Resize operator, but it doesn't support - // the "asymmetric" coordinate transformation mode on HTP. - ORT_RETURN_IF(transformation_mode == "asymmetric" && nearest_mode != "floor", - "QNN EP: Resize with coordinate_transformation_mode 'asymmetric' and nearest_mode '", nearest_mode, - "' is not currently supported on the HTP backend."); + if (is_npu_backend) { + // QNN only supports the following nearest_mode values on HTP: + // - "round_prefer_floor" via QNN's Resize operator + // - "floor" via QNN's ResizeNearestNeighbor operator + // + // QNN validation does not throw an error if unsupported nearest_mode values are used, so we have to + // catch them here. Otherwise, accuracy is significantly degraded. + ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_floor" || nearest_mode == "floor", + "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); + + const bool use_resize_nn_op = nearest_mode == "floor"; + + // If HTP uses ResizeNearestNeighbor ("floor"), then the "pytorch_half_pixel" coordinate_transformation_mode + // is not supported. + ORT_RETURN_IF(use_resize_nn_op && transformation_mode == "pytorch_half_pixel", + "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", + " and coordinate_transformation_mode == 'pytorch_half_pixel'."); + + // QNN's ResizeNearestNeighbor requires rank 4 inputs. + ORT_RETURN_IF(use_resize_nn_op && input_rank != 4, + "QNN EP: Resize on the NPU with nearest_mode == 'floor' requires an input with rank 4."); + } } - // Check that input shape has at least a rank of 3. - const auto& input_0 = node_unit.Inputs()[0]; - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), - "QNN EP: Cannot get shape for Resize input"); - ORT_RETURN_IF(input_shape.size() < 3, "QNN EP: Resize input must have a rank >= 3."); + // Check that the input shape has at least a rank of 3 (and a max of 5 on HTP). + ORT_RETURN_IF(input_rank < 3 || (is_npu_backend && input_rank > 5), + "QNN EP: Resize input must have a rank >= 3. The maximum rank is 5 on the NPU."); const auto& output_0 = node_unit.Outputs()[0]; std::vector output_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), "QNN EP: Cannot get shape for Resize output"); - ORT_RETURN_IF(output_shape.size() < 3, "QNN EP: Resize output must have a rank >= 3."); + + // Check that only the spatial dimensions (width, height) are resized. The batch_size (N) and channels (C) should + // be untouched. This code runs before layout transformation, so we know that the current layout is "channel first" + // (e.g., N, C, S1, S2, ..., SN), and that the minimum rank is 3. + assert(node_unit.Domain() != kMSInternalNHWCDomain); + ORT_RETURN_IF_NOT(input_shape[0] == output_shape[0] && input_shape[1] == output_shape[1], + "QNN EP: Resize may only change the spatial dimensions."); + + if (!is_npu_backend) { + ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); + ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), + "QNN EP: Data type ", input_data_type->c_str(), + " is not supported for Resize operator in CPU backend."); + } return Status::OK(); } @@ -305,92 +228,34 @@ Status ResizeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { - // The QNN Resize op does not currently work with the QNN cpu backend, but works with the HTP backend. Therefore, we - // currently use QNN's Resize op for quantized models and either ResizeBilinear or ResizeNearestNeighbor for - // non-quantized models. This requires separate handling for quantized models. - // TODO: Use only Resize once QNN's Resize op works in the QNN cpu backend. - bool is_quantized_node = NodeUnit::Type::QDQGroup == node_unit.UnitType(); - return is_quantized_node ? ProcessQDQOpAttrsAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation) : ProcessOpAttrsAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation); -} - -Status ResizeOpBuilder::ProcessOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { - ORT_UNUSED_PARAMETER(logger); - NodeAttrHelper node_helper(node_unit); - const std::string resize_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - std::string qnn_node_type = "ResizeNearestNeighbor"; - if ("linear" == resize_mode) { - qnn_node_type = "ResizeBilinear"; - } - - const std::string coordinate_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - - Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; - qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = static_cast(0); - - Qnn_Scalar_t qnn_half_pixel = QNN_SCALAR_INIT; - qnn_half_pixel.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel.bool8Value = static_cast(0); - - if ("align_corners" == coordinate_mode) { - qnn_align_corners.bool8Value = static_cast(1); - } else if ("half_pixel" == coordinate_mode) { - qnn_half_pixel.bool8Value = static_cast(1); - } - QnnParamWrapper qnn_align_corners_param(node_unit.Index(), node_unit.Name(), - QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, qnn_align_corners); - QnnParamWrapper qnn_half_pixel_param(node_unit.Index(), node_unit.Name(), - QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, qnn_half_pixel); - - std::vector param_tensor_names; - param_tensor_names.push_back(qnn_align_corners_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_align_corners_param)); - param_tensor_names.push_back(qnn_half_pixel_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_half_pixel_param)); - - return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), - logger, do_op_validation, qnn_node_type); -} - -Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { std::vector param_tensor_names; NodeAttrHelper node_helper(node_unit); const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); const std::string transformation_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); + const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); std::string qnn_op_type = "Resize"; - // Handle Resize with {mode: "nearest", coordinate_transformation_mode: "asymmetric"} uniquely. - // QNN's ONNX converter tool translates this configuration (regardless of rounding mode) - // to QNN's ResizeNearestNeighbor with {align_corners: 0, half_pixel: 0}. - // - // NOTE: This is only accurate if the rounding mode is "floor". Need to investigate how to handle - // other rounding modes with Qualcomm. Ideally, we would use QNN's Resize operator, but it doesn't support - // the "asymmetric" coordinate transformation mode on HTP. - if (interp_mode == "nearest" && transformation_mode == "asymmetric") { + // Translate Resize with {mode: "nearest", nearest_mode: "floor", coordinate_transformation_mode: XXX} to + // QNN's ResizeNearestNeighbor operator on the HTP backend. This combination of parameters is not supported on HTP + // via QNN's Resize operator. Note that QNN's ResizeNearestNeighbor operator always uses "floor" rounding. + if (is_npu_backend && interp_mode == "nearest" && nearest_mode == "floor") { qnn_op_type = "ResizeNearestNeighbor"; - // Set parameter 'align_corners' to 0 + // Parameter 'align_corners' Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = static_cast(0); + qnn_align_corners.bool8Value = static_cast(transformation_mode == "align_corners"); QnnParamWrapper qnn_align_corners_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, qnn_align_corners); param_tensor_names.push_back(qnn_align_corners_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(qnn_align_corners_param)); - // Set parameter 'half_pixel_centers' to 0 + // Parameter 'half_pixel_centers' Qnn_Scalar_t qnn_half_pixel = QNN_SCALAR_INIT; qnn_half_pixel.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel.bool8Value = static_cast(0); + qnn_half_pixel.bool8Value = static_cast(transformation_mode == "half_pixel"); QnnParamWrapper qnn_half_pixel_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, qnn_half_pixel); param_tensor_names.push_back(qnn_half_pixel_param.GetParamTensorName()); @@ -399,11 +264,12 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'transformation_mode' Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT; qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_coord_transf_modes, transformation_mode, - "coordinate_transformation_mode", qnn_transformation_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_coord_transf_modes, transformation_mode, + "coordinate_transformation_mode", + qnn_transformation_mode.uint32Value)); - QnnParamWrapper qnn_transformation_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, - qnn_transformation_mode); + QnnParamWrapper qnn_transformation_mode_param(node_unit.Index(), node_unit.Name(), + QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, qnn_transformation_mode); param_tensor_names.push_back(qnn_transformation_mode_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(qnn_transformation_mode_param)); @@ -420,7 +286,7 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'interpolation_mode' Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT; qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_modes, interp_mode, "mode", qnn_interp_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_modes, interp_mode, "mode", qnn_interp_mode.uint32Value)); QnnParamWrapper qnn_interp_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE, qnn_interp_mode); @@ -429,11 +295,10 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'nearest_mode'. Processed only when 'interpolation_mode' is NEAREST(0). if (qnn_interp_mode.uint32Value == 0) { - const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT; qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_nearest_modes, nearest_mode, "nearest_mode", - qnn_nearest_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_nearest_modes, nearest_mode, "nearest_mode", + qnn_nearest_mode.uint32Value)); QnnParamWrapper qnn_nearest_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_NEAREST_MODE, qnn_nearest_mode); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 832a8a744c08b..0434b16dc66ce 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -99,9 +99,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch // ROCm: results mismatch - // QNN: conflict with layout transformer, need furture investigation test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) { @@ -131,7 +130,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) { @@ -159,7 +158,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr 10, 10, 10}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_uint8) { @@ -188,7 +187,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -215,7 +214,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e 0, 0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { @@ -261,9 +260,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - // QNN: conflict with layout transformer, need furture investigation test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { @@ -287,7 +285,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { @@ -309,7 +307,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { std::vector Y = {0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } // Since NNAPI(TFLite) only using the scale calculate using the input/output size @@ -399,7 +397,9 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { std::vector Y = {1.0f, 4.0f}; test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - test.Run(); + + // QNN: result mismatch ("NaN" instead of 1.0f on QNN CPU backend) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); }; run_test(false); @@ -435,7 +435,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -465,7 +465,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // TensorRT: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); }; run_test(false); @@ -532,7 +532,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -560,7 +560,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {0, 2, -9}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); // TensorRT: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch } TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { @@ -641,7 +641,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -683,7 +683,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_int8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y, false, .0f, 1.0f); // TensorRT: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); }; run_test(false); @@ -1079,7 +1079,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { 13.0f, 13.0f, 13.0f, 14.0f, 14.0f, 15.0f, 15.0f, 16.0f}; test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); // QNN: result diff + test.Run(); } TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { @@ -1887,7 +1887,7 @@ void TestAntialiasing(std::map attributes, test.AddOutput("Y", output_shape, output_data); // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index cf336ca9eeb8b..cd6865d443cc0 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -120,7 +120,7 @@ static void RunCPUResizeOpTest(const TestInputDef& input_def, const std:: const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 11) { + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnCpu.dll"; @@ -138,7 +138,7 @@ static void RunCPUResizeOpTestWithScales(const TestInputDef& input_def, c const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 11) { + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnCpu.dll"; @@ -157,7 +157,8 @@ static void RunQDQResizeOpTest(const TestInputDef& input_def, const std::vector& sizes_data, const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -169,27 +170,20 @@ static void RunQDQResizeOpTest(const TestInputDef& input_def, GetQDQResizeModelBuilder(input_def, sizes_data, mode, coordinate_transformation_mode, nearest_mode), provider_options, - 18, // opset - expected_ep_assignment, - 1e-5f); + opset, + expected_ep_assignment); } // // CPU tests: // -// TODO: Our QNN CPU translation of ONNX Resize with "nearest" mode uses QNN's ResizeNearestNeighbor -// operator, which does not have a way to specify rounding (i.e., "nearest_mode" in ONNX). It is not clear -// what kind of rounding QNN's ResizeNearestNeighbor uses. Therefore, we do not yet know how to compare -// ONNX Resize to QNN ResizeNearestNeighbor. These tests should remain disabled until this behavior is -// clarified. If, for example, it turns out that ResizeNearestNeighbor uses "floor" rounding, then we should -// only compare against ONNX resize with "floor" rounding. - // Upsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, -10.0f, 10.0f), // Random input w/ range [-10, 10] - {1, 2, 21, 10}, // Sizes +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestHalfPixel_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), + {1, 2, 21, 10}, // Sizes "nearest", "half_pixel", "round_prefer_floor", @@ -198,57 +192,72 @@ TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpf) { // Upsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestHalfPixel_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 7, 5}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestHalfPixel_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestHalfPixel_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 3}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestHalfPixel_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestHalfPixel_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Upsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -// QNN v2.13: index #50 don't match, which is 4.67152 from -1.93515 -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestAlignCorners_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAlignCorners_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), {1, 2, 21, 10}, "nearest", "align_corners", "round_prefer_floor", ExpectedEPNodeAssignment::All); } +// Upsample that uses "round_prefer_floor" as the "nearest_mode". +// coordinate_transformation_mode: "asymmetric" +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAsymmetric_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), + {1, 2, 21, 10}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + // Upsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestAlignCorners_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAlignCorners_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 7, 5}, "nearest", "align_corners", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestAlignCorners_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 3}, "nearest", "align_corners", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestAlignCorners_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "nearest", "align_corners", "round_prefer_floor", ExpectedEPNodeAssignment::All); } @@ -258,76 +267,177 @@ TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpf) { // TEST_F(QnnCPUBackendTests, Resize2xLinearHalfPixel) { - RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, input_data), {1, 3, 8, 10}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearHalfPixel_scales) { - RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, input_data), {1.0f, 1.0f, 2.0f, 2.0f}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearAlignCorners) { - RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, input_data), {1, 3, 8, 10}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearAlignCorners_scales) { - RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, input_data), {1.0f, 1.0f, 2.0f, 2.0f}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } +// Test Resize downsample with mode: "linear", coordinate_transformation_mode: "align_corners" +// TODO: Enable ResizeOpTest.ResizeOpLinearDownSampleTest_4DBilinear_align_corners in cpu resize_op tests when fixed. +// +// Input f32[1,1,2,4]: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 +// Expected output f32[1, 1, 1, 2]: 1.0, 4.0 +// Actual output f32[1, 1, 1, 2]: NaN, NaN +TEST_F(QnnCPUBackendTests, DISABLED_Resize_DownSample_Linear_AlignCorners_scales) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunCPUResizeOpTestWithScales(TestInputDef({1, 1, 2, 4}, false, input_data), + {1.0f, 1.0f, 0.6f, 0.6f}, "linear", "align_corners", "", + ExpectedEPNodeAssignment::All); +} + +// Test Resize downsample with mode: "linear", coordinate_transformation_mode: "half_pixel" +// TODO: Enable ResizeOpTest.ResizeOpLinearDownSampleTest_4DBilinear cpu resize_op tests when fixed. +// +// Input f32[1,1,2,4]: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 +// Expected output f32[1, 1, 1, 2]: 2.6666 4.3333 +// Actual output f32[1, 1, 1, 2]: NaN, NaN +TEST_F(QnnCPUBackendTests, DISABLED_Resize_DownSample_Linear_HalfPixel_scales) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunCPUResizeOpTestWithScales(TestInputDef({1, 1, 2, 4}, false, input_data), + {1.0f, 1.0f, 0.6f, 0.6f}, "linear", "half_pixel", "", + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: // +// Test QDQ Resize downsample with mode: "linear", coordinate_transformation_mode: "align_corners" +TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_AlignCorners) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), + {1, 1, 1, 2}, "linear", "align_corners", "", + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Resize downsample with mode: "linear", coordinate_transformation_mode: "half_pixel" +TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_HalfPixel) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), + {1, 1, 1, 2}, "linear", "half_pixel", "", + ExpectedEPNodeAssignment::All); +} + +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "pytorch_half_pixel" +// QNN EP uses QNN's Resize op. TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearPytorchHalfPixel) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "pytorch_half_pixel", "", ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "half_pixel" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearHalfPixel) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "asymmetric", "floor", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "align_corners" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAlignCorners) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } -// TODO: Investigate with Qualcomm. The qnn-onnx-converter tool translates ONNX Resize [nearest, asymmetric, ceil] to -// QNN ResizeNearestNeighbor {align_corners: 0, half_pixel: 0}, which is NOT equivalent. It would be better to use -// QNN's own Resize operator (instead of ResizeNearestNeighbor), but it doesn't support the "asymmetric" coordinate -// transform mode. -// -// QNN v2.13: Inaccuracy detected for output 'output', element 189. -// Output quant params: scale=0.078431375324726105, zero_point=127. -// Expected val: -2.663428783416748 -// QNN QDQ val: 7.4509806632995605 (err 10.114409446716309) -// CPU QDQ val: -2.6666667461395264 (err 0.0032379627227783203) -TEST_F(QnnHTPBackendTests, DISABLED_ResizeU8_2xNearestAsymmetricCeil) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "asymmetric" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAsymmetric) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "asymmetric", "", ExpectedEPNodeAssignment::All); } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test that the nearest_mode "ceil" is not supported on the HTP backend. +TEST_F(QnnHTPBackendTests, ResizeU8_NearestModeCeil_Unsupported) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", + ExpectedEPNodeAssignment::None); +} + +// Test 3x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "floor". +// QNN EP uses QNN's ResizeNearestNeighbor op. TEST_F(QnnHTPBackendTests, ResizeU8_3xNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 12, 12}, "nearest", "asymmetric", "floor", ExpectedEPNodeAssignment::All); } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQResizeOpTest(TestInputDef({1, 2, 2, 2}, false, input_data), + {1, 2, 4, 4}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test 3x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +// +// TODO: Inaccuracy detected for output 'output_0', element 2. +// Output quant params: scale=0.078431375324726105, zero_point=127. +// Expected val: -3.3333334922790527 +// QNN QDQ val: -9.960784912109375 (err 6.6274514198303223) +// CPU QDQ val: -3.2941176891326904 (err 0.039215803146362305) +// +// More debugging info: +// Input elements f32[1,1,2,2] = -10.0000000 -3.33333349 3.33333302 10.0000000 +// ORT CPU EP (f32 model) outputs: -10.0000000 -10.0000000 -3.33333349 -3.33333349 -3.33333349 -3.33333349 -10.00 ... +// ORT CPU EP (qdq model) outputs: -9.96078491 -9.96078491 -3.29411769 -3.29411769 -3.29411769 -3.29411769 -9.961 ... +// ORT QNN EP (qdq model) outputs: -9.96078491 -9.96078491 -9.96078491 -3.37254906 -3.37254906 -3.37254906 -9.961 ... +TEST_F(QnnHTPBackendTests, DISABLED_ResizeU8_3xNearestAsymmetricRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 4); + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 2}, false, input_data), + {1, 1, 6, 6}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test 0.5x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "floor" +// QNN EP uses QNN's ResizeNearestNeighbor op. TEST_F(QnnHTPBackendTests, ResizeU8_HalfNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 2, 2}, "nearest", "asymmetric", "floor", ExpectedEPNodeAssignment::All); } From ce287a4e77895e7f6147a044ae5c723a48cb8277 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 23 Sep 2023 07:06:04 +0800 Subject: [PATCH 073/121] [WebNN EP] Remove workaround for dynamic shape (#17644) As now we have the FreeDimensionOverrides option to support dynamic shape, we can remove the previous workaround. --- onnxruntime/core/providers/webnn/builders/helper.cc | 7 +++++-- .../core/providers/webnn/builders/model_builder.cc | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 31453e005272e..774df067fe347 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -53,9 +53,12 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons } for (const auto& dim : shape_proto->dim()) { - // For now we workaround dynamic shape support by assuming 1. + // WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape. if (!dim.has_dim_value()) { - LOGS(logger, VERBOSE) << "Dynamic shape is not supported for now, assume to be 1, for input:" << input_name; + LOGS(logger, VERBOSE) << "Dynamic shape is not supported, " + << "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: " + << input_name; + return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 14ca4f1a1e674..2eae8cebbbd66 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -218,12 +218,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i } else { dims.reserve(shape.size()); for (const auto& dim : shape) { - if (!dim.has_dim_value()) { - // FIXME: support dyanmic shape. - dims.push_back(1); - } else { - dims.push_back(SafeInt(dim.dim_value())); - } + // dim_param free dimensions should have already been excluded by IsInputSupported(). + assert(dim.has_dim_value()); + dims.push_back(SafeInt(dim.dim_value())); } } } From 216214b7d302cb504d1e5a647f65b6fe49c22dbb Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:35:28 +0800 Subject: [PATCH 074/121] [ROCm] Remove ROCm5.4.2, ROCm 5.5 and add ROCm5.7 to python package pipeline (#17668) - Remove ROCm5.4.2, ROCm 5.5 and add ROCm5.7 to python package pipeline - Remove redundant arg --- ...orttraining-py-packaging-pipeline-rocm.yml | 30 ++++++------------- .../github/azure-pipelines/templates/rocm.yml | 1 - 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml index eb837b35af428..a45b7d57205d1 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml @@ -14,51 +14,39 @@ stages: - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.9' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.10' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.8' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index cc2e8745e8946..d43029266b4b0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -51,7 +51,6 @@ jobs: --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur --build-arg BUILD_UID=$(id -u) --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion) --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib From df15a3a335cb4e6703404beeb405f987521e4cf9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 09:22:00 -0700 Subject: [PATCH 075/121] [js/web] configure 5GB memory space for webpack build (#17684) ### Description ort-web build step - webpack consumes the amount of memory on the edge of Node.js(V8)'s default max-old-space-size, so increase the default memory size to 5GB to avoid this issue. --- js/web/script/build.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/js/web/script/build.ts b/js/web/script/build.ts index d3a5be429bfa1..03510ae86b85f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -176,7 +176,12 @@ npmlog.info('Build', 'Building bundle...'); webpackArgs.push('--env', `-f=${FILTER}`); } npmlog.info('Build.Bundle', `CMD: npx ${webpackArgs.join(' ')}`); - const webpack = spawnSync('npx', webpackArgs, {shell: true, stdio: 'inherit', cwd: ROOT_FOLDER}); + const webpack = spawnSync('npx', webpackArgs, { + shell: true, + stdio: 'inherit', + cwd: ROOT_FOLDER, + env: {...process.env, NODE_OPTIONS: (process.env.NODE_OPTIONS ?? '') + ' --max-old-space-size=5120'} + }); if (webpack.status !== 0) { console.error(webpack.error); process.exit(webpack.status === null ? undefined : webpack.status); From 905faea3b2683383bdb71f4ef3bb0a8f0a0832c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 25 Sep 2023 19:11:58 +0200 Subject: [PATCH 076/121] Fix static quantization for QDQ and Percentile distribution (#17649) ### Description One quantization case was not covered by the current list of unit tests. This PR adds a unit test to cover that case with the fix. It fixes the issue #17619. ### Motivation and Context --- .../providers/cpu/quantization/qlinearconv.cc | 3 +- .../python/tools/quantization/calibrate.py | 4 +- .../tools/quantization/operators/conv.py | 2 +- .../tools/quantization/operators/lstm.py | 4 +- .../tools/quantization/qdq_quantizer.py | 8 +- .../test/python/quantization/resnet_code.py | 13757 ++++++++++++++++ .../test_quantize_static_resnet.py | 138 + 7 files changed, 13909 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/test/python/quantization/resnet_code.py create mode 100644 onnxruntime/test/python/quantization/test_quantize_static_resnet.py diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index e9fc8d857b831..21a256eee6f14 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel { W_zero_point_value = W_zero_point_data[0]; for (int64_t i = 1; i < W_zero_point_size; i++) { ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value, - "QLinearConv : zero point of per-channel filter must be same"); + "QLinearConv : zero point of per-channel filter must be same. " + "This happens by design if the quantization is symmetric."); } } diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index bdf00f21100bf..26e74a6dfbac9 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -22,7 +22,7 @@ class TensorData: - _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"]) + _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -55,7 +55,7 @@ def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]] self.data[k] = TensorData(lowest=v[0], highest=v[1]) continue if len(v) == 4: - self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3]) + self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3]) continue raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.") if not isinstance(v, TensorData): diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index d23459b478e6a..23f9eaf4b0e0b 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -157,7 +157,7 @@ def quantize(self): nodes, ) = self.quantizer.quantize_activation(node, [0]) quant_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py index 7e91f9b76ca36..90a52cb528b32 100644 --- a/onnxruntime/python/tools/quantization/operators/lstm.py +++ b/onnxruntime/python/tools/quantization/operators/lstm.py @@ -47,10 +47,10 @@ def quantize(self): R.dims[0] = R_num_dir * R_4_hidden_size quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[2], onnx_proto.TensorProto.INT8, 0 + node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806 diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index e595b580b20df..5c97dd20cf507 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -283,7 +283,13 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( weight_name, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, + # Quantization type is forced to be TensorProto.INT8. + # when the expected value would be (see below) + # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. + # QLinearConv expects to have a unique value for all channels. + # This code does not enforce that but it is necessarily the case when the + # quantization is symmetric (as for INT8). + onnx_proto.TensorProto.INT8, axis, keep_float_weight=self.add_qdq_pair_to_weight, ) diff --git a/onnxruntime/test/python/quantization/resnet_code.py b/onnxruntime/test/python/quantization/resnet_code.py new file mode 100644 index 0000000000000..2f78047c824a6 --- /dev/null +++ b/onnxruntime/test/python/quantization/resnet_code.py @@ -0,0 +1,13757 @@ +import numpy +from onnx import numpy_helper +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info, set_model_props + + +def create_model(): + initializers = [] + nodes = [] + inputs = [] + outputs = [] + functions = [] + + # opsets + opsets = {"": 13} + + # initializers + + list_value = [ + -0.013732648454606533, + -0.005861935671418905, + 0.06889285147190094, + -0.1172710582613945, + 0.08841240406036377, + -0.03748627379536629, + 0.016256270930171013, + -0.1059316024184227, + 0.08246039599180222, + 0.14295539259910583, + -0.32958757877349854, + 0.1631188541650772, + 0.05412565544247627, + -0.10758306831121445, + 0.12607362866401672, + -0.4987836182117462, + 0.7441706657409668, + -0.24774713814258575, + -0.30415549874305725, + 0.4033295810222626, + -0.13447114825248718, + 0.04623159021139145, + 0.2380414456129074, + -1.226112723350525, + 2.150630235671997, + -1.702580213546753, + 0.5305419564247131, + -0.06836353242397308, + -0.20055373013019562, + 0.7035881280899048, + -0.8389442563056946, + -0.1904432326555252, + 1.2609282732009888, + -1.0670661926269531, + 0.4142579436302185, + 0.04739700257778168, + -0.3265092074871063, + 1.1873037815093994, + -1.6817731857299805, + 0.9709527492523193, + -0.09095840901136398, + -0.12556785345077515, + 0.0835147574543953, + -0.24109329283237457, + 0.032948240637779236, + 0.46304041147232056, + -0.6594106554985046, + 0.349990576505661, + -0.04113377630710602, + 0.016451245173811913, + 0.008994563482701778, + -0.028321878984570503, + -0.05336569994688034, + 0.16036668419837952, + -0.12088149785995483, + 0.031160499900579453, + -0.0618649423122406, + 0.07205374538898468, + 0.15965768694877625, + -0.3389044404029846, + 0.21603335440158844, + 0.04029613360762596, + -0.0813034325838089, + 0.1019665077328682, + -0.4873599112033844, + 0.7873126268386841, + -0.2951086163520813, + -0.43754327297210693, + 0.5905176401138306, + -0.21821773052215576, + 0.06022067740559578, + 0.26326146721839905, + -1.6453089714050293, + 2.606400728225708, + -1.8939754962921143, + 0.5196341276168823, + 0.0055860355496406555, + -0.2335057258605957, + 0.9807199239730835, + -1.2137882709503174, + -0.2699125409126282, + 1.7379733324050903, + -1.4401814937591553, + 0.435971736907959, + -0.04829222336411476, + -0.24543480575084686, + 1.3292583227157593, + -2.0375823974609375, + 1.2458536624908447, + -0.08251484483480453, + -0.14181238412857056, + 0.10612589120864868, + -0.21671657264232635, + 0.1129523366689682, + 0.3666985034942627, + -0.7546612024307251, + 0.42979565262794495, + -0.0976259633898735, + -0.0008812264422886074, + 0.02994859404861927, + -0.07027778774499893, + 0.01393035613000393, + 0.07363647222518921, + -0.10249849408864975, + 0.06602989137172699, + -0.012129798531532288, + 0.10730132460594177, + -0.04546127840876579, + -0.16065146028995514, + 0.14788293838500977, + -0.05488971993327141, + 0.03601694852113724, + 0.07513345777988434, + -0.23953600227832794, + 0.48062530159950256, + -0.42057543992996216, + -0.02402813360095024, + 0.17920851707458496, + -0.10703158378601074, + -0.028666120022535324, + 0.2815375030040741, + -0.860264241695404, + 1.4422725439071655, + -1.2058128118515015, + 0.5272247791290283, + -0.06504356116056442, + -0.20021803677082062, + 0.44968947768211365, + -0.3856053650379181, + -0.1589551419019699, + 0.7579770684242249, + -0.8349987268447876, + 0.3225692808628082, + 0.08153475821018219, + -0.43163740634918213, + 0.8742384910583496, + -0.9722443222999573, + 0.579015851020813, + -0.06688100844621658, + -0.12384293973445892, + 0.08289378881454468, + -0.10082041472196579, + -0.11204896867275238, + 0.3934254050254822, + -0.4511864185333252, + 0.32745760679244995, + -0.06534548103809357, + -0.028830429539084435, + 0.021844232454895973, + 0.01775779016315937, + -0.004250001162290573, + 0.013087524101138115, + -0.001250433037057519, + -0.040545206516981125, + -0.014049320481717587, + -0.024194253608584404, + -0.023865194991230965, + -0.0038033330347388983, + 0.00920871365815401, + -0.006582418456673622, + 0.0032474950421601534, + -0.0369916632771492, + -0.16640843451023102, + -0.28968843817710876, + -0.3531132638454437, + -0.26307201385498047, + -0.13392697274684906, + -0.03747623786330223, + 0.08083077520132065, + 0.2026241272687912, + 0.25018608570098877, + 0.2529378831386566, + 0.2307336926460266, + 0.13928599655628204, + 0.08631229400634766, + 0.13893137872219086, + 0.4867081344127655, + 0.7170669436454773, + 0.8331555724143982, + 0.6734364032745361, + 0.3549460768699646, + 0.16798041760921478, + -0.14487245678901672, + -0.47733625769615173, + -0.7670150995254517, + -0.875726580619812, + -0.6291986703872681, + -0.2910463213920593, + -0.09991979598999023, + -0.009158087894320488, + 0.018850643187761307, + 0.02646111696958542, + -0.009077857248485088, + 0.029430989176034927, + -0.03707962855696678, + -0.05111744999885559, + -0.02076525054872036, + 0.011828843504190445, + 0.017857171595096588, + 0.02548048458993435, + -0.009077494964003563, + 0.0022066361270844936, + -0.02064262516796589, + -0.008582246489822865, + -0.022748643532395363, + -0.03038850985467434, + 0.0006585497176274657, + -0.0016039719339460135, + -0.01612498238682747, + 0.013966801576316357, + -0.05851661041378975, + -0.21422894299030304, + -0.33863192796707153, + -0.3720807433128357, + -0.3030800521373749, + -0.1737397164106369, + -0.05903157964348793, + 0.15018144249916077, + 0.27454254031181335, + 0.31182464957237244, + 0.30118387937545776, + 0.24605700373649597, + 0.14123573899269104, + 0.14992672204971313, + 0.20660799741744995, + 0.5046274662017822, + 0.7706091403961182, + 0.8978630900382996, + 0.7368614673614502, + 0.3929724097251892, + 0.23079657554626465, + -0.21169082820415497, + -0.5920398235321045, + -0.893406867980957, + -0.9499238729476929, + -0.730407178401947, + -0.3615736961364746, + -0.15422092378139496, + -0.024615347385406494, + 0.005115498788654804, + 0.024657316505908966, + 0.028517475351691246, + 0.027910854667425156, + -0.009482389315962791, + -0.042242538183927536, + -0.017875321209430695, + 0.00430292496457696, + 0.015949612483382225, + 0.003636278910562396, + -0.018156034871935844, + -0.0009349065367132425, + -0.0010362856555730104, + -0.013051170855760574, + -0.009141271002590656, + -8.714485738892108e-05, + 0.02399279735982418, + 0.01753612607717514, + -0.013710699044167995, + -0.014245252124965191, + -0.0028008236549794674, + -0.08206935226917267, + -0.1098734438419342, + -0.10250325500965118, + -0.08874496072530746, + -0.031079040840268135, + 0.004536658991128206, + 0.03923843801021576, + 0.08478657901287079, + 0.07715648412704468, + 0.018803801387548447, + 0.013921198435127735, + 0.015864359214901924, + 0.04947463795542717, + 0.039856068789958954, + 0.1712094396352768, + 0.362756609916687, + 0.4192918539047241, + 0.2668488621711731, + 0.11430513113737106, + 0.06648365408182144, + -0.058979276567697525, + -0.24177154898643494, + -0.3709423542022705, + -0.3979431986808777, + -0.29706764221191406, + -0.11569518595933914, + -0.01848490908741951, + -0.015523962676525116, + 0.05081642046570778, + 0.09057094901800156, + 0.08520761132240295, + 0.04497350752353668, + -0.019453801214694977, + -0.06109466031193733, + 0.011463015340268612, + -0.008522219955921173, + -0.005283404141664505, + -0.017313135787844658, + -0.0015744483098387718, + -0.011845857836306095, + -0.016727561131119728, + -0.006708915811032057, + 0.0008860539528541267, + -0.010050912387669086, + -0.028460539877414703, + -0.0165643822401762, + -0.016545938327908516, + -0.00567589420825243, + -0.0032017906196415424, + -0.0130555285140872, + -0.026848897337913513, + -0.02615198865532875, + 0.002669057110324502, + -0.027966763824224472, + -0.03851256147027016, + -0.014509409666061401, + -0.029059220105409622, + -0.007284109480679035, + 0.04045313969254494, + 0.10005538910627365, + 0.014574537053704262, + -0.044292762875556946, + -0.01750861294567585, + -0.02231375314295292, + -0.004432118032127619, + 0.10051869601011276, + 0.1443023532629013, + 0.0508832149207592, + -0.04350621998310089, + -0.0025447055231779814, + -0.014583000913262367, + -0.02153291553258896, + 0.018860718235373497, + 0.03618147224187851, + 0.007304056081920862, + -0.029104959219694138, + 0.00576505484059453, + -0.016025763005018234, + -0.025094063952565193, + -0.05296780914068222, + -0.037012189626693726, + -0.04414081946015358, + -0.053135257214307785, + -0.028890708461403847, + -0.010220452211797237, + -0.027575822547078133, + -0.01087758969515562, + -0.027209162712097168, + -0.030827227979898453, + -0.007646164856851101, + -0.016133273020386696, + 0.000639698002487421, + -0.0034172122832387686, + 0.03914793208241463, + 0.030786357820034027, + 0.005965455900877714, + 0.020923329517245293, + -0.03435938432812691, + -0.0026781477499753237, + 0.04278327897191048, + 0.20045910775661469, + 0.21770593523979187, + 0.09422573447227478, + 0.03198440372943878, + -0.021056609228253365, + 0.028007682412862778, + 0.19196027517318726, + 0.4791645109653473, + 0.5333831906318665, + 0.3014310598373413, + 0.103666290640831, + -0.03651479259133339, + 0.027079502120614052, + 0.19239209592342377, + 0.5168290138244629, + 0.5564895868301392, + 0.2977963089942932, + 0.07770062237977982, + -0.042239490896463394, + -0.017265107482671738, + 0.08760321140289307, + 0.2775075435638428, + 0.312491774559021, + 0.12284757196903229, + 0.019664151594042778, + -0.026643047109246254, + 0.0009152573184110224, + 0.016156431287527084, + 0.09042830765247345, + 0.08991760015487671, + 0.013326293788850307, + 0.02613811008632183, + 0.021025240421295166, + 0.0198842640966177, + 0.03375901281833649, + 0.028616728261113167, + 0.026605166494846344, + 0.04126269370317459, + 0.029309948906302452, + 0.01408455427736044, + -0.003831037785857916, + 0.01922326348721981, + -0.018229445442557335, + -0.013015883974730968, + 0.017597628757357597, + -0.007964612916111946, + 0.045263469219207764, + 0.0184696726500988, + -0.001163159729912877, + -0.1809321641921997, + -0.22486254572868347, + -0.08606110513210297, + 0.001087217591702938, + 0.037091098725795746, + -0.013625397346913815, + -0.178089901804924, + -0.5483279824256897, + -0.612791895866394, + -0.32531827688217163, + -0.06506585329771042, + 0.05076128616929054, + -0.007585812360048294, + -0.20981833338737488, + -0.6155760884284973, + -0.7119701504707336, + -0.354442298412323, + -0.04236743599176407, + 0.045713260769844055, + 0.03192479908466339, + -0.07216271013021469, + -0.310979425907135, + -0.3656359910964966, + -0.13522450625896454, + 0.008291869424283504, + 0.03362602740526199, + -0.0009240762447007, + 0.01604474149644375, + -0.055634208023548126, + -0.06180194392800331, + 0.0222025066614151, + 0.027704820036888123, + -0.034385330975055695, + -0.07050742954015732, + -0.06287489086389542, + 0.03521641716361046, + -0.00020920530369039625, + 0.05458284169435501, + 0.058752644807100296, + -0.08097169548273087, + -0.01668735221028328, + 0.18557283282279968, + 0.26208117604255676, + 0.1253771185874939, + 0.07758381962776184, + -0.022084739059209824, + 0.016727397218346596, + 0.23247942328453064, + 0.35444316267967224, + 0.21802566945552826, + -0.04409221559762955, + -0.08573070168495178, + -0.0994141548871994, + 0.07754423469305038, + 0.14311672747135162, + 0.04036660119891167, + -0.29222917556762695, + -0.38828015327453613, + -0.26185816526412964, + -0.12845511734485626, + 0.04763585329055786, + -0.017382778227329254, + -0.16010743379592896, + -0.2395028918981552, + -0.2049665004014969, + -0.041346337646245956, + 0.091490738093853, + -0.005191737785935402, + -0.07687077671289444, + -0.08105621486902237, + -0.05329642817378044, + -0.03404862806200981, + 0.11478845030069351, + 0.13328343629837036, + -0.037197597324848175, + -0.01787363924086094, + -0.016605347394943237, + 0.007853846065700054, + 0.029950136318802834, + 0.10808859020471573, + 0.02873288467526436, + -0.1766187697649002, + -0.17560969293117523, + -0.03922238200902939, + 0.14447443187236786, + 0.1534212827682495, + 0.11272227019071579, + 0.008810695260763168, + -0.1485181748867035, + 0.07839693129062653, + 0.43013128638267517, + 0.4898712635040283, + 0.26522761583328247, + 0.10202436149120331, + -0.07163076847791672, + 0.09933187812566757, + 0.47377726435661316, + 0.6340300440788269, + 0.36741772294044495, + -0.04812543839216232, + -0.17370514571666718, + -0.17513291537761688, + 0.22105705738067627, + 0.3226463794708252, + 0.09850790351629257, + -0.4044247269630432, + -0.6237908601760864, + -0.4679968059062958, + -0.1954391747713089, + 0.09878316521644592, + -0.004430827684700489, + -0.31550562381744385, + -0.5235733985900879, + -0.4510284662246704, + -0.13843706250190735, + 0.10064390301704407, + -0.006748788990080357, + -0.12714813649654388, + -0.2107744812965393, + -0.18755048513412476, + -0.05646044388413429, + 0.12781813740730286, + 0.18928050994873047, + -0.04337320104241371, + -0.04973407834768295, + -0.04690375551581383, + 0.0245530866086483, + 0.10698680579662323, + 0.1646823137998581, + 0.081840381026268, + -0.01471243891865015, + -0.03138890117406845, + -0.04195617139339447, + 0.012708203867077827, + 0.033312954008579254, + 0.02409377694129944, + -0.0036440726835280657, + -0.06239784508943558, + 0.0037516560405492783, + 0.11261500418186188, + 0.13069754838943481, + 0.05901307612657547, + 0.048614490777254105, + -0.027712708339095116, + 0.027247682213783264, + 0.19195327162742615, + 0.2688453793525696, + 0.1509387195110321, + 0.020540937781333923, + -0.004100556951016188, + -0.012650247663259506, + 0.039176344871520996, + 0.09037251025438309, + -0.004689970053732395, + -0.23859903216362, + -0.2364242821931839, + -0.15189304947853088, + -0.0761493444442749, + -0.0028172829188406467, + -0.04328106716275215, + -0.16187387704849243, + -0.21743592619895935, + -0.1282283067703247, + -0.024501819163560867, + 0.04029383510351181, + -0.027387680485844612, + -0.05414740741252899, + -0.08344019204378128, + -0.06591048091650009, + 0.012637111358344555, + 0.06905930489301682, + 0.08426016569137573, + -0.0030199100729078054, + 0.034059297293424606, + 0.01111840270459652, + 0.013492933474481106, + 0.0674189031124115, + 0.08242739737033844, + 0.006129032466560602, + -0.07763395458459854, + -0.03002289868891239, + -0.055725954473018646, + 0.008795201778411865, + 0.02994825504720211, + -0.06114519387483597, + -0.0560108907520771, + -0.008179228752851486, + -0.07149285078048706, + -0.02700420655310154, + -0.01306728646159172, + 0.06276566535234451, + 0.007125973701477051, + -0.03540417551994324, + -0.039717916399240494, + 0.009147526696324348, + -0.06517947465181351, + 0.0720859095454216, + -0.05035398155450821, + 0.06659520417451859, + -0.01841895841062069, + 0.004233633633702993, + -0.020911216735839844, + -0.004646372981369495, + 1.6690073013305664, + 0.4517613649368286, + -0.07667035609483719, + 0.005556757096201181, + -0.02638973295688629, + 0.044588603079319, + -0.020916732028126717, + 0.2571280598640442, + -0.009559552185237408, + -0.043380800634622574, + 0.03196016326546669, + -0.03783237189054489, + -0.03076902963221073, + 0.03180111199617386, + 0.06352709978818893, + 0.020281998440623283, + -0.00741154421120882, + -0.0009214285528287292, + -0.0476187989115715, + -0.07208544760942459, + -0.05323023349046707, + -0.011103631928563118, + 0.02877136506140232, + -0.05324484035372734, + -0.10076326876878738, + 0.026193000376224518, + 0.03536469116806984, + 0.045722659677267075, + -0.03756006807088852, + 0.022998394444584846, + 0.0019359687576070428, + 0.01654801517724991, + 0.047304198145866394, + -0.08431598544120789, + -0.0645647644996643, + -0.17326746881008148, + -0.10692577064037323, + -0.08416426181793213, + -0.04107839986681938, + -0.0012680464424192905, + -0.02600814774632454, + -0.014215772971510887, + 0.2114446610212326, + -0.040954578667879105, + -0.05050172284245491, + 0.004194092936813831, + -0.0025900816544890404, + -0.1359374076128006, + 0.03946976363658905, + 2.3023669719696045, + 0.7484877109527588, + -0.1994970589876175, + -0.06490366160869598, + 0.007983183488249779, + -0.017937449738383293, + -0.12516839802265167, + 0.3313288688659668, + 0.11946671456098557, + -0.16942338645458221, + -0.007721045054495335, + 0.02824605070054531, + -0.05310647189617157, + -0.1122083067893982, + -0.17094524204730988, + -0.08465421944856644, + -0.09679102897644043, + -0.03848385065793991, + 0.040121182799339294, + -0.06661732494831085, + 0.0005764663219451904, + -0.05729356408119202, + -0.04778655245900154, + -0.034835152328014374, + -0.07634143531322479, + -0.05054831504821777, + 0.00597620103508234, + 0.04499154910445213, + -0.03308190405368805, + -0.04915233701467514, + -0.05842791870236397, + 0.003590918146073818, + 0.055837079882621765, + -0.02547842636704445, + -0.018847621977329254, + -0.2073899656534195, + -0.14987564086914062, + -0.03971748799085617, + 0.05886378139257431, + 0.020922083407640457, + -0.039155181497335434, + -0.028855402022600174, + 0.08688661456108093, + -0.1402827501296997, + -0.05810496211051941, + 0.037841811776161194, + -0.04082907736301422, + -0.1191127747297287, + -0.10852136462926865, + 1.6274418830871582, + 0.3678200840950012, + -0.2865799367427826, + -0.05291350558400154, + 0.023858532309532166, + -0.046683818101882935, + -0.2307816743850708, + -0.001670230645686388, + -0.17716962099075317, + -0.16724731028079987, + 0.040194038301706314, + -0.023075448349118233, + -0.01538322027772665, + -0.07914327085018158, + -0.19621343910694122, + -0.11628971993923187, + -0.05851752683520317, + 0.06313594430685043, + 0.017808571457862854, + 0.02447943389415741, + 0.048611078411340714, + -0.009247995913028717, + 0.00789090245962143, + 0.06673033535480499, + 0.0661577433347702, + 0.019111329689621925, + 0.038164373487234116, + 0.029342610388994217, + -0.03547409921884537, + -0.11017149686813354, + -0.11077891290187836, + 0.001108204829506576, + -0.0330691784620285, + -0.05039837956428528, + 0.017638904973864555, + 0.277705579996109, + 0.5606598258018494, + 0.5469182133674622, + 0.13591277599334717, + 0.012421006336808205, + 0.046348799020051956, + -0.02721901424229145, + -0.5645118355751038, + -1.072814702987671, + -0.9852984547615051, + -0.3608386516571045, + -0.010197073221206665, + -0.09785731136798859, + -0.02597353421151638, + 0.4627133309841156, + 1.1483618021011353, + 0.9505703449249268, + 0.17471027374267578, + -0.016467586159706116, + 0.026623696088790894, + 0.04765752702951431, + -0.4000166058540344, + -0.8956774473190308, + -0.6268588304519653, + -0.09439487755298615, + 0.02861764468252659, + -0.004155704285949469, + 0.08989865332841873, + 0.27384331822395325, + 0.6518518328666687, + 0.4184596836566925, + 0.13106893002986908, + 0.0050344159826636314, + 0.007061495911329985, + -0.016157688573002815, + -0.1364346295595169, + -0.27324289083480835, + -0.14245718717575073, + -0.04623992741107941, + -0.015541884116828442, + 0.030779436230659485, + 0.03756715729832649, + 0.01957445964217186, + -0.04964561015367508, + -0.0211405660957098, + 0.044496409595012665, + -0.026335055008530617, + -0.11620140820741653, + -0.11803250014781952, + 0.18242181837558746, + 0.5057784914970398, + 0.5045838952064514, + 0.03748183697462082, + 0.05692485347390175, + 0.1608155369758606, + 0.02245517633855343, + -0.7651812434196472, + -1.5504053831100464, + -1.3563542366027832, + -0.4314505457878113, + -0.028384560719132423, + -0.12238024920225143, + 0.106974296271801, + 1.11427903175354, + 2.173083543777466, + 1.747692346572876, + 0.5455064177513123, + 0.03363418206572533, + 0.11388687789440155, + -0.05905687436461449, + -0.8059568405151367, + -1.6196117401123047, + -1.1898213624954224, + -0.2654758095741272, + -0.004251840524375439, + -0.0916782096028328, + -0.024067873135209084, + 0.22692462801933289, + 0.6695711612701416, + 0.3673460781574249, + -0.017016466706991196, + -0.029604146257042885, + 0.020365707576274872, + 0.03215239942073822, + 0.0070981839671730995, + -0.14026938378810883, + -0.02425236999988556, + 0.059152450412511826, + -0.006319367326796055, + 0.003989882301539183, + 0.048541076481342316, + 0.003988460637629032, + -0.03105335496366024, + -0.08329232037067413, + 0.03226872906088829, + 0.02119620516896248, + -0.0953872874379158, + -0.15174035727977753, + 0.07963212579488754, + 0.29094186425209045, + 0.2690921127796173, + -0.020104877650737762, + 0.024988379329442978, + 0.15326620638370514, + 0.1256464123725891, + -0.40941280126571655, + -0.946648120880127, + -0.8358487486839294, + -0.14284957945346832, + -0.07980851829051971, + -0.1435413807630539, + 0.038134895265102386, + 0.8021518588066101, + 1.552701473236084, + 1.2496209144592285, + 0.38152581453323364, + 0.07136060297489166, + 0.14329172670841217, + -0.06546801328659058, + -0.5923707485198975, + -1.253793478012085, + -0.9458200335502625, + -0.156633198261261, + -0.04217473417520523, + -0.11199303716421127, + -0.07520301640033722, + 0.15331010520458221, + 0.4794600307941437, + 0.2449675053358078, + -0.10396319627761841, + 0.0034801275469362736, + 0.04475663974881172, + 0.024035215377807617, + 0.056806568056344986, + -0.07363307476043701, + -0.001563104335218668, + 0.05157755687832832, + 0.043718185275793076, + 0.02102719619870186, + 0.11859089881181717, + 0.08675580471754074, + -0.13180124759674072, + -0.15522590279579163, + 0.03273458778858185, + -0.0019622649997472763, + 0.1011638194322586, + -0.10800585150718689, + -0.6884365677833557, + -0.5495791435241699, + 0.0780424103140831, + 0.33674973249435425, + -0.21274283528327942, + -0.4183696210384369, + -0.8053947687149048, + 0.03347628563642502, + 1.3938312530517578, + 0.9454176425933838, + -0.012210174463689327, + 0.04924672842025757, + 0.16284359991550446, + 1.1340152025222778, + 2.0020322799682617, + 0.2796843647956848, + -0.968036413192749, + -0.5768532752990723, + 0.17757350206375122, + 0.37485063076019287, + 0.11534234136343002, + -1.2916942834854126, + -1.692176103591919, + -0.30523377656936646, + 0.14307916164398193, + 0.03928302228450775, + -0.19196964800357819, + -0.4533900022506714, + -0.3294944167137146, + 0.5480389595031738, + 0.4497548043727875, + 0.2170887440443039, + -0.05817069113254547, + -0.06957870721817017, + 0.03169052675366402, + 0.23751793801784515, + 0.0823391005396843, + -0.04811413958668709, + -0.051265716552734375, + -0.0395645909011364, + -0.03849785774946213, + 0.04607917368412018, + 0.09946659207344055, + -0.029992828145623207, + -0.05369366332888603, + -0.005230880342423916, + 0.012808755040168762, + 0.1821947544813156, + 0.05478882044553757, + -0.47736144065856934, + -0.44480830430984497, + -0.036321353167295456, + 0.13646431267261505, + -0.04045571759343147, + -0.21837295591831207, + -0.6888197660446167, + -0.08431777358055115, + 0.96018385887146, + 0.6788493990898132, + 0.011028020642697811, + 0.05917810648679733, + 0.02488739602267742, + 0.6898419857025146, + 1.4259209632873535, + 0.13193827867507935, + -0.8078985810279846, + -0.31056249141693115, + 0.018122224137187004, + 0.137860506772995, + 0.051947757601737976, + -0.9757952094078064, + -1.1060559749603271, + 0.06675099581480026, + 0.2091575562953949, + -0.029623042792081833, + -0.0705878809094429, + -0.18514159321784973, + -0.07947035878896713, + 0.5719470381736755, + 0.2286168485879898, + -0.03433626517653465, + 0.0036030709743499756, + 0.006251791957765818, + 0.04144154116511345, + 0.08598234504461288, + -0.050599172711372375, + -0.10440917313098907, + -0.02927244082093239, + -0.04102599248290062, + -0.07101748138666153, + -0.03579306975007057, + 0.03586365282535553, + 0.06752362847328186, + 0.048901572823524475, + -0.020898710936307907, + -0.009411930106580257, + 0.10169848799705505, + 0.1812015175819397, + -0.014482695609331131, + -0.12548771500587463, + -0.060731250792741776, + -0.034499138593673706, + 0.0829617902636528, + 0.04616715386509895, + -0.20867496728897095, + -0.1990129053592682, + 0.1773940473794937, + 0.13156233727931976, + -0.03437860682606697, + 0.04012921825051308, + -0.11132699251174927, + -0.023460939526557922, + 0.2713286876678467, + -0.06662362813949585, + -0.2709292471408844, + -0.0030232456047087908, + -0.10379529744386673, + -0.07136038690805435, + 0.03757762163877487, + -0.20515622198581696, + -0.1231834888458252, + 0.26915228366851807, + 0.0998353362083435, + -0.031466737389564514, + 0.04657471179962158, + 0.07664929330348969, + 0.10308870673179626, + 0.23429608345031738, + -0.06942534446716309, + -0.09051290899515152, + 0.03243685141205788, + 0.04053235426545143, + -0.021392958238720894, + -0.05330868810415268, + -0.11525140702724457, + -0.03889385238289833, + 0.01636480540037155, + -0.009352890774607658, + 0.13151532411575317, + -0.14738643169403076, + -0.18289834260940552, + 0.15955400466918945, + -0.001023759599775076, + 0.028809679672122, + 0.012261062860488892, + 0.29654747247695923, + -0.285063236951828, + -0.40187928080558777, + 0.3713407516479492, + 0.009383893571794033, + -0.023022817447781563, + -0.003799814498052001, + 0.48470190167427063, + -0.43402406573295593, + -0.5858806371688843, + 0.5751441717147827, + 0.05045031011104584, + -0.05559438094496727, + -0.02045449987053871, + 0.5281224250793457, + -0.5058223605155945, + -0.5950849056243896, + 0.6492323279380798, + 0.013408469036221504, + -0.05940670147538185, + -0.0044364179484546185, + 0.3112560212612152, + -0.34908774495124817, + -0.42427319288253784, + 0.43349501490592957, + 0.03724945709109306, + -0.05263671651482582, + -0.010485195554792881, + 0.1261255145072937, + -0.1349790245294571, + -0.2524855136871338, + 0.24608080089092255, + 0.036001257598400116, + -0.028843939304351807, + 0.0056989979930222034, + 0.04458172619342804, + -0.06122935935854912, + -0.166972354054451, + 0.14557687938213348, + 0.018050044775009155, + 0.032598987221717834, + -0.0055792503990232944, + 0.24355076253414154, + -0.21433626115322113, + -0.29646870493888855, + 0.1958809792995453, + 0.015435033477842808, + 0.05235098674893379, + 0.010786890983581543, + 0.47903597354888916, + -0.4127257168292999, + -0.6203306317329407, + 0.47024452686309814, + 0.0823090448975563, + -0.04538045823574066, + -0.004072466865181923, + 0.7509317994117737, + -0.6508772969245911, + -0.8481631278991699, + 0.7875698208808899, + 0.0966777428984642, + -0.10461349785327911, + 0.0063789174892008305, + 0.7535857558250427, + -0.8082649111747742, + -0.8165622353553772, + 0.9064085483551025, + 0.04986630380153656, + -0.10200339555740356, + 0.0314355194568634, + 0.46324053406715393, + -0.5523763298988342, + -0.5632953643798828, + 0.6378755569458008, + 0.07833302766084671, + -0.07979781180620193, + 0.031164664775133133, + 0.1967470794916153, + -0.21681970357894897, + -0.29283079504966736, + 0.3367702066898346, + 0.034929461777210236, + -0.047199901193380356, + -0.0033645557705312967, + 0.05454660952091217, + -0.11264829337596893, + -0.190998375415802, + 0.17961400747299194, + 0.0009085010970011353, + -0.0001827089727157727, + 0.04841821268200874, + 0.019923821091651917, + -0.07004066556692123, + -0.10590090602636337, + 0.054114967584609985, + 0.04302384704351425, + 0.00462615629658103, + 0.022948985919356346, + 0.1673787385225296, + -0.1319379210472107, + -0.2711219787597656, + 0.2387620061635971, + 0.05667697265744209, + -0.018639734014868736, + -0.07672597467899323, + 0.3503187298774719, + -0.2981504797935486, + -0.38647517561912537, + 0.4072522521018982, + 0.010913677513599396, + -0.05246961489319801, + -0.04058554396033287, + 0.39216771721839905, + -0.3605193495750427, + -0.34857264161109924, + 0.46899959444999695, + -0.03358001261949539, + -0.05188553035259247, + -0.023204902186989784, + 0.17140533030033112, + -0.2120431810617447, + -0.2144550085067749, + 0.2837989032268524, + -0.0191226527094841, + -0.020922169089317322, + 0.004324179142713547, + 0.038136694580316544, + -0.042803723365068436, + -0.11487454175949097, + 0.11820490658283234, + 0.003412557765841484, + 0.0035020115319639444, + 0.03646541014313698, + -0.010104459710419178, + -0.010897459462285042, + -0.09292570501565933, + 0.06823977828025818, + 0.02677192911505699, + 0.020071662962436676, + 0.005776307079941034, + 0.02613351307809353, + 0.017107944935560226, + -0.0002623539185151458, + -0.039298396557569504, + -0.0314190648496151, + -0.019773684442043304, + -0.01924789510667324, + 0.04253160580992699, + 0.09694722294807434, + 0.1925637573003769, + 0.1901547759771347, + 0.09470294415950775, + -0.00296174269169569, + -0.03602522239089012, + 0.03572473302483559, + 0.08787581324577332, + 0.1773553043603897, + 0.20970025658607483, + 0.14899243414402008, + 0.05427362397313118, + -0.032429151237010956, + 0.023915717378258705, + 0.06557436287403107, + 0.13488733768463135, + 0.17550915479660034, + 0.17485061287879944, + 0.10260436683893204, + -0.005381361581385136, + -0.05573735386133194, + -0.09410752356052399, + -0.07940010726451874, + -0.03424998000264168, + 0.007975265383720398, + 0.028827181085944176, + 0.023788832128047943, + -0.02962818741798401, + -0.13474339246749878, + -0.22529757022857666, + -0.20413516461849213, + -0.14711618423461914, + -0.05960607901215553, + 0.04579121991991997, + 0.005325576290488243, + -0.11592217534780502, + -0.2260522097349167, + -0.2467145025730133, + -0.22054187953472137, + -0.13919179141521454, + 0.0016459478065371513, + 0.0515579916536808, + 0.060555730015039444, + 0.040788713842630386, + -0.017907800152897835, + -0.026459651067852974, + -0.02488812990486622, + 0.015644825994968414, + 0.10543125867843628, + 0.19312354922294617, + 0.28380078077316284, + 0.28878358006477356, + 0.16968156397342682, + 0.04848042502999306, + -0.00986899808049202, + 0.06337545067071915, + 0.16356752812862396, + 0.2444516271352768, + 0.29273414611816406, + 0.2314801961183548, + 0.12695762515068054, + -0.022283215075731277, + 0.018402203917503357, + 0.07152476161718369, + 0.14247483015060425, + 0.18759845197200775, + 0.20828258991241455, + 0.14114585518836975, + -0.047197990119457245, + -0.13794781267642975, + -0.17509934306144714, + -0.1696663200855255, + -0.1206701323390007, + -0.036128126084804535, + 0.007180679589509964, + 0.006984225939959288, + -0.09600912779569626, + -0.22975720465183258, + -0.33287662267684937, + -0.2942708134651184, + -0.20305578410625458, + -0.08411446958780289, + 0.042896877974271774, + -0.020053744316101074, + -0.16365791857242584, + -0.3145587742328644, + -0.3321540057659149, + -0.2667454183101654, + -0.1542910486459732, + -0.006954069249331951, + 0.020191870629787445, + 0.014010002836585045, + 0.0016916356980800629, + -0.04649524390697479, + -0.014931428246200085, + -0.017954425886273384, + -0.020003901794552803, + 0.03831968829035759, + 0.08447518199682236, + 0.14068123698234558, + 0.13400419056415558, + 0.08205568045377731, + -0.0004489773709792644, + -0.019211264327168465, + 0.023363608866930008, + 0.08738930523395538, + 0.12299696356058121, + 0.13070489466190338, + 0.09040816128253937, + 0.03286544978618622, + -0.006979941390454769, + -0.0010930931894108653, + 0.04313739389181137, + 0.10121051222085953, + 0.11390950530767441, + 0.11383924633264542, + 0.06694260239601135, + -0.00425445893779397, + -0.0666416585445404, + -0.09225274622440338, + -0.0977785512804985, + -0.07118111103773117, + -0.026749763637781143, + -0.019425569102168083, + 0.03321055322885513, + -0.0033978468272835016, + -0.08309262245893478, + -0.15557922422885895, + -0.14969374239444733, + -0.07188998907804489, + -0.018716221675276756, + 0.022834330797195435, + 0.004232254344969988, + -0.04141783341765404, + -0.125192329287529, + -0.14545302093029022, + -0.12225300818681717, + -0.05844716727733612, + 0.010607236064970493, + 0.024218380451202393, + -0.002702374942600727, + -0.030814893543720245, + 0.03507756441831589, + -0.0506589449942112, + 0.03415676951408386, + 0.0011444400297477841, + 0.0026324463542550802, + 0.028514407575130463, + -0.01849454641342163, + -0.030959082767367363, + -0.05565863475203514, + 0.05771413818001747, + 0.003916156478226185, + -0.004474544432014227, + 0.04403551295399666, + 0.1733711212873459, + -0.37650829553604126, + 0.22322984039783478, + 0.0032540319953113794, + -0.01139416079968214, + -0.039046600461006165, + 0.0021948080975562334, + 0.5777754783630371, + -1.1944804191589355, + 0.769478976726532, + -0.1349843591451645, + 0.0004430754925124347, + -0.0061850035563111305, + -0.08340868353843689, + 0.8327823877334595, + -1.649588942527771, + 1.126111388206482, + -0.2918313145637512, + 0.003614947199821472, + 0.0016799914883449674, + -0.03255167230963707, + 0.6123784184455872, + -1.1993682384490967, + 0.8305437564849854, + -0.13622376322746277, + 0.00905851274728775, + -0.006772476714104414, + 0.07578610628843307, + 0.05859832838177681, + -0.4543764293193817, + 0.26330503821372986, + 0.0259060300886631, + -0.0007997890934348106, + 0.01269856933504343, + 0.006897627376019955, + -0.02491801232099533, + -0.03139931708574295, + 0.0028456314466893673, + 0.0008253560517914593, + -0.01086023822426796, + -0.004186873324215412, + 0.06299160420894623, + -0.039931319653987885, + -0.09315146505832672, + 0.05495935305953026, + 0.027547571808099747, + -0.010900916531682014, + -0.025233760476112366, + 0.060600072145462036, + 0.21010243892669678, + -0.5445898771286011, + 0.35070353746414185, + -0.033771682530641556, + -0.0269146841019392, + -0.025363197550177574, + -0.021729450672864914, + 0.70921790599823, + -1.4368270635604858, + 0.9582043290138245, + -0.1708265244960785, + 0.010022420436143875, + -0.032301150262355804, + -0.08667651563882828, + 1.0338889360427856, + -1.913576364517212, + 1.262008547782898, + -0.23795078694820404, + -0.032233912497758865, + -0.01397701445966959, + -0.05402921140193939, + 0.7621430158615112, + -1.387437343597412, + 0.8621506094932556, + -0.14765247702598572, + -0.004747485741972923, + 0.0017516895895823836, + 0.08154146373271942, + 0.16601374745368958, + -0.5324177742004395, + 0.27442997694015503, + 0.03274058923125267, + -0.008812552317976952, + 0.005774920806288719, + 0.04165825620293617, + -0.011749272234737873, + -0.01953396573662758, + -0.009672109968960285, + 0.01170953270047903, + 0.003071938641369343, + -0.018979815766215324, + 0.062123894691467285, + -0.004921444226056337, + -0.03380037844181061, + 0.01310884952545166, + 0.007953890599310398, + -0.0012086924398317933, + -0.03317898139357567, + -0.0015596294542774558, + 0.08166785538196564, + -0.2291223704814911, + 0.11783571541309357, + -0.016078786924481392, + 0.018957575783133507, + 0.025793947279453278, + -0.09036394208669662, + 0.3833881616592407, + -0.5794023871421814, + 0.4610825777053833, + -0.14165280759334564, + -0.007412370759993792, + 0.05252876877784729, + -0.21435455977916718, + 0.6177686452865601, + -0.8516795635223389, + 0.667263925075531, + -0.22572898864746094, + -0.004465761594474316, + 0.02589319832623005, + -0.1893543303012848, + 0.43213585019111633, + -0.6462821364402771, + 0.434274822473526, + -0.15750259160995483, + -0.01198036689311266, + -2.4281514924950898e-05, + 0.039562296122312546, + 0.11126027256250381, + -0.23193514347076416, + 0.1412443071603775, + -0.011839920654892921, + 0.007880321703851223, + 0.02950354479253292, + 0.011689653620123863, + -0.07272310554981232, + -0.03319466486573219, + -0.003948990721255541, + 0.03549842908978462, + -0.02165558747947216, + -0.09912239760160446, + -0.08742356300354004, + 0.30591821670532227, + 0.23934677243232727, + 0.02658180706202984, + -0.022127188742160797, + -0.02769642136991024, + 0.16399237513542175, + 0.5140998959541321, + 0.007951628416776657, + -0.5589093565940857, + -0.24106110632419586, + -0.02753414213657379, + 0.06947467476129532, + 0.048558495938777924, + -0.5370690822601318, + -0.761831521987915, + 0.16272802650928497, + 0.29426246881484985, + 0.07943751662969589, + -0.022394873201847076, + -0.217612162232399, + -0.03093647211790085, + 0.5945476293563843, + 0.2873935103416443, + -0.16481661796569824, + -0.02931203693151474, + -0.029083512723445892, + 0.06754925847053528, + 0.20200076699256897, + -0.07271742075681686, + -0.1976277083158493, + -0.04189611226320267, + 0.06403793394565582, + -0.00022445111244451255, + -0.01032529678195715, + -0.03415631130337715, + 0.009091783314943314, + 0.04317992925643921, + 0.07196266949176788, + -0.025028688833117485, + -0.02722775563597679, + -0.017168480902910233, + -0.027666645124554634, + -0.06734028458595276, + 0.10843724757432938, + 0.08066407591104507, + -0.027849983423948288, + -0.0045820740051567554, + -0.03388727456331253, + 0.16772156953811646, + 0.651636004447937, + 0.34874194860458374, + -0.1454945057630539, + -0.18056720495224, + 0.11703842133283615, + 0.43017855286598206, + 0.7624525427818298, + -0.3420296907424927, + -1.272199273109436, + -0.5284644365310669, + -0.005667245015501976, + 0.08240436762571335, + -0.13299596309661865, + -1.3164156675338745, + -1.659982442855835, + 0.19898656010627747, + 0.6253566741943359, + 0.25137946009635925, + -0.18244975805282593, + -0.5360167622566223, + -0.06195700913667679, + 1.2547520399093628, + 1.0296341180801392, + 0.10651036351919174, + -0.023540280759334564, + -0.07594245672225952, + 0.1492130160331726, + 0.5033117532730103, + 0.09394379705190659, + -0.22459803521633148, + -0.22473134100437164, + -0.04738321527838707, + 0.04127531498670578, + 0.0682951882481575, + -0.02095615118741989, + -0.1233135387301445, + -0.10028401762247086, + -0.008111395873129368, + -0.000617706507910043, + 0.018859047442674637, + 0.028446361422538757, + -0.06159031391143799, + -0.1292838156223297, + 0.051308393478393555, + 0.11001072078943253, + -0.02056661807000637, + -0.012175443582236767, + -0.1313694268465042, + 0.0067574759013950825, + 0.4612729251384735, + 0.323080450296402, + -0.09392253309488297, + -0.1256203055381775, + 0.03537299111485481, + 0.2556088864803314, + 0.6467183232307434, + -0.16340143978595734, + -0.8799455165863037, + -0.3312987685203552, + 0.01464154850691557, + 0.07046713680028915, + 0.053634822368621826, + -0.8514915108680725, + -1.176972508430481, + 0.2056443840265274, + 0.4998764395713806, + 0.1268644779920578, + -0.10905193537473679, + -0.3750888705253601, + -0.06701061874628067, + 0.9052186608314514, + 0.6792045831680298, + -0.00323892361484468, + -0.0007412935374304652, + -0.03608793020248413, + 0.1009129211306572, + 0.36775916814804077, + 0.035214491188526154, + -0.2273784875869751, + -0.15815992653369904, + -0.004773923195898533, + 0.06374036520719528, + 0.04737555980682373, + -0.0563247986137867, + -0.09587392956018448, + -0.043853096663951874, + 0.032572731375694275, + -0.0036250585690140724, + 0.07889056205749512, + -0.03589344769716263, + -0.019771328195929527, + 0.04937156289815903, + 0.039052557200193405, + -0.013377528637647629, + -0.0841481015086174, + -0.03358105197548866, + -0.2128981053829193, + -0.14468812942504883, + 0.14675867557525635, + 0.2550889551639557, + 0.22369499504566193, + -0.0032973098568618298, + 0.006679064594209194, + -0.11752036958932877, + 0.025247232988476753, + 0.23064176738262177, + 0.25043538212776184, + 0.3474777638912201, + 0.2151806503534317, + 0.051294319331645966, + 0.16301114857196808, + 0.25422143936157227, + -0.1796918362379074, + -0.6128425598144531, + -0.42049655318260193, + 0.07740531116724014, + -0.007960617542266846, + 0.2504507601261139, + 0.2932300865650177, + -0.5157915949821472, + -1.2904177904129028, + -1.0362532138824463, + -0.22443994879722595, + 0.007411653641611338, + 0.16024430096149445, + 0.33939966559410095, + -0.2748318016529083, + -0.8487470149993896, + -0.5955387949943542, + 0.033155132085084915, + -0.09185351431369781, + -0.05639262869954109, + 0.17084303498268127, + 0.11292264610528946, + -0.046329669654369354, + 0.11495561897754669, + 0.31740760803222656, + -0.13903948664665222, + 0.05507560819387436, + 0.10180198401212692, + -0.1369788944721222, + -0.10618618875741959, + -0.001083499751985073, + 0.16340164840221405, + 0.07591762393712997, + 0.3417445123195648, + 0.27897438406944275, + -0.32192930579185486, + -0.5731648206710815, + -0.46150147914886475, + -0.03230089321732521, + 0.04096771031618118, + 0.22242987155914307, + 0.027000218629837036, + -0.4113498628139496, + -0.433158278465271, + -0.5252256393432617, + -0.3510502874851227, + -0.133863165974617, + -0.38554033637046814, + -0.45547229051589966, + 0.2475612610578537, + 1.154951572418213, + 0.8282179236412048, + -0.13197137415409088, + -0.03350961208343506, + -0.5282800197601318, + -0.5297923684120178, + 0.9037952423095703, + 2.516275405883789, + 2.086421489715576, + 0.3573826849460602, + -0.010694397613406181, + -0.31418153643608093, + -0.5325371026992798, + 0.48083701729774475, + 1.7732245922088623, + 1.2747145891189575, + -0.06401863694190979, + 0.14296381175518036, + 0.07267159968614578, + -0.28001847863197327, + -0.29204103350639343, + 0.12853951752185822, + -0.1998838633298874, + -0.6375644207000732, + 0.06310836225748062, + -0.020014479756355286, + -0.08150970935821533, + 0.08175478130578995, + 0.07667485624551773, + 0.0025236753281205893, + -0.08504530042409897, + -0.035742271691560745, + -0.1332666128873825, + -0.15150736272335052, + 0.18459312617778778, + 0.3363596200942993, + 0.2501969635486603, + 0.029292423278093338, + -0.060296736657619476, + -0.1142202764749527, + -0.05918247997760773, + 0.18826954066753387, + 0.2183520495891571, + 0.21247169375419617, + 0.14935970306396484, + 0.09923429787158966, + 0.21808095276355743, + 0.21930061280727386, + -0.060535889118909836, + -0.5729222297668457, + -0.4199080169200897, + 0.058897778391838074, + 0.050647757947444916, + 0.2784770131111145, + 0.2754706144332886, + -0.40136128664016724, + -1.3269731998443604, + -1.124815583229065, + -0.11878778040409088, + -0.005137663800269365, + 0.17839783430099487, + 0.2115524858236313, + -0.24165289103984833, + -0.9655010104179382, + -0.7425088286399841, + 0.0304054357111454, + -0.07012742757797241, + -0.015557953156530857, + 0.1128007024526596, + 0.18957749009132385, + -0.07996463775634766, + 0.09505810588598251, + 0.34419506788253784, + -0.3072076439857483, + 0.03868290036916733, + 0.11494885385036469, + 0.03748936951160431, + 0.0797261893749237, + -0.003397951368242502, + -0.07380004972219467, + -0.11507676541805267, + -0.10298885405063629, + 0.10698320716619492, + 0.06602972000837326, + 0.08226803690195084, + 0.0037747276946902275, + -0.162277951836586, + 0.01671667955815792, + 0.09137773513793945, + 0.18799471855163574, + 0.04144813120365143, + 0.1285877376794815, + 0.1820434182882309, + 0.04940629005432129, + 0.0991915687918663, + 0.10219171643257141, + -0.013141660951077938, + -0.051191627979278564, + 0.05468929558992386, + 0.087598517537117, + 0.15897324681282043, + 0.11863455921411514, + -0.00814050156623125, + -0.07701541483402252, + -0.14013728499412537, + -0.044140227138996124, + -0.05328791216015816, + 0.06760499626398087, + 0.12053386867046356, + 0.09780212491750717, + -0.053725965321063995, + -0.07915244251489639, + -0.0032519602682441473, + 0.019637396559119225, + 0.07848430424928665, + 0.019138827919960022, + 0.1460287868976593, + 0.1281038075685501, + 0.024417784065008163, + 0.059176862239837646, + 0.0658111497759819, + -0.016405148431658745, + -0.18877744674682617, + 0.16666102409362793, + 0.1610611230134964, + 0.08374520391225815, + 0.11570518463850021, + 0.11903064697980881, + 0.1294964700937271, + 0.06379758566617966, + 0.08417274057865143, + 0.12754113972187042, + 0.025328608229756355, + 0.05170705169439316, + 0.0835295170545578, + 0.07477264851331711, + 0.11244285851716995, + 0.11559426784515381, + 0.045258160680532455, + -0.14825093746185303, + -0.08153342455625534, + 0.06288623809814453, + 0.11952362209558487, + 0.11784297972917557, + 0.011141132563352585, + -0.21666541695594788, + -0.29976174235343933, + -0.2279169261455536, + -0.11828474700450897, + 0.12436322867870331, + 0.10465826094150543, + -0.09751085937023163, + -0.292611300945282, + -0.37374064326286316, + -0.31437963247299194, + -0.25637903809547424, + 0.06173908710479736, + 0.14131486415863037, + 0.008434675633907318, + -0.23816508054733276, + -0.30330890417099, + -0.22094152867794037, + -0.11608295142650604, + 0.13235151767730713, + 0.15353602170944214, + 0.15839524567127228, + 0.012247815728187561, + -0.08126968890428543, + -0.003756331978365779, + 0.10660683363676071, + 0.21976575255393982, + -0.04188326746225357, + 0.15462253987789154, + 0.06303395330905914, + 0.006879634689539671, + 0.008284888230264187, + 0.07084798067808151, + 0.1211942657828331, + 0.10190404951572418, + 0.02935362420976162, + -0.05645999684929848, + -0.16800500452518463, + -0.1850246787071228, + -0.09476880729198456, + -0.025327544659376144, + 0.054355036467313766, + -0.035813912749290466, + -0.18694879114627838, + -0.34871891140937805, + -0.3151862621307373, + -0.1943007856607437, + -0.09755205363035202, + 0.014881589449942112, + -0.14875493943691254, + -0.37112873792648315, + -0.37739917635917664, + -0.3241480886936188, + -0.2915399968624115, + -0.11268249899148941, + -0.019726404920220375, + -0.2510305941104889, + -0.38005372881889343, + -0.3622463345527649, + -0.2932804226875305, + -0.28574010729789734, + -0.1505027860403061, + -0.004947682376950979, + -0.18587322533130646, + -0.34759166836738586, + -0.28965193033218384, + -0.21052972972393036, + -0.18780536949634552, + -0.07400713860988617, + 0.11154936999082565, + -0.03556853160262108, + -0.1896934062242508, + -0.18135806918144226, + -0.10117948800325394, + -0.0393117293715477, + 0.06517928093671799, + -0.016659021377563477, + -0.011290309950709343, + -0.007930322550237179, + 0.008189777843654156, + 0.03678786754608154, + 0.021890517324209213, + 0.0034292477648705244, + 0.02200375869870186, + 0.0014921070542186499, + -0.0800287202000618, + -0.17657361924648285, + -0.18702608346939087, + -0.12880444526672363, + -0.022084584459662437, + 0.026420501992106438, + -0.023968446999788284, + -0.07948111742734909, + -0.16741475462913513, + -0.18733707070350647, + -0.16539834439754486, + -0.07347387820482254, + -0.009723886847496033, + -0.02016977220773697, + -0.061092622578144073, + -0.13145211338996887, + -0.15919029712677002, + -0.15043555200099945, + -0.10107766091823578, + 0.0016151965828612447, + 0.0627974420785904, + 0.08695066720247269, + 0.11727584898471832, + 0.11745581030845642, + 0.11329426616430283, + 0.0533670075237751, + -0.016355818137526512, + 0.008450252935290337, + 0.06448577344417572, + 0.1538505256175995, + 0.21232697367668152, + 0.14713847637176514, + 0.039088234305381775, + -0.015588105656206608, + 0.026483291760087013, + 0.060862988233566284, + 0.18265819549560547, + 0.23042462766170502, + 0.168768972158432, + 0.034099943935871124, + -0.018249109387397766, + -0.0321880541741848, + -0.03254542127251625, + -0.03061222843825817, + -0.0026304698549211025, + 0.017764942720532417, + 0.010707704350352287, + 0.009254949167370796, + -0.04533161595463753, + -0.1483704000711441, + -0.2637183666229248, + -0.2678598165512085, + -0.1737881749868393, + -0.049990858882665634, + 0.013515918515622616, + -0.054345693439245224, + -0.1467861533164978, + -0.24911582469940186, + -0.2831358015537262, + -0.22300836443901062, + -0.13739243149757385, + -0.017879672348499298, + -0.040345460176467896, + -0.09990613907575607, + -0.16936856508255005, + -0.2266550064086914, + -0.2020808756351471, + -0.1509508341550827, + 0.014163740910589695, + 0.07591170817613602, + 0.09185601025819778, + 0.10455341637134552, + 0.09514842182397842, + 0.09877350926399231, + 0.053898438811302185, + 0.005704578943550587, + 0.0591997392475605, + 0.13600079715251923, + 0.21777905523777008, + 0.2574957311153412, + 0.20117221772670746, + 0.11415109038352966, + -0.001181072206236422, + 0.09470006823539734, + 0.18978413939476013, + 0.3073742389678955, + 0.36875811219215393, + 0.3069853186607361, + 0.1708926260471344, + -0.0325310118496418, + -0.02656698040664196, + 0.016060845926404, + 0.02459372952580452, + 0.04165660962462425, + 0.033969976007938385, + 0.012855498120188713, + 0.030497560277581215, + 0.004896117839962244, + -0.030887477099895477, + -0.13454437255859375, + -0.1294785887002945, + -0.06398608535528183, + 0.016156472265720367, + 0.03577340394258499, + -0.0033482143189758062, + -0.07112833857536316, + -0.16465041041374207, + -0.1621057391166687, + -0.09478478878736496, + -0.03555302321910858, + -0.001592929707840085, + -0.01719600521028042, + -0.06598587334156036, + -0.1411861628293991, + -0.1496778130531311, + -0.11535074561834335, + -0.0905962884426117, + -0.013807609677314758, + 0.029542237520217896, + 0.039138730615377426, + 0.03988270089030266, + 0.02665030211210251, + 0.049553126096725464, + -0.0015685928519815207, + -0.018007200211286545, + 0.009533192962408066, + 0.06910547614097595, + 0.1034330427646637, + 0.15017645061016083, + 0.10221225768327713, + 0.020978443324565887, + -0.023747621104121208, + 0.02295384369790554, + 0.09313814342021942, + 0.1771395057439804, + 0.21169933676719666, + 0.17989481985569, + 0.05862005427479744, + -0.004540165886282921, + 0.021994179114699364, + -0.003493826137855649, + -0.000224211675231345, + 0.031808022409677505, + -0.05090906098484993, + 0.001970196608453989, + 0.01633802428841591, + 0.0049764602445065975, + 0.0006027702474966645, + -0.005952450912445784, + -0.009886081330478191, + -0.08520589768886566, + 0.030780712142586708, + 0.00037104589864611626, + 0.011886775493621826, + -0.023506291210651398, + 0.08029806613922119, + -0.005086984951049089, + -0.07738454639911652, + 0.06721897423267365, + -0.02397127076983452, + 0.006669329944998026, + -0.016343094408512115, + 0.06056324020028114, + 0.15656796097755432, + -0.49836501479148865, + 0.2475810945034027, + -0.009270203299820423, + -0.006855266634374857, + 0.0034896093420684338, + -0.027938276529312134, + 0.5722692012786865, + -1.1357109546661377, + 0.5644665956497192, + 0.015787361189723015, + -0.015141892246901989, + -0.0032788251992315054, + -0.04797150194644928, + 0.6196744441986084, + -1.1540743112564087, + 0.6065864562988281, + 0.0019708566833287477, + 0.006332532037049532, + 0.014192940667271614, + 0.03773411735892296, + 0.27323007583618164, + -0.594700813293457, + 0.2488076239824295, + -0.008853388018906116, + 0.005692378617823124, + 0.000576167949475348, + -0.027197014540433884, + 0.022015029564499855, + -0.02571249194443226, + 0.004507753532379866, + -0.002439734758809209, + -0.01994609646499157, + 0.03601142391562462, + 0.008136607706546783, + 0.01658148691058159, + -0.06548810750246048, + 0.022721221670508385, + -0.0038820707704871893, + -0.0007800398161634803, + 0.001392301986925304, + 0.09576108306646347, + -0.014628835022449493, + -0.14505760371685028, + 0.07135403156280518, + -0.00839388556778431, + -0.004555124789476395, + -0.04466082155704498, + 0.1456393599510193, + 0.3475525975227356, + -0.7879117131233215, + 0.36262738704681396, + 0.008226356469094753, + 0.0055343699641525745, + -0.061139706522226334, + 0.08975803852081299, + 0.9340736269950867, + -1.7307822704315186, + 0.796896755695343, + -0.024700213223695755, + -0.013090251013636589, + -0.05148586630821228, + 0.050525497645139694, + 0.927090048789978, + -1.7473385334014893, + 0.7727715373039246, + -0.005721901543438435, + 0.010676853358745575, + -0.012798544019460678, + 0.11131046712398529, + 0.4181194007396698, + -0.8475598096847534, + 0.33206430077552795, + 0.018843427300453186, + 0.0006885005859658122, + 0.027498219162225723, + 0.00207257061265409, + 0.0032615051604807377, + -0.021950624883174896, + -0.008452882058918476, + -0.007631891872733831, + -0.028561849147081375, + 0.04865337535738945, + -0.0023105579894036055, + -0.026170270517468452, + -0.011794357560575008, + 0.004327487666159868, + 0.01756221242249012, + 0.0011611212976276875, + -0.008793564513325691, + 0.0741758644580841, + -0.057649385184049606, + -0.006000686902552843, + -0.022717488929629326, + -0.0047143916599452496, + 0.005709030199795961, + -0.05611564591526985, + 0.05792170390486717, + 0.1873699128627777, + -0.3856293857097626, + 0.1371920108795166, + 0.018953431397676468, + 0.015250314958393574, + -0.0016827551880851388, + -0.08515634387731552, + 0.6517581939697266, + -0.9557326436042786, + 0.46986615657806396, + -0.014306572265923023, + -0.01625121757388115, + -0.016088897362351418, + -0.13429272174835205, + 0.6437729001045227, + -1.0167845487594604, + 0.5061463117599487, + 0.00879831612110138, + -0.008598369546234608, + 0.02747279778122902, + 0.007245234213769436, + 0.2527446150779724, + -0.47163763642311096, + 0.15560215711593628, + 0.005050336476415396, + -0.024848125874996185, + -0.0006449198699556291, + -0.008673148229718208, + -0.06940636038780212, + -0.016248086467385292, + 0.1250494420528412, + 0.026387182995676994, + 0.009615709073841572, + -0.0025482974015176296, + -0.04534498229622841, + -0.2626228630542755, + -0.2753732204437256, + 0.052055053412914276, + 0.010792221873998642, + 0.007360508665442467, + 0.10271529853343964, + 0.1113760769367218, + -0.31120774149894714, + -0.49849262833595276, + -0.2206398844718933, + 0.04994913563132286, + 0.054614756256341934, + 0.27786919474601746, + 0.56647789478302, + 0.20970205962657928, + -0.22717078030109406, + -0.17321231961250305, + -0.07836200296878815, + -0.09607961028814316, + 0.10685958713293076, + 0.40848156809806824, + 0.34087467193603516, + -0.005242985673248768, + -0.0682876780629158, + -0.0694413110613823, + -0.1886596381664276, + -0.04473332315683365, + 0.18096435070037842, + 0.1961163580417633, + 0.0014336564345285296, + 0.014584851451218128, + 0.0462430939078331, + -0.1556192934513092, + -0.12809665501117706, + 0.0213937908411026, + 0.10984069108963013, + -0.023050926625728607, + -0.013447473756968975, + 0.007857509888708591, + -0.027979737147688866, + -0.04768490046262741, + -0.09350565075874329, + -0.1659490317106247, + 0.007927919737994671, + 0.26641780138015747, + 0.03398526459932327, + 0.02118881419301033, + -0.006898822728544474, + -0.15209096670150757, + -0.4939330220222473, + -0.42655149102211, + 0.08215854316949844, + 0.02115131914615631, + 0.08892140537500381, + 0.2164168655872345, + 0.12431265413761139, + -0.47813764214515686, + -0.6588870882987976, + -0.3097454905509949, + 0.0837375745177269, + 0.1548176258802414, + 0.49661529064178467, + 0.7337944507598877, + 0.1966201215982437, + -0.29367199540138245, + -0.2547970116138458, + -0.11655519157648087, + -0.11720486730337143, + 0.21941716969013214, + 0.5902130603790283, + 0.42572125792503357, + 0.020460324361920357, + -0.12768393754959106, + -0.12030418962240219, + -0.2582310736179352, + -0.0355166494846344, + 0.2766987085342407, + 0.28080257773399353, + 0.08665957301855087, + 0.027141664177179337, + 0.02690703421831131, + -0.25276950001716614, + -0.23180679976940155, + 0.015180152840912342, + 0.11523276567459106, + 0.041165824979543686, + 0.017444534227252007, + 0.0009439520072191954, + -0.025763530284166336, + -0.022880665957927704, + -0.024819007143378258, + -0.04901815578341484, + 0.027672944590449333, + 0.11211585998535156, + 0.024664992466568947, + -0.010093869641423225, + 0.009466213174164295, + -0.043605536222457886, + -0.17007218301296234, + -0.1366996467113495, + 0.08740171790122986, + -0.014591479673981667, + -0.0031720874831080437, + 0.0835830345749855, + 0.028662094846367836, + -0.21436777710914612, + -0.24753160774707794, + -0.06092096120119095, + 0.03788171336054802, + 0.04295210912823677, + 0.19064708054065704, + 0.3095722496509552, + 0.08003447204828262, + -0.09509303420782089, + -0.05495578795671463, + -0.052218906581401825, + -0.07204427570104599, + 0.07710819691419601, + 0.18033725023269653, + 0.0834946483373642, + -0.049662720412015915, + -0.06561554968357086, + -0.013351643458008766, + -0.11217659711837769, + 0.031957074999809265, + 0.12180440872907639, + 0.06891122460365295, + -0.013705568388104439, + 0.0011150656500831246, + 0.03281388059258461, + -0.11285661906003952, + -0.06422404199838638, + 0.04218210279941559, + 0.014165353029966354, + -0.006244795396924019, + 0.01745765097439289, + 0.08924975246191025, + 0.01710040494799614, + -0.14013372361660004, + -0.21913501620292664, + 0.03613810986280441, + 0.14273521304130554, + 0.05801931768655777, + 0.021427493542432785, + 0.23185034096240997, + 0.2427377849817276, + -0.4384608566761017, + -0.7205182909965515, + -0.18313364684581757, + 0.033575087785720825, + -0.0809125304222107, + 0.04173902049660683, + 0.7251381874084473, + 1.1058244705200195, + -0.015065462328493595, + -0.6434917449951172, + -0.3080260753631592, + -0.090518057346344, + -0.3659006655216217, + -0.4520319700241089, + 0.5924424529075623, + 1.4148176908493042, + 0.5285682082176208, + -0.027211233973503113, + -0.07359065115451813, + -0.08583711832761765, + -0.5631492137908936, + -1.0246236324310303, + -0.1835726648569107, + 0.3307121694087982, + 0.22562064230442047, + 0.05237145721912384, + 0.13263091444969177, + 0.13899636268615723, + -0.1626550555229187, + -0.3918432295322418, + -0.03585565462708473, + 0.06904798001050949, + 0.029870154336094856, + 0.04289601743221283, + 0.05758490040898323, + 0.10055387020111084, + -0.011962685734033585, + -0.13269846141338348, + 0.0012237781193107367, + 0.05511128902435303, + 0.03764793649315834, + -0.07580426335334778, + -0.1750984787940979, + 0.0189101230353117, + 0.08156414330005646, + 0.01691802591085434, + 0.004023027140647173, + 0.18009696900844574, + 0.22744491696357727, + -0.38747039437294006, + -0.6413040161132812, + -0.19208981096744537, + 0.01971367374062538, + -0.036756888031959534, + 0.004946697968989611, + 0.7331712245941162, + 1.1178003549575806, + 0.03220612183213234, + -0.5881579518318176, + -0.24453559517860413, + -0.11856977641582489, + -0.43593257665634155, + -0.5339378118515015, + 0.49467018246650696, + 1.3376370668411255, + 0.5238692164421082, + 0.04584280773997307, + 0.004761924035847187, + -0.032823480665683746, + -0.5419207811355591, + -1.0093209743499756, + -0.19847697019577026, + 0.20687319338321686, + 0.12301573902368546, + 0.07981085777282715, + 0.14125365018844604, + 0.19885297119617462, + -0.1678825318813324, + -0.4042292535305023, + 0.004483209457248449, + 0.03009556047618389, + 0.010802071541547775, + 0.005967534612864256, + 0.0892769992351532, + 0.07342032343149185, + -0.0588892325758934, + -0.09044717997312546, + 0.06307072192430496, + -0.012583961710333824, + -0.006880680099129677, + 0.0030021765269339085, + 0.01633061282336712, + 0.06990820169448853, + 0.0070900083519518375, + -0.03546716272830963, + -0.022131899371743202, + -0.02906683459877968, + 0.010664403438568115, + -0.18731924891471863, + -0.158770352602005, + 0.08571326732635498, + 0.039154618978500366, + 0.032578419893980026, + -0.005781106185168028, + 0.17460086941719055, + 0.2787456810474396, + -0.13416190445423126, + -0.23966801166534424, + -0.004878139588981867, + 0.02796499989926815, + -0.06610933691263199, + -0.19162042438983917, + 0.11163146048784256, + 0.371842622756958, + 0.06444671750068665, + 0.016595548018813133, + 0.01164282951503992, + 0.08330011367797852, + -0.03192862868309021, + -0.2867860198020935, + -0.07080501317977905, + -0.016348646953701973, + -0.06306261569261551, + -0.016291450709104538, + 0.010558445006608963, + 0.13014638423919678, + 0.06202690303325653, + -0.03361419215798378, + 0.0691375732421875, + 0.003561250865459442, + -0.013095442205667496, + -0.050333790481090546, + -0.019117066636681557, + 0.0012089330703020096, + -0.004555183462798595, + -0.022682132199406624, + 0.04747068136930466, + -0.06425238400697708, + -0.0010437731398269534, + -0.0071629988960921764, + -0.04302623122930527, + -0.04830477759242058, + -0.04069536179304123, + -0.06627446413040161, + -0.011470981873571873, + 0.03961857780814171, + 0.026594260707497597, + -0.020662540569901466, + -0.05999285355210304, + -0.053548794239759445, + -0.025959201157093048, + -0.015834785997867584, + 0.013910192996263504, + -0.015868371352553368, + -0.056620921939611435, + -0.06785159558057785, + -0.061030179262161255, + -0.03560228645801544, + -0.04177624359726906, + -0.024657463654875755, + -0.04889696091413498, + 0.004557035397738218, + 0.15414470434188843, + 0.21642963588237762, + 0.035425592213869095, + -0.04339970648288727, + -0.05034525692462921, + -0.08522290736436844, + 0.10652441531419754, + 0.6791198253631592, + 0.7785530686378479, + 0.19941796362400055, + -0.05430706962943077, + -0.02583213709294796, + -0.055139996111392975, + 0.17940561473369598, + 0.6757862567901611, + 0.8240399360656738, + 0.25826773047447205, + -0.062254682183265686, + -0.026456547901034355, + -0.027271386235952377, + -0.0026193747762590647, + 0.11893659085035324, + 0.1915995329618454, + 0.013776157051324844, + 0.08452087640762329, + 0.009950258769094944, + 0.01774573139846325, + 0.06609759479761124, + 0.06512798368930817, + 0.07601971179246902, + 0.09192144125699997, + 0.007696932647377253, + -0.056120894849300385, + -0.03937293961644173, + 0.043086692690849304, + 0.055803027004003525, + 0.08208976686000824, + 0.03658852353692055, + 0.025779196992516518, + -0.0340605266392231, + 0.03186321631073952, + 0.09720855951309204, + 0.10651290416717529, + 0.09562067687511444, + 0.08120692521333694, + 0.06832587718963623, + 0.03940538689494133, + 0.09561086446046829, + -0.03726261481642723, + -0.3520663380622864, + -0.4187469184398651, + -0.11643502116203308, + 0.06203937157988548, + 0.056670401245355606, + 0.11540547758340836, + -0.2742924690246582, + -1.1301417350769043, + -1.2482489347457886, + -0.4411431849002838, + 0.08538330346345901, + 0.036888301372528076, + 0.08759869635105133, + -0.32129940390586853, + -1.1163593530654907, + -1.26430082321167, + -0.48638999462127686, + 0.1056363582611084, + 0.042436979711055756, + 0.07075526565313339, + -0.08341801166534424, + -0.30567145347595215, + -0.39268070459365845, + -0.10187282413244247, + -0.02507772110402584, + -0.0044433241710066795, + -0.009278317913413048, + -0.02964872494339943, + -0.018799586221575737, + -0.03760084509849548, + -0.030454028397798538, + -0.004638439975678921, + 0.026587119325995445, + 0.0095819728448987, + -0.007110759150236845, + -0.006491640582680702, + -0.028083719313144684, + -0.009543413296341896, + -0.005706887226551771, + 0.013012710027396679, + -0.010281933471560478, + -0.0544208325445652, + -0.023230208083987236, + -0.05344587564468384, + -0.04052828997373581, + -0.028035156428813934, + -0.011922319419682026, + -0.045427750796079636, + 0.020700184628367424, + 0.2117788940668106, + 0.21090814471244812, + 0.07214333862066269, + -0.019348343834280968, + -0.014455118216574192, + -0.03561105206608772, + 0.17339389026165009, + 0.49509289860725403, + 0.5219546556472778, + 0.26121678948402405, + -0.029803339391946793, + -0.013761913403868675, + -0.04028521850705147, + 0.17008572816848755, + 0.45583003759384155, + 0.4757367670536041, + 0.22357690334320068, + -0.050064269453287125, + -0.021086007356643677, + -0.039873600006103516, + 0.06433176249265671, + 0.20187893509864807, + 0.2078690379858017, + 0.07802058011293411, + 0.022050827741622925, + -0.05272649601101875, + -0.024311071261763573, + -0.12387345731258392, + -0.20065246522426605, + 0.0262442696839571, + 0.20101603865623474, + 0.056791841983795166, + -0.008266052231192589, + 0.025132112205028534, + -0.23289933800697327, + -0.5296569466590881, + -0.282010018825531, + 0.025113720446825027, + 0.13172000646591187, + 0.16999290883541107, + 0.31588253378868103, + 0.05583454668521881, + -0.5321000814437866, + -0.5585035085678101, + -0.23885560035705566, + 0.0461968369781971, + 0.13807418942451477, + 0.6536149382591248, + 0.6385176777839661, + -0.15636183321475983, + -0.5484278798103333, + -0.5470613241195679, + -0.06269911676645279, + -0.06726553291082382, + 0.5561463236808777, + 1.0985187292099, + 0.6801460385322571, + 0.12841203808784485, + -0.21693651378154755, + -0.19168342649936676, + -0.43073776364326477, + -0.15226863324642181, + 0.41150590777397156, + 0.47421786189079285, + 0.25146934390068054, + -0.017203813418745995, + -0.09694849699735641, + -0.4082376956939697, + -0.3549531400203705, + -0.023591510951519012, + 0.12086013704538345, + 0.08050766587257385, + -0.044960521161556244, + -0.0031193571630865335, + 0.014398006722331047, + -0.005931032355874777, + 0.01548685971647501, + 0.05407734215259552, + -0.006386967841535807, + 0.021660227328538895, + 0.01656133122742176, + 0.002835798542946577, + 0.0008500503608956933, + 0.021802745759487152, + 0.13470955193042755, + 0.06802596151828766, + 0.0033256933093070984, + 0.03037848509848118, + 0.054654810577631, + -0.034221138805150986, + 0.015171438455581665, + 0.23395732045173645, + 0.24771827459335327, + 0.16352902352809906, + -0.07505007833242416, + -0.0814652070403099, + -0.21493901312351227, + -0.3109704852104187, + 0.013416547328233719, + 0.12807825207710266, + 0.12044191360473633, + -0.007915153168141842, + 0.0100772799924016, + -0.15165796875953674, + -0.4013277292251587, + -0.24811144173145294, + -0.06641282886266708, + 0.022568246349692345, + 0.061083581298589706, + 0.09920243173837662, + 0.0695505365729332, + -0.12213064730167389, + -0.12606006860733032, + -0.04593949392437935, + -0.040190644562244415, + 0.03899035230278969, + 0.12688779830932617, + 0.114081971347332, + -0.013348283246159554, + 0.03325115144252777, + 0.007111718878149986, + 0.048056699335575104, + -0.003726312192156911, + 0.05401211231946945, + 0.05355936661362648, + 0.21303032338619232, + 0.2944865822792053, + -0.13604623079299927, + -0.3770989775657654, + -0.0808275118470192, + -0.006103217601776123, + -0.02005188539624214, + 0.37605899572372437, + 0.7776278853416443, + 0.32064270973205566, + -0.23708422482013702, + -0.23380732536315918, + -0.22103570401668549, + -0.45596328377723694, + 0.07213663309812546, + 0.9384943246841431, + 0.8762810230255127, + 0.3557227551937103, + -0.09239326417446136, + -0.25462013483047485, + -0.9858288168907166, + -0.9860153198242188, + 0.2600172162055969, + 0.7731484770774841, + 0.7665594816207886, + 0.14806008338928223, + 0.13109923899173737, + -0.6917864680290222, + -1.580305814743042, + -0.9557210803031921, + -0.16357193887233734, + 0.3189502954483032, + 0.28703632950782776, + 0.5599567890167236, + 0.2459551841020584, + -0.5451022982597351, + -0.6926754713058472, + -0.4368602931499481, + 0.027606861665844917, + 0.025857241824269295, + 0.5376880764961243, + 0.535673975944519, + 0.09012678265571594, + -0.14688564836978912, + -0.1812361180782318, + 0.050619762390851974, + 0.021388273686170578, + -0.05923623591661453, + -0.006538081914186478, + 0.05171535536646843, + -0.051560595631599426, + -0.007643367163836956, + 0.027748188003897667, + 0.0024676925968378782, + -0.008760283701121807, + 0.13039670884609222, + 0.18568934500217438, + 0.06342563778162003, + 0.030788781121373177, + -0.004423442296683788, + -0.041261281818151474, + 0.013299684040248394, + 0.22491391003131866, + 0.27831292152404785, + 0.0883866474032402, + 0.048967570066452026, + 0.0012756097130477428, + -0.03215779736638069, + 0.02710782177746296, + 0.20178261399269104, + 0.22446107864379883, + 0.06052157282829285, + 0.019020315259695053, + 0.02715166099369526, + -0.03146626800298691, + -0.017960363999009132, + 0.11820292472839355, + 0.16114193201065063, + 0.05221821367740631, + -0.02201441302895546, + -0.026308327913284302, + 0.008580431342124939, + -0.02444308064877987, + 0.061380185186862946, + 0.11184953153133392, + 0.006053542252629995, + -0.03248603641986847, + -0.037558719515800476, + 0.01881473697721958, + -0.02349863201379776, + 0.02150980569422245, + 0.09881952404975891, + 0.03962325677275658, + -0.0031782283913344145, + -0.0030868228059262037, + -0.007606725674122572, + -0.06136326491832733, + 0.022755015641450882, + 0.09683670848608017, + 0.0016674631042405963, + 0.01306125894188881, + 0.011335537768900394, + -0.01769089885056019, + 0.005807302892208099, + 0.19103741645812988, + 0.2631426155567169, + 0.10424992442131042, + 0.025223100557923317, + -0.024689532816410065, + -0.03370697423815727, + 0.0512213259935379, + 0.2983294129371643, + 0.37597405910491943, + 0.18788966536521912, + 0.056492965668439865, + -0.006051253993064165, + -0.027141474187374115, + 0.06733105331659317, + 0.29171472787857056, + 0.32160115242004395, + 0.14176633954048157, + 0.008538221009075642, + -0.013039524666965008, + -0.04279422387480736, + 0.03345612436532974, + 0.19111940264701843, + 0.25728005170822144, + 0.09830093383789062, + -0.03371569141745567, + -0.05277566984295845, + -0.0011038694065064192, + -0.013657800853252411, + 0.10037966072559357, + 0.1724642813205719, + 0.04436478391289711, + -0.02240786701440811, + -0.02181128039956093, + 0.019526727497577667, + -0.050060197710990906, + 0.017275504767894745, + 0.07785085588693619, + -0.001727179973386228, + -0.0014453287003561854, + 0.019352080300450325, + -0.003202121239155531, + -0.04241566359996796, + 0.005586653482168913, + 0.06037082523107529, + 0.014115821570158005, + -0.00568200321868062, + 0.018071964383125305, + -0.0007147599244490266, + 0.011219227686524391, + 0.10582104325294495, + 0.15557849407196045, + 0.06189450994133949, + 0.014160261489450932, + 0.00814653467386961, + -0.028064200654625893, + 0.026086319237947464, + 0.1474728286266327, + 0.18273885548114777, + 0.06638553738594055, + 0.019263381138443947, + 0.028977060690522194, + -0.02551555074751377, + 0.01937149092555046, + 0.12000202387571335, + 0.1285850703716278, + 0.047506313771009445, + -0.011383740231394768, + 0.02826755866408348, + -0.009583448991179466, + -0.02093282900750637, + 0.07994058728218079, + 0.0926218256354332, + 0.0318426676094532, + -0.024409465491771698, + 0.020994359627366066, + 0.03295197710394859, + -0.034276511520147324, + 0.037398867309093475, + 0.0794353187084198, + 0.022805212065577507, + 0.0015407208120450377, + 0.013169347308576107, + 0.038584139198064804, + -0.002118688775226474, + 0.03358406573534012, + 0.09085306525230408, + 0.04255761206150055, + 0.010275964625179768, + 0.025351760908961296, + 0.04205995053052902, + 0.1319226324558258, + 0.049708493053913116, + -0.03743802383542061, + -0.04293569549918175, + -0.07646205276250839, + -0.04986324533820152, + 0.15992362797260284, + 0.011027384549379349, + -0.32150742411613464, + -0.3761928677558899, + -0.1654653549194336, + -0.08728181570768356, + 0.044714685529470444, + -0.007500737439841032, + -0.41376256942749023, + -0.6625701189041138, + -0.21809393167495728, + 0.10641554743051529, + 0.09274336695671082, + 0.10189083218574524, + -0.1175118163228035, + -0.2905261516571045, + -0.06248515099287033, + 0.4791955053806305, + 0.49865299463272095, + 0.23415400087833405, + 0.12729482352733612, + -0.05814196541905403, + -0.003843356389552355, + 0.16410382091999054, + 0.40895968675613403, + 0.22034852206707, + 0.021014101803302765, + -0.05658271536231041, + -0.012199933640658855, + 0.034277670085430145, + 0.09565535932779312, + 0.18921032547950745, + 0.010441004298627377, + -0.07427560538053513, + -0.09049694985151291, + -0.00554919708520174, + 0.021386168897151947, + 0.0297325998544693, + 0.06431404501199722, + -0.07367311418056488, + -0.08734254539012909, + -0.059512097388505936, + 0.11382041126489639, + 0.19622667133808136, + 0.02534862980246544, + -0.09704668819904327, + -0.10857658833265305, + -0.10241919010877609, + -0.037928055971860886, + 0.17917697131633759, + -0.0396210141479969, + -0.472421795129776, + -0.5453466176986694, + -0.23921693861484528, + -0.06353127211332321, + 0.033679377287626266, + -0.011634309776127338, + -0.523267924785614, + -0.8400278091430664, + -0.3026646375656128, + 0.17986975610256195, + 0.20296970009803772, + 0.14190459251403809, + -0.12953802943229675, + -0.3968985378742218, + -0.13779792189598083, + 0.548722505569458, + 0.7039015293121338, + 0.4025704264640808, + 0.19535738229751587, + -0.08568660169839859, + -0.0589536651968956, + 0.1868993639945984, + 0.5782724618911743, + 0.43018248677253723, + 0.08876730501651764, + -0.10219226032495499, + -0.04660544916987419, + 0.018129168078303337, + 0.14359626173973083, + 0.3174169361591339, + 0.07668197154998779, + -0.13716676831245422, + -0.2058524489402771, + -0.023707473650574684, + 0.03213014453649521, + 0.06718969345092773, + 0.0917893499135971, + -0.10766899585723877, + -0.206499844789505, + -0.12713390588760376, + -0.03174767270684242, + 0.046395305544137955, + 0.018318502232432365, + -0.002416136907413602, + -0.027143845334649086, + -0.0036621293984353542, + -0.019220896065235138, + 0.05427055433392525, + 0.05058867856860161, + -0.05274957790970802, + -0.11321325600147247, + -0.07062514126300812, + -0.01720590703189373, + -0.00901520811021328, + 0.01746262051165104, + -0.08946436643600464, + -0.2304752618074417, + -0.1021895483136177, + 0.013501768000423908, + 0.029721295461058617, + -0.010094762779772282, + 0.009764805436134338, + -0.06424269080162048, + -0.03032868541777134, + 0.13044297695159912, + 0.12166891992092133, + 0.07157951593399048, + 0.029467372223734856, + -0.03827595338225365, + -0.031337328255176544, + -0.026486340910196304, + 0.05953369289636612, + 0.029497025534510612, + 0.022669093683362007, + -0.01055963709950447, + -0.025020133703947067, + 0.002589448355138302, + 0.017152298241853714, + 0.062067389488220215, + 0.008266719058156013, + 0.00563611788675189, + -0.0044869836419820786, + 0.003065212396904826, + 0.014371387660503387, + 0.013636622577905655, + 0.021183570846915245, + -0.012462744489312172, + -0.02493542619049549, + 0.009652925655245781, + -0.09309647232294083, + -0.09614148736000061, + 0.020278261974453926, + 0.262399286031723, + 0.0025974283926188946, + -0.09532646089792252, + -0.0391894206404686, + -0.003332971129566431, + -0.25919869542121887, + -0.2104814499616623, + 0.5975717306137085, + 0.20378711819648743, + -0.20521192252635956, + 0.005045099183917046, + 0.16707547008991241, + -0.08322134613990784, + -1.1734565496444702, + 0.4060916006565094, + 0.9109339714050293, + -0.22450445592403412, + -0.14085394144058228, + 0.19534644484519958, + 0.6220589280128479, + -1.0614460706710815, + -1.2444484233856201, + 1.1965712308883667, + 0.5032565593719482, + -0.26604175567626953, + -0.13583213090896606, + 0.6453277468681335, + 0.4994892477989197, + -1.7917202711105347, + -0.15182015299797058, + 0.7891079783439636, + 0.10711944103240967, + -0.11587982624769211, + 0.08287231624126434, + 0.7848142981529236, + -0.1764022707939148, + -1.0492321252822876, + 0.15281184017658234, + 0.3100045919418335, + -0.0461110882461071, + -0.06824400275945663, + 0.25544390082359314, + 0.3444065451622009, + -0.3189513683319092, + -0.3503313362598419, + 0.05462741479277611, + -0.041028521955013275, + 0.00624969182536006, + -0.0014677124563604593, + 0.10383514314889908, + -0.03467189520597458, + -0.03946290910243988, + 0.012734192423522472, + -0.003676857566460967, + -0.1616411954164505, + -0.034441810101270676, + 0.34758275747299194, + -0.0017601394793018699, + -0.17407774925231934, + 0.05167992413043976, + 0.12394318729639053, + -0.018228475004434586, + -0.71342533826828, + 0.39672648906707764, + 0.4870489537715912, + -0.27272745966911316, + -0.02687050960958004, + 0.09090551733970642, + 0.46698617935180664, + -0.6089348196983337, + -0.7488552331924438, + 0.8327828645706177, + 0.19947239756584167, + -0.17806877195835114, + -0.09197663515806198, + 0.3198661506175995, + 0.42619431018829346, + -1.1321229934692383, + -0.05452701821923256, + 0.4155597984790802, + -0.001295815804041922, + -0.06596186012029648, + -0.05821318179368973, + 0.4515152871608734, + 0.06321248412132263, + -0.6065720319747925, + 0.10882120579481125, + 0.13767170906066895, + 0.01809641905128956, + -0.070295050740242, + 0.04035783186554909, + 0.22459834814071655, + -0.048405971378088, + -0.14622822403907776, + -0.01119917817413807, + 0.00666345190256834, + 0.04815478250384331, + -0.017866114154458046, + -0.04813665896654129, + -0.02366034686565399, + 0.03589487820863724, + -0.0066519430838525295, + 0.0004148671869188547, + -0.014153627678751945, + 0.04403751716017723, + 0.04098428785800934, + -0.10525348782539368, + -0.0078808031976223, + 0.0444580540060997, + -0.027595041319727898, + 0.010916849598288536, + -0.1390431821346283, + 0.20334453880786896, + -0.006475532427430153, + -0.16053295135498047, + 0.06964287906885147, + -0.025649840012192726, + 0.12622858583927155, + -0.09694403409957886, + -0.09791161119937897, + 0.2617567479610443, + -0.06268735229969025, + -0.03128494322299957, + -0.017743078991770744, + -0.02372320368885994, + 0.2195650041103363, + -0.2456466406583786, + 0.031090563163161278, + 0.010196326300501823, + -0.04323133826255798, + 0.02746250294148922, + -0.079569011926651, + 0.06894756853580475, + 0.11414647102355957, + -0.12175147980451584, + 0.025397513061761856, + 0.006027852185070515, + 0.013360690325498581, + -0.024561991915106773, + -0.10966529697179794, + 0.04913714900612831, + 0.09801583737134933, + 0.00013951699656900018, + -0.03194398432970047, + 0.002382949460297823, + -0.003335593966767192, + 0.023621119558811188, + 0.024585755541920662, + -0.016027197241783142, + -0.02846739999949932, + -0.012949706055223942, + -0.020852699875831604, + -0.016913240775465965, + 0.016088848933577538, + 0.141468346118927, + 0.07285624742507935, + -0.008997173048555851, + -0.033306676894426346, + -0.03418722003698349, + -0.15127411484718323, + -0.047440435737371445, + 0.2687169015407562, + 0.17237843573093414, + 0.03505166247487068, + -0.06994523108005524, + -0.031143782660365105, + -0.3024960458278656, + -0.1552041918039322, + 0.33517369627952576, + 0.28441429138183594, + 0.06471730768680573, + -0.0613982267677784, + -0.02271229960024357, + -0.29379361867904663, + -0.3259792923927307, + 0.16062304377555847, + 0.29220375418663025, + 0.10862076282501221, + -0.005909152328968048, + 0.049116987735033035, + -0.20140305161476135, + -0.3278747797012329, + -0.02566053718328476, + 0.14338354766368866, + 0.006411381531506777, + -0.007274044211953878, + 0.08232597261667252, + -0.04198717698454857, + -0.17330540716648102, + -0.01131037063896656, + 0.08018575608730316, + -0.02374250255525112, + -0.002276432001963258, + 0.00019528658594936132, + -0.024716932326555252, + 0.026509074494242668, + 0.08361849933862686, + 0.012956380844116211, + -0.06030649319291115, + -0.020338360220193863, + -0.03577016666531563, + -0.06858085840940475, + 0.008245388977229595, + 0.25225168466567993, + 0.16135559976100922, + -0.03690743073821068, + -0.09188401699066162, + -0.10410526394844055, + -0.25971388816833496, + -0.07926154136657715, + 0.3933144509792328, + 0.33186599612236023, + 0.059405017644166946, + -0.11824909597635269, + -0.10528354346752167, + -0.4808295667171478, + -0.25224801898002625, + 0.4267246127128601, + 0.4853539764881134, + 0.16933484375476837, + -0.073345847427845, + -0.02648857608437538, + -0.4723232388496399, + -0.4904792010784149, + 0.1938265562057495, + 0.44070878624916077, + 0.22439399361610413, + 0.03877745941281319, + 0.08536087721586227, + -0.31432414054870605, + -0.5158097743988037, + -0.09537900239229202, + 0.20227058231830597, + 0.07895126938819885, + 0.059195615351200104, + 0.14728911221027374, + -0.059377528727054596, + -0.2884902060031891, + -0.12288203090429306, + 0.05220698565244675, + -0.045279599726200104, + 0.019795719534158707, + -0.009819806553423405, + -0.013713877648115158, + 0.0012175077572464943, + 0.03281072899699211, + 0.0017424041870981455, + -0.028847966343164444, + -0.0032059827353805304, + -0.020358575507998466, + 0.0009416870889253914, + -0.007760196924209595, + 0.07921157032251358, + 0.03826644644141197, + -0.02976907789707184, + -0.03300238028168678, + -0.017963968217372894, + -0.055836472660303116, + -0.03299689665436745, + 0.15166012942790985, + 0.06786434352397919, + 0.008589516393840313, + -0.05790036544203758, + -0.0029997669626027346, + -0.14070068299770355, + -0.08799122273921967, + 0.19680362939834595, + 0.14703704416751862, + 0.03569985553622246, + -0.02847554162144661, + 0.03601403906941414, + -0.1339161992073059, + -0.20527805387973785, + 0.1060374304652214, + 0.16269326210021973, + 0.0575268417596817, + 0.0029672966338694096, + 0.018848277628421783, + -0.1029881089925766, + -0.19446833431720734, + -0.055140964686870575, + 0.09632515162229538, + 0.01196608692407608, + 0.01994382217526436, + 0.0030014747753739357, + 0.0029817752074450254, + -0.09395840018987656, + -0.038611751049757004, + 0.03793984279036522, + -0.006295992527157068, + 0.01736803539097309, + -0.0961727425456047, + 0.1318971812725067, + 0.00169672432821244, + 0.02773740515112877, + -0.03737606480717659, + -0.02413480542600155, + -0.07371329516172409, + 0.04465596005320549, + 0.34972262382507324, + 0.269726425409317, + 0.14907677471637726, + -0.15323053300380707, + -0.24987848103046417, + -0.32931339740753174, + 0.05209995433688164, + 0.5192161798477173, + 0.5108750462532043, + 0.2627664804458618, + -0.26889729499816895, + -0.49891141057014465, + -0.5081418752670288, + 0.13535383343696594, + 0.7318623661994934, + 0.7116816639900208, + 0.2973657250404358, + -0.38982102274894714, + -0.7131763100624084, + -0.5916072130203247, + 0.1200462281703949, + 0.7752112746238708, + 0.6947993636131287, + 0.21100594103336334, + -0.5576100945472717, + -0.7797606587409973, + -0.6058254837989807, + 0.08617032319307327, + 0.6432424187660217, + 0.522933304309845, + 0.16018754243850708, + -0.5134027004241943, + -0.6838728189468384, + -0.5088241100311279, + 0.10101393610239029, + 0.4321025311946869, + 0.3330003023147583, + 0.10116448998451233, + -0.2786642014980316, + -0.4134466052055359, + -0.3247438967227936, + 0.009768294170498848, + 0.008712833747267723, + -0.029476309195160866, + 0.007709377445280552, + 0.025279967114329338, + 0.01615188643336296, + 0.01585867628455162, + -0.0031516810413450003, + -0.06462288647890091, + -0.055517926812171936, + -0.013180199079215527, + -0.014849795028567314, + 0.05535515025258064, + 0.04162544384598732, + 0.0022392054088413715, + -0.09408581256866455, + -0.07889631390571594, + -0.032870080322027206, + 0.0382377915084362, + 0.07495865970849991, + 0.08439645916223526, + 0.008036677725613117, + -0.1167779192328453, + -0.10782196372747421, + -0.06854722648859024, + 0.06310252100229263, + 0.09643208235502243, + 0.08629462122917175, + -0.016969647258520126, + -0.10456187278032303, + -0.10410942137241364, + -0.017384463921189308, + 0.03931420296430588, + 0.11296819150447845, + 0.08688211441040039, + -0.018024103716015816, + -0.0985492691397667, + -0.10534191876649857, + 0.016594627872109413, + 0.024613894522190094, + 0.09626104682683945, + 0.056779902428388596, + -0.01314453687518835, + -0.1173979789018631, + -0.07576211541891098, + -0.00741730397567153, + 0.04463285952806473, + 0.06365535408258438, + 0.029472019523382187, + 0.06097950413823128, + -0.0884813666343689, + -0.020469073206186295, + -0.004499382339417934, + 0.006147715728729963, + 0.0061135985888540745, + 0.046618249267339706, + -0.024977274239063263, + -0.2809607684612274, + -0.20776452124118805, + -0.10792756825685501, + 0.10520339012145996, + 0.2195160835981369, + 0.27846819162368774, + -0.0425783209502697, + -0.4539273977279663, + -0.4210258722305298, + -0.24160517752170563, + 0.2377386838197708, + 0.4254952371120453, + 0.40258923172950745, + -0.08894401043653488, + -0.6261403560638428, + -0.6177268624305725, + -0.2941279113292694, + 0.36115866899490356, + 0.6176164746284485, + 0.5170959234237671, + -0.12760992348194122, + -0.6392932534217834, + -0.6288641095161438, + -0.20397846400737762, + 0.4859760105609894, + 0.7283636927604675, + 0.5233575105667114, + -0.08038943260908127, + -0.513219952583313, + -0.4611802101135254, + -0.08622774481773376, + 0.41959214210510254, + 0.6145293116569519, + 0.4252074360847473, + -0.08993257582187653, + -0.3586794435977936, + -0.23889268934726715, + -0.07402873039245605, + 0.2362663745880127, + 0.33187127113342285, + 0.24442552030086517, + -0.10037989169359207, + -0.1200498715043068, + -0.06188809871673584, + 0.009648810140788555, + 0.07703708112239838, + -0.07734857499599457, + -0.16337357461452484, + -0.13160429894924164, + -0.037760209292173386, + 0.10750655829906464, + 0.21975228190422058, + 0.21332265436649323, + 0.1482381671667099, + -0.012174196541309357, + -0.03128019720315933, + 0.06983920931816101, + 0.2055918425321579, + 0.16611628234386444, + 0.20955723524093628, + 0.21407610177993774, + 0.13214662671089172, + 0.01558306161314249, + 0.20919384062290192, + 0.21453723311424255, + 0.10980720072984695, + 0.10323476791381836, + 0.1754676252603531, + 0.16320686042308807, + 0.076839879155159, + 0.2669583261013031, + 0.29500535130500793, + 0.18005967140197754, + 0.14900699257850647, + 0.2337430715560913, + 0.2607984244823456, + -0.08909865468740463, + 0.12383633106946945, + 0.27329200506210327, + 0.2634970247745514, + 0.2298160344362259, + 0.22673286497592926, + 0.1753624528646469, + -0.14258335530757904, + -0.033422429114580154, + 0.09338828176259995, + 0.21975602209568024, + 0.2488732784986496, + 0.21165378391742706, + 0.08514796197414398, + 0.0776415765285492, + -0.028732767328619957, + -0.0827818363904953, + -0.14784079790115356, + -0.06101813539862633, + -0.10570015013217926, + -0.07298385351896286, + -0.03352680057287216, + -0.08094660192728043, + -0.08546923100948334, + -0.025722583755850792, + -0.04828448221087456, + -0.15816760063171387, + -0.22295169532299042, + -0.04976325109601021, + -0.12255501747131348, + -0.04869991913437843, + 0.09818085283041, + 0.2285904735326767, + 0.015187943354249, + -0.19952231645584106, + -0.1415022611618042, + -0.09511925280094147, + 0.10828559100627899, + 0.35640013217926025, + 0.5399265289306641, + 0.3026861250400543, + -0.10532847791910172, + -0.0455780103802681, + -0.09365752339363098, + 0.2482689470052719, + 0.5483031272888184, + 0.6572608947753906, + 0.4098849594593048, + -0.0039499495178461075, + -0.11641024053096771, + -0.22666053473949432, + -0.03133581206202507, + 0.2815704643726349, + 0.3229265809059143, + 0.009749597869813442, + -0.19616934657096863, + -0.05046992748975754, + -0.15597671270370483, + -0.22775587439537048, + -0.14872166514396667, + -0.12174414098262787, + -0.23433859646320343, + -0.238412007689476, + 0.09725375473499298, + 0.08522887527942657, + 0.006490080617368221, + -0.024619178846478462, + 0.07278231531381607, + 0.13406167924404144, + 0.22993306815624237, + 0.10250072181224823, + 0.09119024127721786, + -0.07687287777662277, + -0.1012108325958252, + -0.09500063210725784, + -0.10082961618900299, + 0.09466016292572021, + 0.11299365013837814, + -0.033278487622737885, + -0.20269805192947388, + -0.21449527144432068, + -0.08820098638534546, + -0.18970704078674316, + -0.050536416471004486, + -0.03471578657627106, + -0.13205547630786896, + -0.18150201439857483, + -0.03963223099708557, + 0.13029472529888153, + -0.11594776809215546, + -0.173879474401474, + 0.017406627535820007, + -0.11885572224855423, + -0.06966021656990051, + 0.1687183529138565, + 0.2677668035030365, + -0.020446041598916054, + -0.11710261553525925, + 0.044354867190122604, + -0.10054060816764832, + -0.1287878155708313, + -0.03600803390145302, + -0.03198331966996193, + -0.22372953593730927, + -0.11045534163713455, + 0.22963544726371765, + 0.16736479103565216, + -0.023956498131155968, + -0.0882943719625473, + -0.11904646456241608, + -0.10481738299131393, + 0.083598293364048, + 0.058089643716812134, + -0.04821285232901573, + 0.16764044761657715, + -0.13788309693336487, + -0.1412951946258545, + 0.059633608907461166, + 0.012824267148971558, + -0.03141501545906067, + -0.017422236502170563, + 0.3908282518386841, + -0.31520241498947144, + -0.27876099944114685, + 0.17109407484531403, + 0.011913848109543324, + -0.04440265893936157, + 0.05610174685716629, + 0.5290316343307495, + -0.4506116211414337, + -0.2946499288082123, + 0.2802693545818329, + 0.04180249199271202, + -0.05673402547836304, + 0.0445592887699604, + 0.4933576285839081, + -0.4903600513935089, + -0.3259376883506775, + 0.26069584488868713, + 0.047843094915151596, + -0.053804315626621246, + 0.029928382486104965, + 0.3588394224643707, + -0.39090782403945923, + -0.18598265945911407, + 0.1703576147556305, + 0.010407418012619019, + 0.019840527325868607, + -0.017079327255487442, + 0.21012797951698303, + -0.1586841642856598, + -0.12738685309886932, + 0.12431345880031586, + 0.028149213641881943, + 0.05083676427602768, + -0.07053223252296448, + 0.12090320140123367, + -0.13737183809280396, + -0.09807822853326797, + 0.07203921675682068, + -0.01965559460222721, + 0.036479320377111435, + -0.02657422423362732, + 0.2924504280090332, + -0.19397024810314178, + -0.20908842980861664, + 0.07435549795627594, + 0.011985386721789837, + -0.051603686064481735, + 0.039122600108385086, + 0.5911946892738342, + -0.45937344431877136, + -0.43863579630851746, + 0.23180224001407623, + 0.05592876672744751, + -0.10227655619382858, + 0.1371937245130539, + 0.7193072438240051, + -0.6789532899856567, + -0.5275344252586365, + 0.4098500609397888, + 0.09136661887168884, + -0.08802130073308945, + 0.12226735055446625, + 0.6819202303886414, + -0.7316576838493347, + -0.5229181051254272, + 0.37578293681144714, + 0.09086397290229797, + -0.05128701403737068, + 0.09287497401237488, + 0.5103837251663208, + -0.6150248646736145, + -0.3208717107772827, + 0.29780012369155884, + 0.071808360517025, + 0.04605705663561821, + 0.028153980150818825, + 0.30872926115989685, + -0.32211968302726746, + -0.1925150454044342, + 0.18948692083358765, + 0.07391810417175293, + 0.08546463400125504, + -0.07042243331670761, + 0.14390304684638977, + -0.22509464621543884, + -0.12615789473056793, + 0.09681600332260132, + 0.0030679223127663136, + 0.06206878274679184, + -0.0493885837495327, + 0.11675205081701279, + -0.09476804733276367, + -0.0708041712641716, + 0.027848264202475548, + 0.018535451963543892, + 0.01112216804176569, + -0.023546719923615456, + 0.2808285057544708, + -0.2312571257352829, + -0.16320407390594482, + 0.15229304134845734, + -0.007220278959721327, + -0.026767488569021225, + -0.008487970568239689, + 0.39064091444015503, + -0.3746477961540222, + -0.22930599749088287, + 0.23297259211540222, + -0.020648201927542686, + -0.03918099403381348, + -0.03193120285868645, + 0.37857353687286377, + -0.38306936621665955, + -0.25103962421417236, + 0.2414209097623825, + 0.007709929719567299, + -0.041483473032712936, + -0.001570625347085297, + 0.315625935792923, + -0.276553213596344, + -0.13154125213623047, + 0.17517149448394775, + 0.03219839558005333, + 0.002647437620908022, + -0.012777225114405155, + 0.17064248025417328, + -0.13943275809288025, + -0.10204917937517166, + 0.09418098628520966, + 0.026260169222950935, + 0.05167905613780022, + -0.024634944275021553, + 0.0931941494345665, + -0.11875593662261963, + -0.0752263143658638, + 0.0569780170917511, + 0.00024334408226422966, + -0.001991289434954524, + -0.012094452045857906, + -0.0012201170902699232, + 0.01342268567532301, + 0.006425719242542982, + 0.01147665549069643, + -0.002208880614489317, + -0.019385183230042458, + -0.024868011474609375, + 0.00465290667489171, + 0.009205960668623447, + 0.0016242304118350148, + 0.0059639886021614075, + -0.03436571732163429, + 0.01672518253326416, + 0.008815832436084747, + 0.06389293074607849, + 0.06249547377228737, + 0.06542838364839554, + 0.043118152767419815, + 0.04117512330412865, + 0.014435848221182823, + 0.0065850247628986835, + 0.03811212629079819, + -0.006077034864574671, + -0.004025861620903015, + 0.006247953977435827, + 0.014478449709713459, + 0.0009701942908577621, + -0.002422194229438901, + 0.009390920400619507, + -0.052253514528274536, + -0.05192738026380539, + -0.010346310213208199, + -0.001328076352365315, + -0.002972622634842992, + 0.0015572139527648687, + 0.022503724321722984, + -0.002475353656336665, + 0.001927886507473886, + 0.02994818612933159, + 0.02062363363802433, + -0.0010653833160176873, + -0.005995174869894981, + 0.024450020864605904, + 0.013005194254219532, + 0.0496530681848526, + 0.029475165531039238, + 0.004157512914389372, + -0.0007043799851089716, + 0.01860312558710575, + 0.03839566186070442, + 0.00014980587002355605, + 0.018569663166999817, + 0.05668198689818382, + 0.04645680636167526, + 0.01642409712076187, + 0.03577466681599617, + 0.03575601801276207, + -0.03680748492479324, + -0.01865880750119686, + 0.041660092771053314, + 0.033268485218286514, + 0.03338993713259697, + 0.04665865749120712, + -0.03322917968034744, + -0.2860279381275177, + -0.28877392411231995, + -0.09617949277162552, + 0.014234350994229317, + 0.038012001663446426, + -0.016850680112838745, + -0.27252569794654846, + -0.6714493632316589, + -0.686245322227478, + -0.3376169502735138, + -0.0812990590929985, + 0.003058002796024084, + -0.026376569643616676, + -0.29216718673706055, + -0.6779875159263611, + -0.6917123198509216, + -0.3184400796890259, + -0.058261968195438385, + 0.06338769942522049, + 0.03199980780482292, + -0.09837217628955841, + -0.3355932831764221, + -0.30900436639785767, + -0.04878076910972595, + 0.061543505638837814, + 0.04651529714465141, + 0.0263908002525568, + 0.0030237447936087847, + -0.10458099842071533, + -0.07959774881601334, + 0.05430716276168823, + 0.056767694652080536, + 0.00796051137149334, + -0.016737859696149826, + -0.042338743805885315, + -0.0198048185557127, + -0.03085070475935936, + -0.058721307665109634, + -0.036032311618328094, + -0.0035414688754826784, + -8.359456842299551e-05, + -0.02213932015001774, + 0.02032857947051525, + 0.021788733080029488, + -0.03522418439388275, + -0.025317413732409477, + -0.042937491089105606, + -0.05680134892463684, + -0.012510996311903, + 0.226289302110672, + 0.24401520192623138, + 0.022300971671938896, + -0.030825607478618622, + -0.05485948920249939, + 0.007590078748762608, + 0.2208130657672882, + 0.6964298486709595, + 0.7457719445228577, + 0.3470557630062103, + 0.06941442936658859, + -0.03543366119265556, + 0.035853609442710876, + 0.2872598171234131, + 0.7504303455352783, + 0.7509996294975281, + 0.34327855706214905, + 0.024429334327578545, + -0.05711393058300018, + -0.034500252455472946, + 0.057939525693655014, + 0.33292675018310547, + 0.3141649067401886, + 0.033748809248209, + -0.062175147235393524, + -0.041224412620067596, + -0.01891348883509636, + -0.014519350603222847, + 0.08635713160037994, + 0.03148616850376129, + -0.08749162405729294, + -0.05658482387661934, + 0.00018510188965592533, + 0.002624311950057745, + -0.003570129396393895, + 0.0067627751268446445, + -0.01349653396755457, + -0.003961967770010233, + 0.0034001911990344524, + -0.00385954394005239, + 0.018012456595897675, + -0.018755480647087097, + -0.03163064643740654, + -0.0035233700182288885, + 0.011690095998346806, + -0.014693490229547024, + 0.017746854573488235, + 0.05693097040057182, + 0.1272590607404709, + 0.23477119207382202, + 0.19823509454727173, + 0.05071045830845833, + -0.007188393268734217, + -0.05571149289608002, + -0.06468938291072845, + -0.017831332981586456, + -0.07572834193706512, + -0.19599483907222748, + -0.15608063340187073, + -0.039450764656066895, + -0.035583946853876114, + -0.1605951488018036, + -0.5041624307632446, + -0.6836286783218384, + -0.3773191571235657, + -0.08623629808425903, + -0.04881078004837036, + 0.029403403401374817, + 0.15516817569732666, + 0.4108496308326721, + 0.6393839716911316, + 0.4688946008682251, + 0.2135964334011078, + 0.0623941570520401, + 0.02426956780254841, + -8.065254223765805e-05, + -0.00816427543759346, + -0.09353788942098618, + -0.06872912496328354, + -0.029405562207102776, + 0.012364620342850685, + 0.0060868943110108376, + 0.017015695571899414, + -0.0076495204120874405, + -0.006090708542615175, + -0.016521835699677467, + 0.009218892082571983, + 0.030833140015602112, + -0.0002345978282392025, + 0.03332215175032616, + 0.0030349211301654577, + 0.009600857272744179, + 0.05706647038459778, + 0.06095677986741066, + -0.016137542203068733, + 0.03195658698678017, + 0.13535599410533905, + 0.28229761123657227, + 0.4573267698287964, + 0.39102476835250854, + 0.17547546327114105, + 0.005337159149348736, + -0.07699840515851974, + -0.12667469680309296, + -0.16613735258579254, + -0.2908898890018463, + -0.44942277669906616, + -0.34229782223701477, + -0.16225378215312958, + -0.1100199744105339, + -0.4044281840324402, + -0.9058251976966858, + -1.1549302339553833, + -0.7502554059028625, + -0.2716369032859802, + -0.13495275378227234, + 0.08614412695169449, + 0.3164423108100891, + 0.7155097723007202, + 1.0356683731079102, + 0.7939887642860413, + 0.39567017555236816, + 0.16957539319992065, + 0.02675812318921089, + 0.048314403742551804, + 0.053107086569070816, + -0.009243623353540897, + -0.011442561633884907, + 0.004911235999315977, + 0.012210517190396786, + 0.006660772021859884, + -0.004562888294458389, + -0.009606098756194115, + -0.01610635593533516, + -0.03475078567862511, + 0.007796770427376032, + 0.02015513926744461, + 0.020311446860432625, + 0.009043446741998196, + -0.01929326355457306, + -0.04183953255414963, + -0.003052672604098916, + 0.020744286477565765, + 0.01371331699192524, + 0.004048139322549105, + 0.0692848190665245, + 0.16867054998874664, + 0.2799474000930786, + 0.28119951486587524, + 0.13579942286014557, + -0.0015732255997136235, + -0.05406518653035164, + -0.05831173434853554, + -0.034435681998729706, + -0.11925295740365982, + -0.2570647895336151, + -0.19120880961418152, + -0.09981344640254974, + -0.011702792719006538, + -0.22477947175502777, + -0.5395713448524475, + -0.7111374139785767, + -0.4207299053668976, + -0.11811137199401855, + -0.035199034959077835, + 0.024358956143260002, + 0.16262274980545044, + 0.46769100427627563, + 0.677872896194458, + 0.4637402892112732, + 0.15558630228042603, + 0.04467496648430824, + 0.03221412003040314, + 0.02430277317762375, + -0.006398700177669525, + -0.07235423475503922, + -0.03669704124331474, + -0.000992153538390994, + 0.02220241352915764, + -0.03329842537641525, + 0.05199713259935379, + -0.14053553342819214, + 0.1906905472278595, + -0.13544943928718567, + 0.08535720407962799, + -0.009813228622078896, + 0.03578176349401474, + -0.05863757058978081, + 0.33848440647125244, + -0.49837300181388855, + 0.15308170020580292, + 0.14865124225616455, + -0.12349266558885574, + -0.025796135887503624, + 0.17790427803993225, + -0.7813658714294434, + 0.853188693523407, + 0.2489670068025589, + -0.7378701567649841, + 0.2207188457250595, + 0.05207442864775658, + -0.4280349314212799, + 1.1408430337905884, + -0.24505679309368134, + -1.5490919351577759, + 1.4560288190841675, + -0.31143030524253845, + -0.03536878153681755, + 0.5640448331832886, + -0.6874421834945679, + -1.210310697555542, + 2.6637399196624756, + -1.6589887142181396, + 0.2221546173095703, + 0.10179737955331802, + -0.4354941248893738, + 0.034149203449487686, + 1.480568528175354, + -2.072199821472168, + 0.9205833673477173, + 0.021510563790798187, + -0.07755836099386215, + 0.17983688414096832, + 0.040537625551223755, + -0.5325585603713989, + 0.550999641418457, + -0.11060550063848495, + -0.09052976220846176, + -0.048361390829086304, + 0.03450514376163483, + -0.11854307353496552, + 0.23462797701358795, + -0.17563995718955994, + 0.0653814822435379, + -0.009748813696205616, + 0.07013920694589615, + -0.08628369867801666, + 0.3019683063030243, + -0.630340576171875, + 0.274477481842041, + 0.15417183935642242, + -0.036220982670784, + -0.07344137132167816, + 0.2339126616716385, + -1.0395091772079468, + 1.2002928256988525, + 0.085142120718956, + -0.7080597281455994, + 0.23101751506328583, + 0.016307154670357704, + -0.45877355337142944, + 1.617128849029541, + -0.6593433618545532, + -1.8957709074020386, + 1.746606469154358, + -0.37062564492225647, + 0.01213759370148182, + 0.5851964354515076, + -1.0307577848434448, + -1.4803766012191772, + 3.812014102935791, + -2.0028398036956787, + 0.12008816003799438, + 0.01813559979200363, + -0.5065457820892334, + 0.17598780989646912, + 2.0418734550476074, + -2.680522918701172, + 0.7466094493865967, + 0.16271913051605225, + -0.04379571974277496, + 0.21930621564388275, + 0.041255541145801544, + -0.6644601821899414, + 0.481300413608551, + 0.05410065874457359, + -0.09025495499372482, + 0.01954805478453636, + 0.01899997517466545, + -0.1337241530418396, + 0.19821906089782715, + -0.06395180523395538, + -0.03586877882480621, + 0.01973363384604454, + 0.013873124495148659, + -0.09288538247346878, + 0.4300728440284729, + -0.4235192537307739, + 0.03646458685398102, + 0.10077393800020218, + -0.07569073140621185, + -0.08176662772893906, + 0.3834531605243683, + -0.747482419013977, + 0.4493187367916107, + 0.2960513234138489, + -0.5245057344436646, + 0.27831950783729553, + 0.0731748417019844, + -0.45574328303337097, + 0.6987965703010559, + 0.019539732486009598, + -1.1160184144973755, + 1.0756875276565552, + -0.3804619312286377, + -0.040626902133226395, + 0.2780243456363678, + -0.32946258783340454, + -0.8122196793556213, + 1.9535348415374756, + -1.300661563873291, + 0.3443142771720886, + 0.04858396574854851, + -0.17409801483154297, + -0.07783844321966171, + 1.0875797271728516, + -1.5148566961288452, + 0.8014272451400757, + -0.19643208384513855, + -0.033590562641620636, + 0.11178025603294373, + 0.08284300565719604, + -0.5165408849716187, + 0.5841389894485474, + -0.24739950895309448, + 0.027926180511713028, + -0.028708497062325478, + 0.0037401756271719933, + -0.0047450135461986065, + 0.008427698165178299, + 0.009801353327929974, + -0.0029346586670726538, + -0.010193527676165104, + 0.014876358211040497, + 0.009861295111477375, + -0.005554665345698595, + -0.06270359456539154, + -0.0316256619989872, + 0.006706684362143278, + 0.04316525161266327, + 0.008637072518467903, + -0.03666357323527336, + -0.0719730481505394, + -0.1525861918926239, + -0.14396126568317413, + -0.05387119948863983, + 0.01955549605190754, + 0.007112634833902121, + -0.05175568535923958, + -0.16772602498531342, + -0.20807777345180511, + -0.18768996000289917, + -0.17093753814697266, + -0.03334345668554306, + 0.0011808606795966625, + -0.01579100452363491, + -0.12589050829410553, + -0.17219413816928864, + -0.19648219645023346, + -0.21980451047420502, + -0.04920821264386177, + 0.0012217299081385136, + 0.023885242640972137, + -0.056074876338243484, + -0.13907776772975922, + -0.19139252603054047, + -0.13652737438678741, + -0.0027339402586221695, + 0.004720518831163645, + -0.00037206560955382884, + 0.017924504354596138, + -0.02118082158267498, + -0.06553903222084045, + -0.0435921773314476, + 0.02721239998936653, + 0.020702000707387924, + 0.024033410474658012, + 0.005382229574024677, + -0.01273527555167675, + -0.01742861233651638, + 0.007402990944683552, + 0.010333286598324776, + 0.02598601020872593, + 0.012456837110221386, + -0.03471057116985321, + -0.10051856189966202, + -0.08084382116794586, + -0.023420603945851326, + 0.031205907464027405, + 0.00424322672188282, + -0.03734385594725609, + -0.1152661070227623, + -0.2012551724910736, + -0.1995576024055481, + -0.07972321659326553, + -0.011126434430480003, + -0.0185835100710392, + -0.06944561004638672, + -0.21481844782829285, + -0.26795628666877747, + -0.24916253983974457, + -0.17833945155143738, + -0.06658200174570084, + -0.00305415247566998, + -0.054028186947107315, + -0.19072681665420532, + -0.256619930267334, + -0.26868295669555664, + -0.21621295809745789, + -0.06564134359359741, + 0.0031192339956760406, + 0.013205861672759056, + -0.08044812828302383, + -0.18137820065021515, + -0.23007699847221375, + -0.13054916262626648, + -0.01135951280593872, + 0.013734308071434498, + 0.010981118306517601, + -0.02249351143836975, + -0.05804377421736717, + -0.10652261227369308, + -0.04163172468543053, + 0.017101088538765907, + -0.028687385842204094, + -0.0019976652693003416, + 0.009987232275307178, + 0.010130539536476135, + 0.0015575449215248227, + -0.000983694102615118, + -0.012845008634030819, + 0.01329281460493803, + 0.0029350779950618744, + -0.003755913581699133, + -0.036475058645009995, + -0.0245466697961092, + -0.0020879909861832857, + 0.025867130607366562, + -0.0065954397432506084, + 0.008656582795083523, + -0.04037104919552803, + -0.11718368530273438, + -0.13506115972995758, + -0.024255141615867615, + 0.014097613282501698, + -0.0009370348998345435, + -0.010953565128147602, + -0.12869219481945038, + -0.18789908289909363, + -0.19098156690597534, + -0.12795749306678772, + -0.002666366985067725, + -0.004907527007162571, + -0.014610078185796738, + -0.11913872510194778, + -0.19921070337295532, + -0.21869640052318573, + -0.1849898099899292, + -0.03470952808856964, + 0.0064156935550272465, + 0.03401843458414078, + -0.04000416398048401, + -0.12354391813278198, + -0.16908879578113556, + -0.10385500639677048, + 0.002833302365615964, + -0.036176733672618866, + -0.001048827893100679, + 0.010002595372498035, + -0.020798830315470695, + -0.0488261841237545, + -0.002972641494125128, + 0.016395021229982376, + -0.045770127326250076, + -0.12710650265216827, + -0.1637774109840393, + -0.1411965787410736, + 0.20447289943695068, + 0.509396493434906, + 0.07264503091573715, + 0.12041529268026352, + -0.015143441036343575, + -0.2673257887363434, + -0.3589763641357422, + 0.11289574205875397, + 0.8517020344734192, + 0.7068799138069153, + 0.067301444709301, + -0.02102830447256565, + -0.5235708355903625, + -1.2064802646636963, + -0.856619656085968, + 0.26774707436561584, + 0.6825867295265198, + 0.13516077399253845, + 0.3054035007953644, + -0.0727991834282875, + -1.4912222623825073, + -1.906838297843933, + -0.8574200868606567, + -0.15282419323921204, + 0.39327505230903625, + 0.9758505821228027, + 1.2323224544525146, + 0.18179064989089966, + -0.947610080242157, + -0.6657719016075134, + -0.19935055077075958, + -0.09150458872318268, + 0.34379544854164124, + 1.2025749683380127, + 0.9517407417297363, + -0.12023784220218658, + -0.3146151900291443, + -0.1049022302031517, + -0.34867578744888306, + -0.32945582270622253, + 0.28920575976371765, + 0.7844374179840088, + 0.35520124435424805, + 0.007452746387571096, + 0.018862545490264893, + -0.0021927610505372286, + 0.0321974977850914, + 0.05439181253314018, + -0.030729038640856743, + -0.03517322614789009, + -0.037830010056495667, + -0.056672073900699615, + -0.017769837751984596, + 0.06385952979326248, + 0.08161566406488419, + 0.07809178531169891, + 0.06333671510219574, + -0.036322008818387985, + -0.06432312726974487, + -0.03629852458834648, + 0.010879911482334137, + 0.088901087641716, + 0.0021402277052402496, + 0.09618857502937317, + 0.02661084569990635, + -0.03414442762732506, + -0.08736730366945267, + -0.048222169280052185, + 0.03507986292243004, + -0.053828027099370956, + 0.006044292356818914, + 0.04232194274663925, + 0.001624415279366076, + -0.028371643275022507, + -0.08724038302898407, + -0.005835397634655237, + 0.01057528518140316, + 0.04210871085524559, + 0.06106603890657425, + 0.04250370338559151, + 0.0028668276499956846, + -0.07583706080913544, + -0.06849333643913269, + -0.08538331836462021, + -0.021475542336702347, + 0.044341571629047394, + 0.03604369983077049, + 0.05146002024412155, + 0.00280605535954237, + -0.004615028854459524, + -0.07857430726289749, + -0.03716180846095085, + 0.010876243002712727, + -0.03418488800525665, + 0.007391764782369137, + 0.05969953536987305, + 0.08769611269235611, + 0.066011443734169, + -0.10404568910598755, + -0.27194535732269287, + -0.05224551260471344, + -0.03618992492556572, + -0.023098375648260117, + 0.13832588493824005, + 0.21510572731494904, + -0.07285867631435394, + -0.489085853099823, + -0.33285844326019287, + -0.04830349236726761, + 0.014211038127541542, + 0.2612524926662445, + 0.6911754608154297, + 0.5294638276100159, + -0.2706173360347748, + -0.39350029826164246, + -0.05156399682164192, + -0.16490484774112701, + 0.1161464974284172, + 0.8029336929321289, + 1.1809980869293213, + 0.5025736689567566, + 0.07084998488426208, + -0.1901131123304367, + -0.4918227195739746, + -0.603122889995575, + -0.09460704773664474, + 0.5786081552505493, + 0.35392242670059204, + 0.1328991800546646, + -0.008106965571641922, + -0.2159435749053955, + -0.6369062662124634, + -0.5241336822509766, + 0.06276796758174896, + 0.1139409989118576, + 0.05483332276344299, + 0.1703934520483017, + 0.14603517949581146, + -0.16187912225723267, + -0.4139055907726288, + -0.14918148517608643, + -0.06163417547941208, + 0.005302567034959793, + 0.015524876303970814, + -0.11895350366830826, + -0.19724233448505402, + 0.03412429615855217, + 0.10862118750810623, + 0.08550503104925156, + -0.008599682711064816, + -0.03031114675104618, + -0.33224624395370483, + -0.27994298934936523, + 0.196475550532341, + 0.31109708547592163, + 0.17151644825935364, + -0.04994147643446922, + -0.167176753282547, + -0.5247878432273865, + -0.21136601269245148, + 0.54701828956604, + 0.6110883951187134, + 0.04194486886262894, + -0.27640673518180847, + -0.0795169249176979, + -0.360530287027359, + 0.3472684621810913, + 1.5428175926208496, + 1.0249378681182861, + -0.2724844515323639, + -0.3013695478439331, + 0.020736562088131905, + -0.019495302811264992, + 0.7758124470710754, + 1.5381159782409668, + 0.028625331819057465, + -1.289720892906189, + -0.5894255638122559, + 0.0526396706700325, + 0.11443997919559479, + 0.5935031771659851, + 0.47169724106788635, + -1.2507063150405884, + -1.351940631866455, + -0.03894977271556854, + 0.05095001682639122, + 0.01581231690943241, + 0.11137383431196213, + -0.22327138483524323, + -0.9629225730895996, + -0.2607772946357727, + 0.5907121300697327, + 0.006906076334416866, + 0.002633580705150962, + 0.01940075121819973, + 0.0143396882340312, + 0.020781584084033966, + -0.07249777764081955, + -0.016355905681848526, + 0.016553230583667755, + -0.027528395876288414, + 0.0244428887963295, + 0.024910561740398407, + 0.027229825034737587, + -0.04104151204228401, + 0.007100561633706093, + 0.0157785601913929, + -0.06626633554697037, + 0.006520191207528114, + 0.021171070635318756, + 0.036674920469522476, + -0.06950324773788452, + -0.03003627620637417, + 2.178798422391992e-05, + -0.07278106361627579, + 0.014382920227944851, + 0.0982266515493393, + 0.1454961597919464, + -0.10096189379692078, + 0.022237209603190422, + -0.00040665315464138985, + -0.013766243122518063, + 0.06440296769142151, + 0.21751047670841217, + 0.02519127167761326, + -0.23383572697639465, + 0.0038903038948774338, + -0.042271602898836136, + -0.012596859596669674, + 0.023778460919857025, + 0.07685687392950058, + -0.21480663120746613, + -0.19205358624458313, + 0.04876565560698509, + -0.016765035688877106, + -0.02620583213865757, + 0.01641852967441082, + 0.02201787941157818, + -0.07457322627305984, + -0.003633625339716673, + 0.07550841569900513, + 0.024774253368377686, + 0.04710151255130768, + 0.09110233932733536, + -0.017366377636790276, + -0.04366954043507576, + -0.039786458015441895, + 0.005311290733516216, + 0.037867460399866104, + 0.05367766693234444, + 0.07434491813182831, + -0.07251215726137161, + -0.04231821000576019, + -0.023427855223417282, + 0.036294277757406235, + 0.07782749086618423, + 0.11835407465696335, + 0.08753973245620728, + -0.20742319524288177, + -0.13341759145259857, + -0.008225077763199806, + 0.07292432337999344, + 0.006392402108758688, + 0.021914338693022728, + -0.09218581020832062, + -0.44192466139793396, + -0.1744878888130188, + 0.014938815496861935, + 0.10678526759147644, + -0.012087192386388779, + -0.024533385410904884, + -0.1804407387971878, + -0.3253834545612335, + 0.040678758174180984, + 0.2011708915233612, + 0.17262929677963257, + -0.0045212251134216785, + -0.033313386142253876, + -0.10575363039970398, + -0.07636692374944687, + 0.20343273878097534, + 0.28330928087234497, + 0.043149981647729874, + -0.01109551265835762, + -0.0027725452091544867, + 0.003926735837012529, + 0.029440222308039665, + 0.23945140838623047, + 0.09122566133737564, + -0.15140119194984436, + 0.08737201988697052, + 0.07120998948812485, + 0.05722665786743164, + -0.04388495534658432, + 0.02116825245320797, + 0.023315919563174248, + 0.10898162424564362, + 0.11808467656373978, + 0.03412344306707382, + 0.002771642990410328, + -0.1959579437971115, + -0.05181330814957619, + -0.0044630044139921665, + 0.12481725960969925, + 0.09140311926603317, + 0.03444851189851761, + -0.10931172221899033, + -0.3204459846019745, + -0.21193139255046844, + -0.11101037263870239, + 0.04186606407165527, + -0.07420916110277176, + -0.2004990428686142, + -0.26937955617904663, + -0.12928874790668488, + 0.20819628238677979, + -0.17379426956176758, + -0.2181481271982193, + 0.005387924611568451, + -0.24132733047008514, + -0.23942433297634125, + 0.41489261388778687, + 1.0702778100967407, + 0.024913936853408813, + -0.28405970335006714, + 0.083008773624897, + -0.11059781163930893, + -0.17623695731163025, + -0.17386195063591003, + 0.010644182562828064, + -0.32716259360313416, + -0.2135595828294754, + 0.1223129853606224, + 0.07060510665178299, + -0.048680394887924194, + -0.3332099914550781, + -0.25886017084121704, + -0.18619979918003082, + -0.00733158877119422, + 0.03393476828932762, + -0.010564662516117096, + -0.01817108877003193, + -0.05650597810745239, + -0.01891104131937027, + -0.0554141066968441, + -0.004592927638441324, + -0.0013615720672532916, + -0.05552899092435837, + -0.0560498908162117, + -0.1080632209777832, + -0.013965745456516743, + -0.03290533646941185, + -0.02599845454096794, + -0.02877708151936531, + -0.05670137703418732, + -0.07158109545707703, + -0.08808472007513046, + -0.03919175639748573, + -0.08478893339633942, + -0.08045543730258942, + -0.10066724568605423, + -0.048338882625103, + -0.06750114262104034, + 0.08164039999246597, + 0.3343777060508728, + 0.004952755756676197, + -0.14891156554222107, + 0.032855477184057236, + -0.03277512267231941, + 0.0474768728017807, + 0.6316664814949036, + 1.2214386463165283, + 0.2548498213291168, + -0.13185030221939087, + -0.018188906833529472, + -0.07653989642858505, + -0.01643386110663414, + 0.06630122661590576, + 0.23864209651947021, + -0.013703612610697746, + -0.09347789734601974, + -0.0900193303823471, + -0.04930814355611801, + -0.02791711315512657, + -0.15441712737083435, + -0.01623091846704483, + -0.0447690524160862, + -0.06071227043867111, + -0.04737209901213646, + -0.059769801795482635, + -0.04375007003545761, + -0.00650476710870862, + 0.021540174260735512, + -0.05590728670358658, + -0.13030850887298584, + -0.022067781537771225, + -0.05066747963428497, + 0.00609770929440856, + 0.108611099421978, + 0.1621929407119751, + 0.05232185125350952, + -0.049729123711586, + -0.11906369775533676, + -0.030973592773079872, + 0.057787079364061356, + 0.1610448956489563, + 0.18756121397018433, + 0.07277501374483109, + -0.05777435004711151, + -0.05227195844054222, + 0.14434091746807098, + 0.1889694482088089, + 0.26951169967651367, + 0.4710105359554291, + 0.2164669781923294, + 0.05052375793457031, + -0.0038236663676798344, + 0.20267778635025024, + 0.31214746832847595, + 0.7506387829780579, + 1.2302387952804565, + 0.4363090693950653, + 0.16759593784809113, + -0.049752235412597656, + 0.044786907732486725, + 0.14537742733955383, + 0.2227499932050705, + 0.37362414598464966, + 0.16590620577335358, + 0.0864599421620369, + -0.14058542251586914, + -0.04404178634285927, + -0.0325944609940052, + -0.019113417714834213, + 0.17414243519306183, + 0.11160623282194138, + -0.034911543130874634, + 0.1523953527212143, + 0.04554234445095062, + -0.054958827793598175, + -0.11794494092464447, + -0.19570015370845795, + -0.21358126401901245, + -0.1885669231414795, + -0.08286706358194351, + -0.29818814992904663, + -0.52330082654953, + -0.6190353631973267, + -0.682529091835022, + -0.6171367764472961, + -0.4793100655078888, + -0.11180876195430756, + -0.3490432798862457, + -0.5531057715415955, + -0.6426181793212891, + -0.6420838832855225, + -0.4970071613788605, + -0.27038174867630005, + -0.09740017354488373, + -0.1929621547460556, + -0.30848363041877747, + -0.27204805612564087, + -0.2515120208263397, + -0.07497832179069519, + 0.03551386669278145, + -0.05060403421521187, + 0.08276989310979843, + 0.14321963489055634, + 0.3583574593067169, + 0.40667927265167236, + 0.39398193359375, + 0.27561235427856445, + 0.005085935816168785, + 0.2793635427951813, + 0.48155927658081055, + 0.7088037729263306, + 0.7394692897796631, + 0.6158861517906189, + 0.3986552655696869, + 0.025508087128400803, + 0.38533228635787964, + 0.5305332541465759, + 0.6659612059593201, + 0.6396889090538025, + 0.5396444797515869, + 0.39010515809059143, + -0.03072960674762726, + 0.014305810444056988, + 0.029885446652770042, + 0.038084372878074646, + 0.012448564171791077, + 0.034353457391262054, + 0.048626724630594254, + 0.048866890370845795, + 0.07561437785625458, + 0.09152165800333023, + 0.08432324975728989, + 0.09332144260406494, + 0.07517607510089874, + 0.049146559089422226, + 0.03146318346261978, + 0.06335246562957764, + 0.06438779830932617, + 0.06851581484079361, + 0.09263566881418228, + 0.06460423022508621, + 0.011992924846708775, + 0.03396693989634514, + 0.04433950409293175, + 0.04642309248447418, + 0.0022602551616728306, + -0.0361824594438076, + -0.0005105047021061182, + 0.030808264389634132, + 0.0022333709057420492, + -0.017826544120907784, + -0.03796307370066643, + -0.012887164019048214, + -0.028499294072389603, + -0.03367336839437485, + -0.03668365254998207, + -0.02807682938873768, + -0.07444571703672409, + -0.081318199634552, + -0.09610070288181305, + -0.05368436127901077, + -0.09006591141223907, + -0.10038736462593079, + -0.04115951433777809, + -0.056811004877090454, + -0.09935522079467773, + -0.11107856035232544, + -0.07852742075920105, + -0.0942930206656456, + -0.07625897973775864, + -0.12966541945934296, + -0.038938648998737335, + 0.04580259323120117, + 0.10179819911718369, + 0.17127273976802826, + 0.17857632040977478, + 0.13426578044891357, + 0.04687841981649399, + 0.2424812912940979, + 0.42633309960365295, + 0.5291624069213867, + 0.6012980937957764, + 0.5449428558349609, + 0.3945220708847046, + 0.07037744671106339, + 0.26918724179267883, + 0.44614800810813904, + 0.5331310629844666, + 0.568580687046051, + 0.43367546796798706, + 0.25516101717948914, + 0.08428427577018738, + 0.177769735455513, + 0.24885930120944977, + 0.2178547978401184, + 0.13834305107593536, + 0.07452446967363358, + 0.005187708884477615, + 0.050621017813682556, + -0.08428733795881271, + -0.15576106309890747, + -0.25531095266342163, + -0.34646397829055786, + -0.3276817202568054, + -0.24377694725990295, + 0.02817704901099205, + -0.2531633675098419, + -0.3907041549682617, + -0.5944734811782837, + -0.6062930822372437, + -0.5171639919281006, + -0.3501560389995575, + -0.019397703930735588, + -0.2758809030056, + -0.4118667244911194, + -0.5375933051109314, + -0.5525977611541748, + -0.44681206345558167, + -0.2748269736766815, + -0.04229651764035225, + -0.005005967803299427, + -0.011332424357533455, + 0.011387092061340809, + -0.015463154762983322, + -0.012038768269121647, + 0.011360889300704002, + 0.03551746904850006, + 0.05123865604400635, + 0.020377267152071, + 0.1065637394785881, + 0.18875306844711304, + 0.18516196310520172, + 0.12519532442092896, + -0.042940977960824966, + -0.03246130794286728, + -0.016645772382616997, + 0.07807288318872452, + -0.7815885543823242, + -0.5930942296981812, + 0.03312799707055092, + -0.04537777230143547, + -0.022234303876757622, + 0.009241255931556225, + 0.16947965323925018, + -0.0700032040476799, + -0.06346366554498672, + 0.09555318206548691, + 0.02858082763850689, + 0.009246457368135452, + 0.03902693837881088, + 0.007071994710713625, + 0.10085106641054153, + 0.0881502702832222, + 0.011019160971045494, + 0.006030070595443249, + -0.012882355600595474, + -0.01701420359313488, + 0.022596944123506546, + -0.05345382168889046, + 0.02355102449655533, + -0.0091088330373168, + 0.00015542628534603864, + -0.0004997836658731103, + -0.006951311603188515, + 0.01267238613218069, + -0.0033983420580625534, + -0.0030770134180784225, + 0.02975126914680004, + 0.010702245868742466, + -0.016947058960795403, + 0.007774800062179565, + 0.09566964209079742, + 0.07426714897155762, + 0.1621979922056198, + 0.12728945910930634, + 0.06112523376941681, + 0.06061968579888344, + 0.07934501022100449, + 0.11534841358661652, + 0.10001469403505325, + 0.15475066006183624, + 0.1828109323978424, + 0.02134544588625431, + -0.015320047736167908, + 0.012000483460724354, + -0.014393450692296028, + -1.5520576238632202, + -1.2115217447280884, + 0.017239907756447792, + -0.007013735361397266, + 0.0019166347337886691, + 0.025112343952059746, + 0.1803419440984726, + -0.30807924270629883, + -0.33957329392433167, + 0.10846519470214844, + 0.06151076406240463, + 0.054799750447273254, + 0.06235412135720253, + 0.09605015069246292, + 0.16495031118392944, + 0.12624189257621765, + 0.12234552949666977, + 0.006969878450036049, + 0.0033541936427354813, + 0.008165130391716957, + 0.035377491265535355, + -0.03170061111450195, + 0.019396571442484856, + -0.011411413550376892, + 0.019043665379285812, + 0.00957057997584343, + 0.0055394587107002735, + 0.05569477006793022, + 0.0076510305516421795, + 0.018707536160945892, + 0.06073765829205513, + 0.006503407843410969, + -0.0058801183477044106, + -0.03229741007089615, + 0.0386439748108387, + 0.03167358413338661, + 0.027749545872211456, + -0.04634377732872963, + -0.00019781991431955248, + 0.024982664734125137, + 0.009453915059566498, + 0.1091528981924057, + 0.21055325865745544, + 0.23810525238513947, + 0.13829846680164337, + -0.019112061709165573, + -0.0014926757430657744, + 0.01856786385178566, + 0.10649964213371277, + -0.8599057793617249, + -0.6383436322212219, + 0.10839059948921204, + -0.038730181753635406, + -0.030203847214579582, + -0.033147793263196945, + 0.18132103979587555, + -0.1427767276763916, + -0.11132896691560745, + 0.10957232862710953, + -0.00349965482018888, + 0.03486581891775131, + 0.016247740015387535, + 0.060106489807367325, + 0.1439678966999054, + 0.07201634347438812, + 0.07603273540735245, + -0.0072280303575098515, + 0.01600506529211998, + -0.012912745587527752, + 0.015192546881735325, + -0.034853674471378326, + 0.026164958253502846, + 0.001483929343521595, + 0.0508253313601017, + -0.010546445846557617, + -0.024398569017648697, + -0.0043407524935901165, + 0.0030393539927899837, + -0.009643012657761574, + -0.008882591500878334, + 0.01182172168046236, + 0.003359999740496278, + -0.01145304087549448, + -7.34154018573463e-05, + 0.007416137028485537, + -0.012022661976516247, + 0.013550116680562496, + -0.005982181057333946, + -0.019205773249268532, + -0.0811527743935585, + -0.06323252618312836, + -0.026379290968179703, + -0.04671972244977951, + -0.006205265875905752, + 0.05242094770073891, + 0.05065605416893959, + 0.01961991749703884, + 0.021542323753237724, + 0.04147094115614891, + 0.04451332613825798, + 0.05155060812830925, + 0.15659169852733612, + 0.4448348879814148, + 0.7207449078559875, + 0.8680058717727661, + 0.7269517779350281, + 0.36259666085243225, + 0.10394725203514099, + -0.20449180901050568, + -0.42664405703544617, + -0.7290332317352295, + -0.9376083016395569, + -0.735107958316803, + -0.3541502356529236, + -0.23789332807064056, + -0.10901623964309692, + -0.26809337735176086, + -0.38465574383735657, + -0.44440212845802307, + -0.4070444703102112, + -0.22405119240283966, + -0.14190013706684113, + 0.07151509076356888, + 0.21848519146442413, + 0.41893038153648376, + 0.4783499836921692, + 0.4281534254550934, + 0.28631147742271423, + 0.057699400931596756, + 0.0029010034631937742, + -0.02580493874847889, + -0.02152368798851967, + -0.025850815698504448, + 0.004789783153682947, + 0.021941278129816055, + 0.00574735039845109, + -0.004016151186078787, + -0.014377521350979805, + -0.0828985944390297, + -0.06380187720060349, + -0.048879947513341904, + -0.04580164700746536, + -0.030843649059534073, + 0.024663949385285378, + 0.03409295156598091, + 0.060452476143836975, + 0.037006158381700516, + 0.058853648602962494, + 0.07275765389204025, + 0.02882941998541355, + 0.14549848437309265, + 0.4268765151500702, + 0.7150183320045471, + 0.8942612409591675, + 0.7532845139503479, + 0.3846176564693451, + 0.15604183077812195, + -0.19108416140079498, + -0.42633384466171265, + -0.7508237361907959, + -0.9448286890983582, + -0.719300389289856, + -0.3583783805370331, + -0.2060524821281433, + -0.10382426530122757, + -0.2624296545982361, + -0.4049411416053772, + -0.4338999092578888, + -0.41390693187713623, + -0.22797809541225433, + -0.14593803882598877, + 0.08197329193353653, + 0.2430788278579712, + 0.3906225562095642, + 0.47147202491760254, + 0.42429792881011963, + 0.29326340556144714, + 0.06683206558227539, + 0.004355552606284618, + -0.007973028346896172, + 0.0035172239877283573, + -0.0018502225866541266, + -0.015291260555386543, + 0.0025160792283713818, + 0.0015979957534000278, + 0.011951611377298832, + -0.0004334237310104072, + -0.00172338483389467, + 0.017284434288740158, + -0.00445173867046833, + -0.004828867502510548, + 0.004030159674584866, + 0.03321678191423416, + -0.016998661682009697, + -0.029765218496322632, + -0.07912255078554153, + -0.0494595468044281, + 0.012136446312069893, + 0.029541414231061935, + -0.01129366084933281, + 0.09502168744802475, + 0.21533286571502686, + 0.3453419804573059, + 0.22987395524978638, + 0.04720258712768555, + 0.0032486498821526766, + -0.0042808204889297485, + -0.10162857174873352, + -0.21601493656635284, + -0.3040534257888794, + -0.19600912928581238, + -0.0568307563662529, + -0.0062937624752521515, + -0.021828925237059593, + -0.03831009939312935, + -0.08992031216621399, + -0.08103442937135696, + -0.07600760459899902, + -0.02319694682955742, + -0.008472982794046402, + -0.004151565954089165, + 0.05002164468169212, + 0.0985124409198761, + 0.11273156106472015, + 0.10279814153909683, + 0.032678257673978806, + -0.023295480757951736, + -0.022312145680189133, + 0.032877422869205475, + 0.08301658928394318, + -0.049675002694129944, + -0.05956050381064415, + 0.006878976244479418, + 0.011597251519560814, + -0.03617611899971962, + -0.005020621232688427, + 0.0066283573396503925, + 0.061849869787693024, + 0.0668889507651329, + -0.1120104044675827, + 0.0215831957757473, + -0.008177083916962147, + 0.019240612164139748, + -0.03794482350349426, + -0.21581093966960907, + 0.3248063623905182, + 0.0525924488902092, + -0.13873063027858734, + -0.030904211103916168, + -0.004122832324355841, + 0.2784009277820587, + -0.42068102955818176, + -0.15351417660713196, + 0.4266241192817688, + -0.10780557245016098, + 0.03840374946594238, + -0.15116721391677856, + 0.2292502224445343, + 0.23400554060935974, + -0.5023872256278992, + 0.14868289232254028, + 0.09809935092926025, + 0.03480924293398857, + -0.046804867684841156, + -0.14212554693222046, + 0.3073779344558716, + -0.029529480263590813, + -0.13998086750507355, + -0.02750661037862301, + 0.010526027530431747, + 0.032874979078769684, + -0.07645174115896225, + -0.02746269293129444, + 0.10902399569749832, + -0.00446560001000762, + -0.01339190173894167, + 0.003540819976478815, + -0.04410126060247421, + -0.10884726047515869, + 0.016081949695944786, + 0.15211890637874603, + 0.04027504846453667, + -0.05552368983626366, + 0.04718002676963806, + 0.014503135345876217, + -0.2764658033847809, + -0.16068166494369507, + 0.3356778621673584, + 0.06485499441623688, + -0.07164154946804047, + 0.084479421377182, + 0.2702949047088623, + -0.1339409202337265, + -0.9642015695571899, + 0.47433769702911377, + 0.4715694189071655, + -0.17669782042503357, + -0.04434441775083542, + 0.2641690671443939, + 0.7357130646705627, + -1.2222046852111816, + -0.8205837607383728, + 0.9091072678565979, + 0.14896778762340546, + -0.09332367032766342, + -0.16173647344112396, + 0.8782246708869934, + 0.3819980323314667, + -1.619883418083191, + 0.059255462139844894, + 0.42745286226272583, + -0.03186821565032005, + -0.16420172154903412, + 0.12124066799879074, + 0.8650834560394287, + -0.3728218674659729, + -0.5816569328308105, + 0.10949260741472244, + -0.010671291500329971, + -0.07903271913528442, + -0.09700250625610352, + 0.3192030191421509, + 0.2756008505821228, + -0.2616698145866394, + -0.11051242798566818, + 0.016789941117167473, + -0.0484573096036911, + -0.12333080172538757, + 0.0158428642898798, + 0.11172449588775635, + 0.014953864738345146, + -0.011746960692107677, + 0.05310823395848274, + 0.030244171619415283, + -0.23969320952892303, + -0.1039247065782547, + 0.285805881023407, + -0.04652552306652069, + -0.05380000174045563, + 0.05430186912417412, + 0.25547218322753906, + -0.06164371967315674, + -0.7386756539344788, + 0.4393811821937561, + 0.2623714804649353, + -0.1849273294210434, + -0.049713607877492905, + 0.1656467467546463, + 0.6638666391372681, + -0.899787187576294, + -0.5747878551483154, + 0.7465870976448059, + -0.025567445904016495, + -0.051771312952041626, + -0.19754628837108612, + 0.6828271746635437, + 0.4451557695865631, + -1.2559787034988403, + 0.07448688894510269, + 0.27905938029289246, + 0.003908769693225622, + -0.18454433977603912, + -0.011183545924723148, + 0.7449039816856384, + -0.228777676820755, + -0.47592073678970337, + 0.13784541189670563, + 0.019371675327420235, + -0.06424596160650253, + -0.1660400629043579, + 0.2080633044242859, + 0.2942465841770172, + -0.20263032615184784, + -0.0709841251373291, + -0.0021153483539819717, + -0.028180474415421486, + -0.021557176485657692, + 0.012511649169027805, + 0.06533018499612808, + 0.006560645066201687, + -0.01908997632563114, + -0.020228691399097443, + 0.10450740903615952, + 0.04476405307650566, + -0.20389842987060547, + -0.36356496810913086, + -0.18690945208072662, + 0.06581642478704453, + 0.005246834829449654, + -0.14777734875679016, + 0.04554577171802521, + 0.7314760088920593, + 1.1759854555130005, + 0.7747871279716492, + 0.08771117031574249, + 0.04425497353076935, + 0.14875195920467377, + -0.05036012455821037, + -1.0561891794204712, + -1.7835016250610352, + -1.313464879989624, + -0.4041728973388672, + -0.08825081586837769, + -0.18483860790729523, + -0.09619659930467606, + 0.6506555676460266, + 1.2331949472427368, + 1.057729721069336, + 0.3030258119106293, + 0.053314659744501114, + 0.10696353763341904, + 0.19720971584320068, + -0.19457301497459412, + -0.3546113669872284, + -0.3773464560508728, + 0.007737448439002037, + 0.007112926337867975, + -0.026632368564605713, + -0.07708505541086197, + 0.016982559114694595, + 0.03331448882818222, + 0.03235285356640816, + -0.04479134455323219, + 0.0062864539213478565, + -0.04983896017074585, + -0.014209658838808537, + 0.025105496868491173, + 0.07187403738498688, + -0.019782420247793198, + -0.0387532040476799, + 0.01098113413900137, + 0.10765481740236282, + -0.005502769257873297, + -0.29967597126960754, + -0.5370010733604431, + -0.25729984045028687, + 0.0341138020157814, + -0.01927473582327366, + -0.11736954003572464, + 0.09457080066204071, + 0.8881804943084717, + 1.5049697160720825, + 1.0347492694854736, + 0.22410355508327484, + -0.004720119293779135, + 0.1449226438999176, + -0.11916695535182953, + -1.2009364366531372, + -2.080855369567871, + -1.5549882650375366, + -0.5231477618217468, + -0.005029830615967512, + -0.11258674412965775, + 0.03710457682609558, + 0.9192798137664795, + 1.525830626487732, + 1.3018689155578613, + 0.44408130645751953, + 0.006972550880163908, + 0.07937697321176529, + 0.060622286051511765, + -0.4068094491958618, + -0.5964561104774475, + -0.6058750152587891, + -0.1743212193250656, + -0.0038881103973835707, + -0.04932431876659393, + -0.04989266395568848, + 0.07228495925664902, + 0.10359980911016464, + 0.11054171621799469, + 0.017031395807862282, + -0.012849675491452217, + -0.02224516123533249, + -0.019851619377732277, + 0.04567919671535492, + 0.12134519219398499, + 0.018673665821552277, + -0.03933878242969513, + 0.03506385162472725, + 0.07499910145998001, + -0.004981306381523609, + -0.269795298576355, + -0.4478399455547333, + -0.3141564130783081, + 0.014856644906103611, + -0.01102763693779707, + -0.11778493225574493, + -0.00048367868294008076, + 0.46917271614074707, + 0.8380635976791382, + 0.5829758048057556, + 0.14924737811088562, + 0.00504975114017725, + 0.1242799386382103, + 0.027800291776657104, + -0.5343790054321289, + -0.9185061454772949, + -0.6974499225616455, + -0.1733488291501999, + 0.028415951877832413, + -0.07513032108545303, + 0.010947657749056816, + 0.5501428246498108, + 0.8556726574897766, + 0.6854383945465088, + 0.21023745834827423, + -0.04757346957921982, + 0.028925150632858276, + -0.05005616322159767, + -0.4106282889842987, + -0.5990055203437805, + -0.5274976491928101, + -0.18928098678588867, + 0.007199999876320362, + 0.004744168370962143, + -0.006203897297382355, + 0.16117095947265625, + 0.20310591161251068, + 0.17358633875846863, + 0.057794276624917984, + 0.0018837900133803487, + -0.021730661392211914, + 0.03705505281686783, + 0.048999205231666565, + 0.017187459394335747, + -0.04760497808456421, + -0.06534644961357117, + 0.027641354128718376, + -0.02722003310918808, + -0.09557735174894333, + 0.2721945643424988, + 0.06861108541488647, + -0.17862513661384583, + 0.029542427510023117, + -0.028343068435788155, + -0.24357359111309052, + 0.2928915321826935, + 0.6317090392112732, + -0.5675624012947083, + -0.31298428773880005, + 0.119928739964962, + -0.04503166303038597, + 0.1997436285018921, + 0.9068917632102966, + -0.6105388402938843, + -1.176649808883667, + 0.391012579202652, + 0.21436090767383575, + 0.06404570490121841, + 0.4306352436542511, + -0.18372972309589386, + -1.6093186140060425, + 0.5129231810569763, + 0.8333584666252136, + -0.11607109010219574, + 0.024050598964095116, + -0.027272621169686317, + -0.8072280883789062, + 0.15613007545471191, + 1.0115277767181396, + -0.1886059194803238, + -0.1662863790988922, + -0.07484262436628342, + -0.11359186470508575, + -0.05765556916594505, + 0.48085057735443115, + 0.031143836677074432, + -0.20803743600845337, + 0.005643316078931093, + -0.011422591283917427, + -0.02063453011214733, + 0.010139239020645618, + 0.026931140571832657, + 0.02650240994989872, + 0.014503400772809982, + -0.030498046427965164, + 0.01038119662553072, + -0.041832923889160156, + -0.11747029423713684, + 0.24838468432426453, + 0.08126607537269592, + -0.17684465646743774, + 0.009867151267826557, + -0.04349489137530327, + -0.22892898321151733, + 0.3097872734069824, + 0.6229272484779358, + -0.5710748434066772, + -0.2540203332901001, + 0.15970031917095184, + -0.05765099450945854, + 0.24631772935390472, + 0.9121918678283691, + -0.6539115309715271, + -1.1680796146392822, + 0.43742635846138, + 0.1981748640537262, + 0.060766786336898804, + 0.48115089535713196, + -0.2704729437828064, + -1.668082594871521, + 0.6258481740951538, + 0.8217618465423584, + -0.17844447493553162, + 0.07583325356245041, + -0.031355466693639755, + -0.884739100933075, + 0.21298757195472717, + 1.0279508829116821, + -0.2118954360485077, + -0.16616611182689667, + -0.025157395750284195, + -0.11329160630702972, + -0.08147483319044113, + 0.46636614203453064, + 0.023730026558041573, + -0.21343427896499634, + -0.015201984904706478, + -0.00498165050521493, + 0.022955382242798805, + 0.020228328183293343, + -0.029405873268842697, + -0.032065436244010925, + 0.047389160841703415, + -0.01793060638010502, + 0.01669210195541382, + 0.05227159336209297, + -0.11703876405954361, + 0.006789325270801783, + 0.03741219639778137, + -0.04651298373937607, + -0.012846981175243855, + 0.024231625720858574, + -0.13399703800678253, + -0.024073680862784386, + 0.2970501184463501, + -0.1497301310300827, + -0.04287628084421158, + 0.08405227214097977, + -0.06020639091730118, + -0.01648692972958088, + 0.4150170087814331, + -0.17000712454319, + -0.43461430072784424, + 0.27202337980270386, + 0.006708468310534954, + -0.04474359005689621, + 0.15199843049049377, + -0.03348325565457344, + -0.6591396331787109, + 0.4057810306549072, + 0.25226324796676636, + -0.16070741415023804, + 0.03464199975132942, + 0.023064177483320236, + -0.35642316937446594, + 0.22774185240268707, + 0.37138837575912476, + -0.24171461164951324, + -0.023513946682214737, + 0.028774995356798172, + -0.02702418342232704, + -0.012504744343459606, + 0.17893734574317932, + -0.1554262489080429, + -0.09501983970403671, + 0.06177212670445442, + -0.013536165468394756, + 0.012441401369869709, + 0.006566522642970085, + -0.018207622691988945, + 0.003373368876054883, + -0.034891802817583084, + 0.002223123563453555, + 0.006169564090669155, + 0.022658145055174828, + -0.005327044054865837, + -0.023764559999108315, + -0.004386506043374538, + -0.02777106687426567, + 0.01950058527290821, + 0.004401096608489752, + 0.02882237359881401, + 0.01790205016732216, + -0.007827110588550568, + -0.005222277250140905, + -0.05361752584576607, + 0.008359426632523537, + -0.026494475081562996, + -0.015572195872664452, + -0.04412947595119476, + -0.006163781508803368, + 0.180303692817688, + 0.17117105424404144, + -0.014117442071437836, + 0.014543564058840275, + 0.03875281661748886, + 0.002004631096497178, + 0.11982911080121994, + 0.609316349029541, + 0.5792325735092163, + 0.10267578810453415, + -0.02287464588880539, + -0.011516223661601543, + -0.02587946131825447, + 0.019127164036035538, + 0.2742871046066284, + 0.23896890878677368, + -0.013414637185633183, + 0.012439075857400894, + 0.01148916780948639, + 0.0024075021501630545, + -0.028374193236231804, + -0.02938784286379814, + -0.061723873019218445, + -0.03288640081882477, + 0.010918691754341125, + 0.01171314436942339, + 0.00894222967326641, + -0.0050367508083581924, + 0.00322812981903553, + -0.01958087645471096, + 0.000401448953198269, + 0.00655051926150918, + 0.008647873997688293, + -0.015351405367255211, + -0.022286182269454002, + -0.0018973759142681956, + -0.032965533435344696, + 0.009401706047356129, + 0.01680464670062065, + 0.01722409576177597, + 0.017367251217365265, + -0.0012145076179876924, + 0.015895379707217216, + -0.013976357877254486, + 0.01587546430528164, + -0.019388504326343536, + -0.004597584251314402, + -0.026080038398504257, + 0.020517753437161446, + 0.20680218935012817, + 0.20302064716815948, + 0.03813354671001434, + 0.027738921344280243, + 0.02183712273836136, + 0.023807305842638016, + 0.14632326364517212, + 0.5991678237915039, + 0.608651340007782, + 0.15929070115089417, + -0.02112223394215107, + -0.020013611763715744, + -0.03723381832242012, + 0.032139480113983154, + 0.27032363414764404, + 0.24862462282180786, + 0.02374681644141674, + 0.007894856855273247, + 0.00042308925185352564, + -0.004832752980291843, + -0.024313796311616898, + -0.0018940505106002092, + -0.02681432105600834, + 0.002362651750445366, + 0.013330202549695969, + 0.012553646229207516, + 0.002630018163472414, + 0.002979951212182641, + 0.0015847217291593552, + -0.03376828506588936, + -0.010844729840755463, + -0.002748559694737196, + 0.012938202358782291, + -0.011872833594679832, + -0.0025761008728295565, + 0.003677211469039321, + -0.04305516183376312, + 0.001133457524701953, + 0.0020396243780851364, + 0.01797032356262207, + 0.016580887138843536, + 0.04445189982652664, + 0.013270077295601368, + -0.04839251935482025, + 0.011546633206307888, + -0.015829432755708694, + 0.019473392516374588, + -0.011464826762676239, + 0.018693143501877785, + 0.18201367557048798, + 0.16157257556915283, + 0.02082117274403572, + 0.015915032476186752, + 0.010720869526267052, + -0.0020238866563886404, + 0.09329187124967575, + 0.46998023986816406, + 0.5186727046966553, + 0.09814783185720444, + -0.016547314822673798, + 0.00325066689401865, + -0.028936590999364853, + 0.01002424769103527, + 0.21822214126586914, + 0.22012007236480713, + 0.008229314349591732, + 0.015599996782839298, + 0.014740276150405407, + 0.0019725109450519085, + 0.003613655688241124, + -0.03043546713888645, + -0.06308998167514801, + 0.014664110727608204, + 0.06775129586458206, + -0.12990300357341766, + -0.03638269379734993, + -0.03883139044046402, + 0.05194637551903725, + 0.03896122798323631, + -0.05132362246513367, + -0.07234688848257065, + -0.36106064915657043, + -0.2839237451553345, + -0.11496391147375107, + 0.3026673197746277, + 0.3528609871864319, + 0.21559017896652222, + -0.11970120668411255, + -0.5473688244819641, + -0.5362005233764648, + -0.21015112102031708, + 0.4089161455631256, + 0.6033567786216736, + 0.38614287972450256, + -0.12437233328819275, + -0.6394402384757996, + -0.6945835947990417, + -0.3482857942581177, + 0.5189254283905029, + 0.8457668423652649, + 0.6248002648353577, + -0.12700730562210083, + -0.6978924870491028, + -0.7764106392860413, + -0.4171960651874542, + 0.44747814536094666, + 0.8406224846839905, + 0.6821274161338806, + -0.07793218642473221, + -0.5459966659545898, + -0.6139025092124939, + -0.35998886823654175, + 0.27800890803337097, + 0.6048891544342041, + 0.591307520866394, + -0.04850815609097481, + -0.3863481283187866, + -0.3542836606502533, + -0.2491992861032486, + 0.1616278886795044, + 0.3402666747570038, + 0.4610227644443512, + -0.010262396186590195, + 0.0408165417611599, + 0.006382474210113287, + -0.011430315673351288, + -0.027895113453269005, + -0.009767768904566765, + 0.005882019177079201, + 0.05225436016917229, + 0.0415218211710453, + 0.08244743943214417, + 0.026765575632452965, + -0.05404946208000183, + -0.06101839989423752, + -0.028233220800757408, + 0.03128793090581894, + 0.07133004069328308, + 0.0718698799610138, + 0.042146697640419006, + -0.08380170166492462, + -0.09263177216053009, + -0.07569421827793121, + 0.032425008714199066, + 0.12351400405168533, + 0.09103626012802124, + -0.004768018145114183, + -0.05960838869214058, + -0.11922567337751389, + -0.10132396221160889, + 0.044341862201690674, + 0.100867860019207, + 0.09607693552970886, + -0.00129030947573483, + -0.05481477826833725, + -0.1278291642665863, + -0.12058380991220474, + 0.016678951680660248, + 0.09958931058645248, + 0.08456224203109741, + 0.061599165201187134, + -0.049776893109083176, + -0.11354166269302368, + -0.09844806790351868, + 0.004753128159791231, + 0.07868346571922302, + 0.06464104354381561, + 0.020981626585125923, + -0.010770543478429317, + -0.08838209509849548, + -0.07265795767307281, + -0.058313023298978806, + 0.10897739976644516, + 0.026735201478004456, + 0.03972309082746506, + -0.019998662173748016, + -0.048948734998703, + 0.03377270698547363, + 0.053406376391649246, + 0.27304399013519287, + 0.20850272476673126, + 0.07890326529741287, + -0.22241365909576416, + -0.2816997468471527, + -0.1745096743106842, + 0.08957889676094055, + 0.4962941110134125, + 0.4586986303329468, + 0.20177948474884033, + -0.3625744581222534, + -0.47758376598358154, + -0.32412785291671753, + 0.0669194757938385, + 0.5394997596740723, + 0.601328432559967, + 0.24388420581817627, + -0.4319041073322296, + -0.6893490552902222, + -0.5106037259101868, + 0.10174300521612167, + 0.5457565784454346, + 0.6549625992774963, + 0.38772058486938477, + -0.3778320252895355, + -0.6820934414863586, + -0.551069438457489, + 0.049600999802351, + 0.45137161016464233, + 0.5143972039222717, + 0.3713279068470001, + -0.26546329259872437, + -0.5121409893035889, + -0.47691628336906433, + 0.03843758627772331, + 0.30808231234550476, + 0.3185756504535675, + 0.22629432380199432, + -0.14860986173152924, + -0.2915389835834503, + -0.3552006185054779, + -0.003137432038784027, + -0.01327254343777895, + -0.027139298617839813, + 0.04800891876220703, + 0.05380738899111748, + -0.01380784809589386, + 0.0022881641052663326, + -0.012132279574871063, + 0.06182793900370598, + 0.03762871399521828, + 0.0966145321726799, + 0.08963571488857269, + 0.06551238149404526, + 0.031640589237213135, + -0.010532311163842678, + 0.07195396721363068, + 0.11343465745449066, + 0.11621421575546265, + 0.047318290919065475, + 0.1111951395869255, + 0.044054243713617325, + 0.016777141019701958, + 0.03392713516950607, + 0.06047024950385094, + -0.7924502491950989, + -0.7310910224914551, + 0.031088173389434814, + 0.0906061977148056, + 0.022829236462712288, + 0.04470035433769226, + 0.025999872013926506, + -0.8246837258338928, + -0.723675549030304, + 0.15835590660572052, + 0.07358791679143906, + -0.015819497406482697, + -0.014207872562110424, + 0.08506257086992264, + 0.08868777751922607, + 0.0976945012807846, + 0.11740022897720337, + 0.016287995502352715, + -0.024363648146390915, + 0.04249691963195801, + 0.02909177541732788, + 0.12011238187551498, + 0.10729824751615524, + 0.05927390977740288, + 0.04731644690036774, + 0.008210064843297005, + 0.03859357163310051, + -0.005175672471523285, + 0.01984376832842827, + -0.0011626111809164286, + -0.0010909241391345859, + 0.02311880886554718, + 0.007646523881703615, + 0.04582137614488602, + -0.0027255103923380375, + 0.027656713500618935, + 0.02781369723379612, + 0.015750093385577202, + 0.040563344955444336, + -0.007784596644341946, + 0.006534814368933439, + 0.002403199439868331, + -0.020037032663822174, + -0.011717663146555424, + 0.07826739549636841, + 0.018203573301434517, + 0.021228624507784843, + 0.014112413860857487, + -0.02866269089281559, + -0.9502679109573364, + -0.825043797492981, + 0.05938851460814476, + 0.06553053110837936, + 0.015418429858982563, + 0.0616452619433403, + -0.0094453701749444, + -0.9471839666366577, + -0.7922234535217285, + 0.13069523870944977, + 0.04939320683479309, + 0.007429714780300856, + 0.022599652409553528, + 0.0820123627781868, + 0.06440276652574539, + 0.09897352755069733, + 0.0856291800737381, + 0.006608777679502964, + -0.0005533680086955428, + 0.021656949073076248, + 0.014818831346929073, + 0.03757459297776222, + -0.001428246614523232, + 0.03473127633333206, + 0.03607869893312454, + 0.017313262447714806, + 0.0025767614133656025, + -0.033292777836322784, + 0.027883101254701614, + -0.007534499745815992, + -0.04302362725138664, + -0.01795666106045246, + -0.007667913101613522, + 0.012547189369797707, + -0.021762438118457794, + 0.03789107874035835, + 0.06384614109992981, + 0.0014223429607227445, + -0.01393786258995533, + -0.041693057864904404, + -0.01813604310154915, + 0.065328449010849, + 0.15736474096775055, + 0.1531635969877243, + 0.09920474886894226, + -0.04044449329376221, + 0.010558396577835083, + 0.05559245124459267, + 0.10931257158517838, + -0.5784384608268738, + -0.5109886527061462, + 0.17690584063529968, + 0.07484250515699387, + 0.010378374718129635, + 0.0890144556760788, + 0.13172735273838043, + -0.6058865785598755, + -0.49908995628356934, + 0.1835336685180664, + 0.005293308291584253, + -0.03870566934347153, + -0.025229454040527344, + 0.12571711838245392, + 0.14792272448539734, + 0.14905226230621338, + 0.0700206533074379, + -0.035034529864788055, + 0.013128797523677349, + 0.015581230632960796, + 0.005400130525231361, + 0.07070232182741165, + 0.03829728811979294, + -0.013876918703317642, + -0.019958000630140305, + -0.020086020231246948, + -0.019999003037810326, + -0.015111410059034824, + 0.11963249742984772, + -0.08270428329706192, + -0.0025947154499590397, + -0.010668564587831497, + 0.016670405864715576, + -0.03206938877701759, + -0.053453829139471054, + 0.1236601173877716, + -0.020077411085367203, + 0.00779569149017334, + -0.0318986251950264, + 0.03579804673790932, + -0.060723867267370224, + -0.009301809594035149, + 0.09249342232942581, + -0.13378725945949554, + 0.17496798932552338, + -0.0935625433921814, + 0.06569044291973114, + -0.18187756836414337, + 0.06397300213575363, + 0.3793930113315582, + -0.5664302706718445, + 0.23658618330955505, + -0.03206830099225044, + 0.03155658766627312, + 0.039305318146944046, + -0.6008145213127136, + 1.0417630672454834, + -0.5062726140022278, + -0.04698493704199791, + 0.0979752242565155, + -0.037326715886592865, + 0.26255178451538086, + -0.590207576751709, + 0.4195419251918793, + 0.12212422490119934, + -0.26122942566871643, + 0.06442253291606903, + -0.07682429254055023, + 0.12608948349952698, + -0.13872937858104706, + -0.030260663479566574, + 0.2047160565853119, + -0.13068141043186188, + 0.016608506441116333, + -0.021629147231578827, + 0.04659907519817352, + 0.024417348206043243, + 0.06751634925603867, + -0.1705978959798813, + 0.0655774399638176, + -0.0041802311316132545, + -0.02263445220887661, + -0.014069054275751114, + 0.06242800131440163, + 0.08984102308750153, + -0.19382472336292267, + 0.09380361437797546, + -0.0032764992211014032, + -0.03950225189328194, + -0.08896161615848541, + 0.28387022018432617, + 0.1668996810913086, + -0.5457127094268799, + 0.21796099841594696, + 0.012032964266836643, + 0.030721815302968025, + -0.4431600570678711, + 0.3104412257671356, + 1.0070439577102661, + -1.1077969074249268, + 0.08187273889780045, + 0.1387241780757904, + 0.09014563262462616, + -0.25378379225730896, + -0.9253583550453186, + 1.9745515584945679, + -0.6605072617530823, + -0.4394792318344116, + 0.11501576751470566, + 0.03007262572646141, + 0.2538164258003235, + -1.1462018489837646, + 0.7988958954811096, + 0.46934643387794495, + -0.4244523048400879, + -0.0001816617150325328, + -0.04351970925927162, + 0.20500127971172333, + -0.40710335969924927, + -0.15871365368366241, + 0.4640160799026489, + -0.06024328991770744, + -0.016036653891205788, + -0.012419192120432854, + 0.05552554875612259, + 0.050986770540475845, + -0.0171927809715271, + -0.12105240672826767, + 0.03947274759411812, + 0.009537882171571255, + -0.026668362319469452, + 0.017273351550102234, + 0.10812800377607346, + -0.015008139424026012, + -0.14154496788978577, + 0.08008233457803726, + -0.01306608971208334, + -0.05574854835867882, + -0.06091056764125824, + 0.2888447940349579, + 0.05022002384066582, + -0.4581625759601593, + 0.21146118640899658, + -0.01495362538844347, + 0.02946372702717781, + -0.38554418087005615, + 0.30167311429977417, + 0.7605867981910706, + -0.898481547832489, + 0.11953620612621307, + 0.12686115503311157, + 0.09949854761362076, + -0.14409342408180237, + -0.7404491901397705, + 1.5449001789093018, + -0.5307857394218445, + -0.3347839415073395, + 0.09940771013498306, + 0.009087899699807167, + 0.3081797957420349, + -0.9053899049758911, + 0.5102643370628357, + 0.4646914303302765, + -0.36200836300849915, + -0.043260715901851654, + -0.05309509113430977, + 0.22480911016464233, + -0.2674587666988373, + -0.25316888093948364, + 0.435017466545105, + -0.017485838383436203, + -0.049459364265203476, + 0.012460661120712757, + -0.02262282371520996, + -0.04392899200320244, + 0.013330060057342052, + 0.05963548645377159, + -0.020561739802360535, + -0.013496879488229752, + -0.02310933545231819, + -0.06549905985593796, + 0.12132573872804642, + 0.22165189683437347, + -0.07683887332677841, + -0.12427931278944016, + 0.05543455854058266, + 0.009089780040085316, + 0.19844494760036469, + 0.07650767266750336, + -0.48934996128082275, + -0.35080164670944214, + 0.13422781229019165, + 0.022217294201254845, + -0.006589306052774191, + -0.18357548117637634, + -0.6055922508239746, + 0.09492127597332001, + 0.7073907256126404, + 0.1777055710554123, + -0.05434347689151764, + 0.04566245526075363, + -0.023967979475855827, + 0.4856843054294586, + 0.8131930828094482, + -0.2068077027797699, + -0.3863125145435333, + 0.02887917123734951, + -0.05048410966992378, + 0.051201049238443375, + 0.057671088725328445, + -0.6412642002105713, + -0.39739903807640076, + 0.11036981642246246, + 0.06687764078378677, + -0.018151026219129562, + 0.0022760110441595316, + -0.09328305721282959, + 0.1352599710226059, + 0.19680921733379364, + 0.032235175371170044, + -0.06123670935630798, + -0.013810456730425358, + -0.01821190118789673, + -0.029903864488005638, + 0.027588335797190666, + 0.0762094110250473, + -0.046041399240493774, + 0.017117975279688835, + -0.018925148993730545, + 0.00423092395067215, + 0.2065701186656952, + 0.157025545835495, + -0.26491472125053406, + -0.24569831788539886, + 0.0873267725110054, + 0.004694689530879259, + 0.1838335543870926, + -0.18973900377750397, + -0.9744532108306885, + -0.41959065198898315, + 0.409589946269989, + 0.22223009169101715, + -0.0989728644490242, + -0.40883490443229675, + -0.8418471813201904, + 0.40256521105766296, + 1.4742398262023926, + 0.4913789629936218, + -0.14741277694702148, + -0.0028576564509421587, + 0.0861843004822731, + 1.0056577920913696, + 1.479182481765747, + -0.21940617263317108, + -0.8383130431175232, + -0.30560192465782166, + 0.12028121203184128, + 0.24013034999370575, + 0.11750353127717972, + -1.1071972846984863, + -0.9066778421401978, + -0.055051110684871674, + 0.15361995995044708, + 0.0032418384216725826, + -0.08823435008525848, + -0.3188804090023041, + -0.02160414680838585, + 0.2972750663757324, + 0.17006494104862213, + 0.03401973098516464, + 0.017106015235185623, + 0.010733614675700665, + 0.004688877146691084, + 0.02985573373734951, + 0.046415988355875015, + -0.05177726596593857, + -0.04624386876821518, + 0.026672907173633575, + 0.03479000926017761, + 0.22761401534080505, + 0.12049756944179535, + -0.23494181036949158, + -0.2207801640033722, + 0.06036320701241493, + 0.02112250216305256, + 0.16173022985458374, + -0.14196650683879852, + -0.8236543536186218, + -0.3530665934085846, + 0.3715725541114807, + 0.25781863927841187, + -0.09806561470031738, + -0.341796338558197, + -0.7201419472694397, + 0.2111824005842209, + 1.1648427248001099, + 0.3866075575351715, + -0.1955428272485733, + -0.13164694607257843, + -0.06048528477549553, + 0.7989920973777771, + 1.143347144126892, + -0.19509637355804443, + -0.6719933152198792, + -0.26912447810173035, + 0.16733723878860474, + 0.32526257634162903, + 0.1910397708415985, + -0.8516904711723328, + -0.6005953550338745, + 0.10627525299787521, + 0.16700856387615204, + 0.032433755695819855, + -0.11345972120761871, + -0.270126610994339, + -0.012052524834871292, + 0.25489771366119385, + 0.14647918939590454, + -0.014324051328003407, + -0.011148945428431034, + -0.0011708218371495605, + -0.018903911113739014, + -0.010648071765899658, + -0.017981043085455894, + 0.014055400155484676, + -0.020784996449947357, + -0.030126383528113365, + 0.1150858998298645, + -0.1112036183476448, + -0.023664508014917374, + 0.1651369333267212, + -0.055412910878658295, + -0.007318025920540094, + -0.07404221594333649, + 0.3068569302558899, + -0.6175673007965088, + 0.35226404666900635, + 0.1940349042415619, + -0.22921296954154968, + 0.06411048769950867, + 0.001689439988695085, + 0.23336739838123322, + -0.9470900893211365, + 1.2042961120605469, + -0.44587329030036926, + -0.15847182273864746, + 0.07572423666715622, + 0.11138042062520981, + -0.2075018584728241, + -0.2651064693927765, + 0.8896074295043945, + -0.7130936980247498, + 0.10370831191539764, + 0.07730382680892944, + 0.02368813008069992, + -0.20520009100437164, + 0.13611918687820435, + 0.31062978506088257, + -0.471883624792099, + 0.21489326655864716, + -0.0216743852943182, + -0.04020361602306366, + -0.022920167073607445, + 0.16054102778434753, + -0.002624030224978924, + -0.14670424163341522, + 0.12018264085054398, + -0.043656397610902786, + -0.005084550939500332, + 0.03873870149254799, + -0.07967288792133331, + -0.007439201697707176, + 0.027688704431056976, + 0.08916077762842178, + -0.0036629599053412676, + -0.01389122661203146, + 0.1402083784341812, + -0.2923351228237152, + -0.01932896114885807, + 0.224355086684227, + -0.013193303719162941, + -0.03984276205301285, + -0.04474477842450142, + 0.3302844762802124, + -0.9746807217597961, + 0.5603556036949158, + 0.3556183874607086, + -0.2713812589645386, + 0.01890619471669197, + 0.06983876973390579, + 0.09052442759275436, + -1.3613605499267578, + 1.8220031261444092, + -0.40902698040008545, + -0.31302449107170105, + 0.03893759846687317, + 0.11448371410369873, + -0.4220678210258484, + -0.3677598237991333, + 1.539440631866455, + -0.8297391533851624, + -0.08504960685968399, + 0.0629446730017662, + -0.016804160550236702, + -0.31778836250305176, + 0.2363198846578598, + 0.6452136635780334, + -0.700931191444397, + 0.09927428513765335, + 0.0019635935313999653, + -0.05397690460085869, + -0.014552262611687183, + 0.2352754771709442, + 0.09991656988859177, + -0.28891685605049133, + 0.07818552106618881, + -0.021534763276576996, + -0.009461677633225918, + -0.01069199200719595, + -0.008059840649366379, + -0.0129952197894454, + 0.038492631167173386, + 0.018906958401203156, + -0.025432486087083817, + -0.03420932963490486, + 0.09104404598474503, + -0.10342919826507568, + -0.035048507153987885, + 0.1415904313325882, + -0.052986644208431244, + -0.021596742793917656, + -0.049690280109643936, + 0.3079117238521576, + -0.5487046837806702, + 0.27024003863334656, + 0.15158434212207794, + -0.16488635540008545, + 0.027642132714390755, + 0.004561549983918667, + 0.21555493772029877, + -0.9188903570175171, + 1.0972669124603271, + -0.3528037667274475, + -0.07574182748794556, + 0.021962830796837807, + 0.08826783299446106, + -0.18681983649730682, + -0.2789378762245178, + 0.864517331123352, + -0.5642455816268921, + 0.07469761371612549, + 0.03803368657827377, + 0.014268620871007442, + -0.17712704837322235, + 0.1349189728498459, + 0.3181247115135193, + -0.45067182183265686, + 0.1391848623752594, + 0.009777083061635494, + -0.028080958873033524, + -0.03586730733513832, + 0.14503192901611328, + -0.014655024744570255, + -0.1472700834274292, + 0.07361634075641632, + -0.0029754601418972015, + -0.006887470372021198, + -0.019166842103004456, + 0.0034907464869320393, + -0.015169994905591011, + 0.053831856697797775, + -0.028789488598704338, + -0.02033298648893833, + 0.0018537036376073956, + 0.07567961513996124, + -0.07041627168655396, + -0.047083087265491486, + 0.17573483288288116, + -0.04860217124223709, + 0.013171656988561153, + 0.020158233121037483, + -0.006270059384405613, + -0.28434091806411743, + 0.2760852873325348, + 0.32198208570480347, + -0.43535903096199036, + 0.03188510239124298, + 0.019360313192009926, + -0.20063988864421844, + 0.04450676590204239, + 0.9678076505661011, + -0.683987021446228, + -0.3979112207889557, + 0.2618143558502197, + -0.049711134284734726, + -0.06456997990608215, + 0.6518288850784302, + -0.1357039213180542, + -1.1304017305374146, + 0.4881652295589447, + 0.19583553075790405, + -0.03677722439169884, + 0.21429045498371124, + 0.09559855610132217, + -0.7311355471611023, + 0.10988117009401321, + 0.4949330687522888, + -0.17359353601932526, + 0.03822369873523712, + 0.011371256783604622, + -0.1900172382593155, + -0.04778448864817619, + 0.2897090017795563, + -0.02235160581767559, + -0.05582524091005325, + 0.007624597754329443, + -0.027456223964691162, + -0.029680097475647926, + -0.023810429498553276, + 0.15409281849861145, + 0.013284318149089813, + -0.0788225457072258, + -0.025637971237301826, + 0.01406402699649334, + -0.13676859438419342, + 0.027384959161281586, + 0.30458444356918335, + -0.11150643229484558, + -0.06806201487779617, + 0.009601237252354622, + -0.0866582989692688, + -0.2328706979751587, + 0.5188567638397217, + 0.3787381649017334, + -0.655829906463623, + 0.0072118742391467094, + -0.0031494891736656427, + -0.2424815446138382, + 0.28893929719924927, + 1.2396824359893799, + -1.0406886339187622, + -0.6376030445098877, + 0.4103420078754425, + -0.05929668992757797, + 0.03918358311057091, + 0.9274081587791443, + -0.28890565037727356, + -1.6682262420654297, + 0.66976398229599, + 0.35488471388816833, + 0.027932289987802505, + 0.3169145882129669, + 0.09107685089111328, + -1.2099432945251465, + 0.11623579263687134, + 0.7632684707641602, + -0.16506360471248627, + 0.037474747747182846, + -0.005203985143452883, + -0.35939401388168335, + -0.17138688266277313, + 0.525232195854187, + 0.10247340798377991, + -0.14317406713962555, + 0.007572649512439966, + -0.006046198774129152, + 0.06188087910413742, + -0.050851333886384964, + 0.032844241708517075, + 0.0544477179646492, + -0.07947597652673721, + -0.03073730878531933, + 0.04025515541434288, + -0.010001083835959435, + -0.11831062287092209, + 0.17422229051589966, + -0.05468267202377319, + -0.04996664077043533, + 0.023996006697416306, + 0.02888253889977932, + -0.18709556758403778, + 0.13987921178340912, + 0.32867854833602905, + -0.31714990735054016, + 0.019951285794377327, + 0.027247004210948944, + -0.19416090846061707, + -0.006519266404211521, + 0.7540720105171204, + -0.5474190711975098, + -0.27137213945388794, + 0.20772530138492584, + -0.042619917541742325, + -0.09566087275743484, + 0.548494815826416, + -0.1599852293729782, + -0.9178788661956787, + 0.5456539988517761, + 0.07497559487819672, + 0.003984459210187197, + 0.18640351295471191, + 0.12121234089136124, + -0.7249511480331421, + 0.2559764087200165, + 0.4684237241744995, + -0.19216996431350708, + 0.018075481057167053, + 0.02684594877064228, + -0.221074178814888, + -0.09164194762706757, + 0.3596596121788025, + -0.08310746401548386, + -0.10815230011940002, + -0.015406409278512001, + -0.011985878460109234, + 0.028467312455177307, + -0.0879230722784996, + 0.0347294844686985, + 0.05081191286444664, + 0.00362736196257174, + 0.010529003106057644, + -0.002672453410923481, + 0.025318201631307602, + -0.06232529878616333, + 0.008822780102491379, + 0.06744717806577682, + 0.003999210894107819, + -0.0022885131184011698, + -0.046704765409231186, + 0.13673964142799377, + -0.2590992748737335, + -0.022161437198519707, + 0.258914053440094, + -0.10650330036878586, + 0.023435762152075768, + 0.06992689520120621, + 0.03760937228798866, + -0.5444027185440063, + 0.4131152629852295, + 0.25325170159339905, + -0.2482522875070572, + 0.010479461401700974, + 0.045747850090265274, + -0.1541248857975006, + -0.35291528701782227, + 0.9078133702278137, + -0.34428781270980835, + -0.14787709712982178, + -0.024105649441480637, + -0.007651817053556442, + -0.14991067349910736, + 0.17544956505298615, + 0.3692120611667633, + -0.46861159801483154, + 0.10201738774776459, + 0.003734431229531765, + -0.010433703660964966, + 0.022045455873012543, + 0.0944862961769104, + 0.01679016835987568, + -0.16537833213806152, + 0.07900089025497437, + -0.004211293533444405, + -0.01076442189514637, + 0.09729930013418198, + -0.1490965485572815, + -0.02511671558022499, + 0.0766475573182106, + 0.010980346240103245, + -0.010220799595117569, + -0.0004861881607212126, + 0.09204736351966858, + -0.179045170545578, + -0.025164175778627396, + 0.15608654916286469, + 0.004787537269294262, + -0.0005253870622254908, + 0.034556396305561066, + 0.1509256660938263, + -0.5432079434394836, + -0.03155849874019623, + 0.513609766960144, + -0.14458952844142914, + 0.015178131870925426, + 0.09172039479017258, + -0.12612608075141907, + -0.926306962966919, + 0.8281942009925842, + 0.5954549908638, + -0.492740273475647, + 0.007195526268333197, + -0.018258413299918175, + -0.4074647128582001, + -0.43008187413215637, + 1.7370752096176147, + -0.350849986076355, + -0.5158001780509949, + -0.017458094283938408, + -0.08306471258401871, + -0.2334563285112381, + 0.445117712020874, + 0.7808031439781189, + -0.7913723587989807, + -0.11814796179533005, + -0.00913319457322359, + 0.0223994143307209, + 0.1012248545885086, + 0.25349485874176025, + 0.028286214917898178, + -0.4809858798980713, + 0.05953341722488403, + 0.015634188428521156, + 0.005101620219647884, + 0.10901974141597748, + -0.11964976042509079, + -0.09117673337459564, + 0.0734483003616333, + 0.01821213960647583, + 5.350751234800555e-05, + -0.020279232412576675, + 0.1097220927476883, + -0.1354990452528, + -0.08653146773576736, + 0.11775246262550354, + -0.012575668282806873, + 0.0310806967318058, + 0.010271146893501282, + 0.20337054133415222, + -0.3854014277458191, + -0.09943562000989914, + 0.3921409249305725, + -0.08432158827781677, + 0.010676748119294643, + 0.040244489908218384, + -0.0015478944405913353, + -0.7022866010665894, + 0.49858638644218445, + 0.42338883876800537, + -0.2982582449913025, + -0.005396307446062565, + -0.008777705952525139, + -0.2325415015220642, + -0.4083922803401947, + 1.186205506324768, + -0.26399391889572144, + -0.2621048092842102, + -0.015712907537817955, + -0.04675402492284775, + -0.1797540783882141, + 0.2992522716522217, + 0.4747498333454132, + -0.5266988277435303, + 0.04581758379936218, + -0.04037958011031151, + 0.0071074217557907104, + 0.047499995678663254, + 0.16617828607559204, + -0.03973710536956787, + -0.2953551113605499, + 0.10628587752580643, + -0.00904526561498642, + 0.010427894070744514, + 0.08035022020339966, + 0.03841109946370125, + -0.06335253268480301, + -0.06992083787918091, + 0.015409895218908787, + -0.026900725439190865, + -0.04523912072181702, + 0.08087682723999023, + 0.12542113661766052, + 0.018750213086605072, + -0.23430712521076202, + 0.11755944788455963, + -0.019747508689761162, + -0.03171322122216225, + -0.12132623791694641, + 0.2640603184700012, + 0.38445138931274414, + -0.5724408030509949, + 0.15661633014678955, + 0.01949799247086048, + -0.021771302446722984, + -0.18984957039356232, + -0.23499636352062225, + 1.2112919092178345, + -0.7037869095802307, + -0.14260035753250122, + 0.01848726160824299, + 0.06443414837121964, + -0.11740390956401825, + -0.8794785141944885, + 1.4160369634628296, + 0.016899125650525093, + -0.5444768071174622, + 0.017313210293650627, + 0.0508052259683609, + 0.11102095246315002, + -0.790285587310791, + 0.3501206636428833, + 0.7238660454750061, + -0.49468666315078735, + -0.019021952524781227, + -0.01212992612272501, + 0.15032203495502472, + -0.3573611080646515, + -0.1293754130601883, + 0.45295456051826477, + -0.08407819271087646, + -0.008717959746718407, + 0.022566653788089752, + -0.012640242464840412, + 0.03181227669119835, + 0.0638526976108551, + -0.058120664209127426, + -0.042917650192976, + 0.02129550836980343, + -0.018790805712342262, + -0.00655191857367754, + 0.05951414257287979, + 0.12890471518039703, + -0.1886381357908249, + 0.059096939861774445, + -0.016928592696785927, + 0.02327263168990612, + -0.17282842099666595, + 0.13812857866287231, + 0.38889989256858826, + -0.5282873511314392, + 0.07564643770456314, + -0.006128210574388504, + -0.00876594614237547, + -0.18427829444408417, + -0.26697441935539246, + 1.2529815435409546, + -0.6549165844917297, + -0.2111111879348755, + 0.011410325765609741, + 0.07089994102716446, + -0.12627695500850677, + -0.8245998024940491, + 1.4581915140151978, + -0.01822204887866974, + -0.5626582503318787, + -0.01661459542810917, + 0.03759436681866646, + 0.10841676592826843, + -0.7652962803840637, + 0.4360819458961487, + 0.7012669444084167, + -0.47011038661003113, + 0.01529701892286539, + -0.0033166150096803904, + 0.12170535326004028, + -0.3871544301509857, + -0.05247795954346657, + 0.4504147171974182, + -0.11442532390356064, + -0.00882577896118164, + 0.005190832540392876, + -0.05153197422623634, + 0.0055236960761249065, + 0.09320031106472015, + -0.03762076050043106, + -0.021778371185064316, + 0.00750907463952899, + 0.014965789392590523, + -0.015135630965232849, + -0.037086039781570435, + 0.08020154386758804, + -0.04429963231086731, + 0.0038218852132558823, + -0.01712334342300892, + 0.053772956132888794, + -0.05226677283644676, + -0.024439912289381027, + 0.12774989008903503, + -0.18722355365753174, + 0.0683830976486206, + -0.010828870348632336, + -0.012880662456154823, + 0.02679484151303768, + -0.13696907460689545, + 0.46868517994880676, + -0.322968989610672, + 0.052930932492017746, + 0.009463602676987648, + -0.046861011534929276, + 0.07714711129665375, + -0.35792097449302673, + 0.5517901182174683, + -0.13382655382156372, + -0.12921281158924103, + 0.018562642857432365, + -0.03842621296644211, + 0.10284601897001266, + -0.28243398666381836, + 0.13314206898212433, + 0.20769073069095612, + -0.1551610678434372, + 0.018036767840385437, + -0.03553476929664612, + 0.036686040461063385, + -0.09568552672863007, + 0.008917863480746746, + 0.11340243369340897, + -0.04745811969041824, + 0.005833764094859362, + -0.04174824804067612, + 0.022730106487870216, + 0.0013601485406979918, + -0.07473982870578766, + -0.004801879171282053, + 0.05632775276899338, + -0.04081303998827934, + 0.11509573459625244, + 0.004507652949541807, + -0.24791881442070007, + 0.43171870708465576, + -0.1362573653459549, + -0.10758046060800552, + 0.02746163308620453, + -0.2954745888710022, + 0.30186471343040466, + 0.3135572075843811, + -1.2296111583709717, + 0.8754236102104187, + -0.11699853837490082, + 0.022482017055153847, + 0.24945153295993805, + -0.7858022451400757, + 0.5181443095207214, + 1.4243930578231812, + -1.876152515411377, + 0.4689188003540039, + 0.04258054122328758, + -0.030832920223474503, + 0.9340220093727112, + -1.512351632118225, + -0.3731614947319031, + 2.021338701248169, + -0.7801089286804199, + -0.09288544207811356, + -0.12423597276210785, + -0.36861127614974976, + 1.1679530143737793, + -0.4960964024066925, + -1.0398281812667847, + 0.686152458190918, + 0.02052121050655842, + 0.07246638089418411, + -0.01763315312564373, + -0.37442535161972046, + 0.33217450976371765, + 0.22260302305221558, + -0.2657756209373474, + 0.00016369696822948754, + 0.008136127144098282, + -0.03592197597026825, + 0.022231513634324074, + 0.041430093348026276, + -0.06439317017793655, + 0.03496818616986275, + -0.05143435671925545, + 0.09930871427059174, + 0.017110232263803482, + -0.3834381699562073, + 0.44344815611839294, + -0.00280396337620914, + -0.11487428843975067, + 0.050503507256507874, + -0.22837062180042267, + 0.47540077567100525, + 0.5802375674247742, + -1.7325034141540527, + 0.8587368130683899, + 0.10429240018129349, + -0.02456486038863659, + 0.1340152472257614, + -1.2299835681915283, + 0.7986555099487305, + 2.2204456329345703, + -2.4498374462127686, + 0.33742472529411316, + 0.1001473218202591, + 0.08700849115848541, + 0.9933257102966309, + -2.5278031826019287, + -0.5935835242271423, + 2.710871934890747, + -0.87749183177948, + -0.06125229224562645, + -0.19061818718910217, + -0.04017600044608116, + 1.7519460916519165, + -0.7798219919204712, + -1.28012216091156, + 0.7500321269035339, + 0.02245335467159748, + 0.08263842761516571, + -0.1563340127468109, + -0.3502165377140045, + 0.5060794949531555, + 0.11768018454313278, + -0.2394258826971054, + 0.0027446788735687733, + -0.0012661140644922853, + 0.010839025489985943, + 0.04500429332256317, + -0.04333498701453209, + -0.027386408299207687, + 0.04357098788022995, + -0.04407481476664543, + 0.08443310111761093, + -0.08108946681022644, + -0.20346391201019287, + 0.3825778365135193, + -0.16498182713985443, + -0.04287993535399437, + 0.05340999737381935, + -0.14011172950267792, + 0.29446643590927124, + 0.2738667130470276, + -1.1299961805343628, + 0.7827413082122803, + -0.07552053779363632, + -0.03602323681116104, + 0.16167275607585907, + -0.6924317479133606, + 0.4478289783000946, + 1.2428895235061646, + -1.4833877086639404, + 0.4690392315387726, + -0.00820756796747446, + -0.09873292595148087, + 0.692342221736908, + -1.0981175899505615, + -0.3906446695327759, + 1.438644528388977, + -0.719068169593811, + 0.026173872873187065, + -0.09383898228406906, + -0.3282022774219513, + 1.0363390445709229, + -0.23960772156715393, + -0.7638148069381714, + 0.5488630533218384, + -0.015319733880460262, + 0.11911362409591675, + 0.017409542575478554, + -0.4231888949871063, + 0.23724795877933502, + 0.1191876158118248, + -0.15694500505924225, + -0.03534351661801338, + 0.06342366337776184, + 0.17738288640975952, + 0.012300643138587475, + -0.06408121436834335, + -0.06030220910906792, + 0.0018237337935715914, + 0.07659764587879181, + 0.1820947527885437, + 0.24410061538219452, + -0.06998514384031296, + -0.1491813361644745, + -0.06184092164039612, + 0.04607890918850899, + 0.15362663567066193, + 0.18308304250240326, + 0.08175522834062576, + -0.305602103471756, + -0.2915116548538208, + -0.08144206553697586, + 0.07138665020465851, + -0.03521484509110451, + -0.0914112851023674, + -0.2766699492931366, + -0.6285344362258911, + -0.38168880343437195, + -0.0033710987772792578, + 0.14477019011974335, + -0.03885374590754509, + -0.11367184668779373, + -0.1979650855064392, + -0.3575190007686615, + 0.016150522977113724, + 0.28292712569236755, + 0.2836199402809143, + -0.016672370955348015, + -0.034946177154779434, + -0.014770845882594585, + -0.0004113636096008122, + 0.29938748478889465, + 0.3562523126602173, + 0.13313128054141998, + -0.029499055817723274, + 0.007187174167484045, + 0.0636785551905632, + 0.047712039202451706, + 0.20670579373836517, + 0.10999035090208054, + -0.1150810718536377, + 0.00879934523254633, + -0.009125287644565105, + -0.013732590712606907, + 0.04738131910562515, + 0.0549951009452343, + -0.014094026759266853, + -0.01195482350885868, + -0.017125386744737625, + -0.071754589676857, + -0.023961570113897324, + 0.013098018243908882, + 0.05972208455204964, + -0.032899752259254456, + -0.024354496970772743, + -0.013116234913468361, + -0.05865325778722763, + -0.006360829807817936, + 0.12809234857559204, + 0.14038555324077606, + -0.022946689277887344, + -0.039698828011751175, + 0.05144746974110603, + -0.025034509599208832, + 0.08764739334583282, + 0.24594412744045258, + 0.19307002425193787, + -0.04085381329059601, + -0.020323628559708595, + 0.022060081362724304, + 0.01799374632537365, + 0.09039195626974106, + 0.1681770235300064, + 0.0016234283102676272, + -0.23777234554290771, + -0.11634974926710129, + -0.014439117163419724, + -0.034799374639987946, + 0.0457066111266613, + 0.049919649958610535, + -0.1926913857460022, + -0.2680967450141907, + 0.0018220803467556834, + -0.012749310582876205, + -0.04389086738228798, + 0.0060565415769815445, + -0.012036234140396118, + -0.12737582623958588, + -0.05777670815587044, + 0.09932202100753784, + 0.09969642758369446, + -0.1296343356370926, + -0.2964152693748474, + -0.05487265810370445, + 0.12073978036642075, + 0.06634647399187088, + 0.004042446613311768, + -0.1586746722459793, + -0.6267098784446716, + -0.5184157490730286, + -0.032286129891872406, + 0.28023189306259155, + 0.12663227319717407, + -0.08828771114349365, + -0.2600027620792389, + -0.5287090539932251, + -0.0994620993733406, + 0.7820600271224976, + 0.9638882279396057, + 0.2193463146686554, + -0.13466303050518036, + 0.042050741612911224, + -0.02292742393910885, + 0.7523098587989807, + 1.7435946464538574, + 1.111282229423523, + -0.2104763388633728, + -0.35129284858703613, + 0.08224371820688248, + 0.11167984455823898, + 0.6513852477073669, + 0.9696454405784607, + -0.1501394510269165, + -1.1777327060699463, + -0.7738466262817383, + 0.01114045549184084, + 0.004884988535195589, + 0.2849186658859253, + 0.14232710003852844, + -1.0306764841079712, + -1.2078118324279785, + -0.14658716320991516, + 0.036605384200811386, + 0.0001495486794738099, + 0.12111346423625946, + -0.24653346836566925, + -0.7028710246086121, + -0.18977169692516327, + 0.5171932578086853, + -0.02514370158314705, + 0.0885375589132309, + -0.1023016944527626, + 0.023200739175081253, + 0.11839435249567032, + -0.09749021381139755, + 0.008283962495625019, + 0.0106261121109128, + -0.031724803149700165, + -0.1594654619693756, + 0.433218389749527, + -0.33944255113601685, + 0.14406877756118774, + -0.0339396670460701, + 0.09370072185993195, + -0.35916459560394287, + 0.7577320337295532, + -0.5531823635101318, + -0.016844574362039566, + 0.2994873523712158, + -0.21487002074718475, + -0.16125759482383728, + 0.35567227005958557, + 0.09099612385034561, + -1.3889282941818237, + 1.9466298818588257, + -1.2556309700012207, + 0.4389301836490631, + -0.010665428824722767, + 0.4707520306110382, + -1.4310415983200073, + 2.0986156463623047, + -1.5515614748001099, + 0.3905705511569977, + 0.01881679706275463, + 0.057307951152324677, + -0.29734691977500916, + 0.369127094745636, + -0.05115725100040436, + -0.44008156657218933, + 0.48642784357070923, + -0.13904061913490295, + -0.004375698510557413, + -0.06351548433303833, + 0.256020188331604, + -0.34121274948120117, + 0.22490821778774261, + 0.004067304544150829, + -0.059063635766506195, + -0.010710661299526691, + 0.03514768183231354, + -0.08577805012464523, + 0.05103181675076485, + 0.04276616871356964, + -0.10832246392965317, + 0.03325289487838745, + 0.06318283081054688, + -0.11063538491725922, + -0.062119144946336746, + 0.40978243947029114, + -0.5597845315933228, + 0.34106317162513733, + -0.030269838869571686, + 0.057014383375644684, + -0.44329890608787537, + 1.0965592861175537, + -1.0767146348953247, + 0.13287265598773956, + 0.517289400100708, + -0.310720294713974, + -0.15501761436462402, + 0.5854693055152893, + -0.12469431757926941, + -1.7694847583770752, + 2.6433238983154297, + -1.596714735031128, + 0.3888415992259979, + -0.02415616251528263, + 0.42178481817245483, + -1.8008503913879395, + 2.8845136165618896, + -1.7628657817840576, + 0.1951047033071518, + 0.11415407806634903, + 0.07305648922920227, + -0.34212157130241394, + 0.46562451124191284, + 0.03175807744264603, + -0.7942091226577759, + 0.6133171319961548, + -0.14596694707870483, + 0.010496735572814941, + -0.03459644690155983, + 0.2948842942714691, + -0.47654271125793457, + 0.2612597346305847, + 0.016025209799408913, + -0.05287598818540573, + -0.01606004498898983, + 0.022197037935256958, + 0.028397703543305397, + -0.0390767939388752, + 0.0037972000427544117, + -0.07010228931903839, + 0.10934390872716904, + 0.017220165580511093, + 0.02215729095041752, + -0.14772991836071014, + 0.2353552132844925, + -0.3846408724784851, + 0.23990634083747864, + -0.02300707995891571, + 0.12085225433111191, + -0.3576957881450653, + 0.6410096883773804, + -0.532350480556488, + -0.002389132045209408, + 0.41821879148483276, + -0.24739143252372742, + -0.10216745734214783, + 0.16793736815452576, + 0.16367803514003754, + -1.1304419040679932, + 1.676539421081543, + -1.064436435699463, + 0.26995453238487244, + -0.07634275406599045, + 0.3324422240257263, + -1.11312997341156, + 1.8095507621765137, + -1.2477567195892334, + 0.3605581820011139, + -0.06627745926380157, + 0.008511146530508995, + -0.19528241455554962, + 0.4320055842399597, + -0.22881783545017242, + -0.18463851511478424, + 0.3064245581626892, + -0.14437103271484375, + 0.02049900032579899, + 0.018321938812732697, + 0.14011529088020325, + -0.26683253049850464, + 0.2172057181596756, + -0.12119362503290176, + 0.025965997949242592, + -0.03424325957894325, + 0.0433838777244091, + 0.1072857677936554, + 0.1997794657945633, + 0.0648089200258255, + -0.06444115936756134, + -0.13146057724952698, + 0.02106364443898201, + -0.22582228481769562, + -0.007233713287860155, + 0.18876874446868896, + -0.5612399578094482, + 0.2632557451725006, + 0.44088244438171387, + 0.11389002948999405, + -0.2791701555252075, + -0.18004432320594788, + 0.8571203947067261, + -1.9517340660095215, + -1.4906251430511475, + 0.3436146676540375, + 0.31222787499427795, + -0.20083315670490265, + -0.217665895819664, + 3.801243782043457, + 1.2014728784561157, + -0.9149202704429626, + 0.6968244910240173, + 0.12756747007369995, + -0.06783506274223328, + -2.086660385131836, + 0.5455523133277893, + 0.49095916748046875, + -0.5991013050079346, + 0.7938552498817444, + -0.1335069239139557, + 0.4730406701564789, + -1.00951087474823, + -0.537578821182251, + -0.49764835834503174, + -1.2683815956115723, + -0.045739322900772095, + -0.16049732267856598, + 0.30239275097846985, + 0.035600025206804276, + 0.6344828605651855, + 0.8256548643112183, + -0.12940075993537903, + 0.09257010370492935, + -0.11000311374664307, + 0.003206665627658367, + -0.008585316129028797, + -0.14573170244693756, + 0.172541081905365, + 0.2107972949743271, + -0.05270108953118324, + -0.08480435609817505, + 0.1914149820804596, + 0.21630872786045074, + -0.23309426009655, + -0.29484814405441284, + -0.1899339109659195, + 0.02601807750761509, + -0.05416746065020561, + 0.20924429595470428, + 0.15566189587116241, + -0.1556546688079834, + -0.23387494683265686, + -0.5112816691398621, + 0.24130745232105255, + -0.049835484474897385, + -0.2685615122318268, + -0.024764614179730415, + 0.5458847880363464, + 0.9501044750213623, + 0.1328524947166443, + 0.21218529343605042, + 0.2524968683719635, + -0.5205130577087402, + -0.3361912667751312, + 1.1678112745285034, + -0.004513490945100784, + -0.9149109125137329, + 0.2125048041343689, + 0.22423015534877777, + -0.08384363353252411, + -0.2866036593914032, + -0.20210212469100952, + -1.2377471923828125, + -0.7704879641532898, + 0.365038126707077, + -0.08308980613946915, + -0.08326874673366547, + 0.456358402967453, + 0.35142943263053894, + 0.19268833100795746, + 0.3706081509590149, + -0.04951317980885506, + 0.10151109844446182, + 0.005193099845200777, + -0.1124582439661026, + -0.08353164792060852, + -0.18709596991539001, + -0.18975794315338135, + 0.17628741264343262, + 0.05536900460720062, + 0.008301885798573494, + -0.1890449970960617, + 0.056875281035900116, + 0.7981322407722473, + -0.05872391164302826, + -0.4860122501850128, + -0.08073797076940536, + 0.13145819306373596, + -0.03608228266239166, + -0.6600452661514282, + 2.243560314178467, + 1.9288626909255981, + -0.5698518753051758, + -0.2486664056777954, + 0.42693793773651123, + 0.2667267322540283, + -4.395429611206055, + -2.15342378616333, + 0.819127082824707, + -0.9362612962722778, + -0.3760467767715454, + 0.5671858787536621, + 2.468177080154419, + -1.6694080829620361, + -0.49952322244644165, + 1.502772569656372, + -1.0188850164413452, + -0.10419629514217377, + -0.36795151233673096, + 1.2645196914672852, + 0.7223924994468689, + 1.751431941986084, + 2.018704891204834, + -0.3197852671146393, + 0.22054125368595123, + -0.19326329231262207, + -0.5307535529136658, + -0.9362435936927795, + -1.0772119760513306, + -0.19870880246162415, + -0.0650869607925415, + -0.0796947032213211, + 0.15733301639556885, + 0.08798394352197647, + 0.0010860684560611844, + 0.05327683687210083, + 0.1107875183224678, + 0.13224183022975922, + 0.08979664742946625, + 0.004348093178123236, + -0.07060158997774124, + -0.19925491511821747, + -0.15811985731124878, + -0.08220887929201126, + -0.022623460739850998, + 0.08509720861911774, + 0.00792989507317543, + -0.14345014095306396, + -0.2720486521720886, + -0.18885627388954163, + -0.11063539236783981, + -0.0355350486934185, + 0.048891279846429825, + -0.12828074395656586, + -0.2712610363960266, + -0.20134924352169037, + -0.1863398402929306, + -0.19976121187210083, + -0.09535074234008789, + 0.009852319024503231, + -0.2776590585708618, + -0.3087778687477112, + -0.21431012451648712, + -0.19772370159626007, + -0.23412325978279114, + -0.11640459299087524, + 0.09514907747507095, + -0.17561811208724976, + -0.29451555013656616, + -0.2381855845451355, + -0.18296842277050018, + -0.18682444095611572, + -0.023345205932855606, + 0.1438502073287964, + 0.02504260651767254, + -0.1554802507162094, + -0.1477985382080078, + -0.07874225080013275, + -0.002977968193590641, + 0.1048416793346405, + -0.1779504120349884, + 0.13204343616962433, + 0.14215172827243805, + 0.049610622227191925, + 0.0888131782412529, + 0.07250366359949112, + 0.0696505531668663, + 0.009899160824716091, + 0.032067786902189255, + 0.08401404321193695, + -0.03567894548177719, + -0.004740188363939524, + -0.0021664693485945463, + -0.011156522668898106, + 0.0821070745587349, + 0.10295391082763672, + -0.0017653254326432943, + -0.16915833950042725, + -0.062223054468631744, + 0.004783258773386478, + 0.038355808705091476, + 0.10124270617961884, + -0.003437258303165436, + -0.18881437182426453, + -0.15905225276947021, + -0.12576808035373688, + -0.11059725284576416, + 0.021587060764431953, + 0.07237453758716583, + -0.1706620156764984, + -0.27434206008911133, + -0.23003827035427094, + -0.20530915260314941, + -0.20856624841690063, + -0.021966496482491493, + 0.13395215570926666, + -0.03810539469122887, + -0.2409798800945282, + -0.2515420913696289, + -0.1872486174106598, + -0.15951117873191833, + 0.04223426431417465, + 0.09909931570291519, + 0.12328703701496124, + -0.057749148458242416, + -0.1300545036792755, + -0.046062104403972626, + 0.019744107499718666, + 0.09484386444091797, + -0.2709728479385376, + 0.03540695831179619, + 0.1206774190068245, + 0.057636432349681854, + 0.10385740548372269, + 0.032486993819475174, + -0.020434774458408356, + -0.10122086852788925, + -0.0023329253308475018, + 0.16941140592098236, + 0.098082534968853, + 0.1250472217798233, + 0.06134447827935219, + -0.025240115821361542, + 0.004181401338428259, + 0.14425808191299438, + 0.17515034973621368, + 0.04739757999777794, + 0.1618604063987732, + 0.1751406490802765, + 0.09162088483572006, + 0.09512057155370712, + 0.13736343383789062, + 0.028775952756404877, + 0.042535409331321716, + 0.08839954435825348, + 0.09229374676942825, + 0.1658262014389038, + 0.09852072596549988, + 0.002680110279470682, + -0.05479496717453003, + -0.03634755313396454, + -0.002902726177126169, + -0.023990361019968987, + 0.1277875006198883, + 0.12727677822113037, + 0.1002269834280014, + -0.040967896580696106, + -0.07101184874773026, + -0.007902896963059902, + 0.019561029970645905, + 0.145268052816391, + 0.017638152465224266, + 0.19240263104438782, + 0.12857146561145782, + 0.05043037235736847, + 0.11596394330263138, + 0.12513381242752075, + 0.12088746577501297, + 0.04333524778485298, + 0.05500142276287079, + 0.05169082432985306, + -0.09941842406988144, + -0.005959822330623865, + -0.032586321234703064, + -0.03065132349729538, + -0.04826900362968445, + 0.14192889630794525, + 0.2543988823890686, + 0.09563885629177094, + -0.28965362906455994, + -0.1341734230518341, + 0.033991701900959015, + -0.22402706742286682, + -0.3190857768058777, + 0.011840387247502804, + 0.9620282053947449, + 1.0609054565429688, + -0.13429726660251617, + -0.20191268622875214, + 0.05324135720729828, + -0.16234318912029266, + -0.9101927280426025, + -1.7916113138198853, + 0.3981992304325104, + 1.3173034191131592, + 0.53525310754776, + 0.18472574651241302, + 0.3719426691532135, + 0.7792536020278931, + -0.027768991887569427, + -2.245561122894287, + -1.2211185693740845, + 0.22817185521125793, + -0.0023349972907453775, + -0.12598364055156708, + 0.06836964190006256, + 0.9917387366294861, + 1.1885775327682495, + -0.2851368486881256, + -0.7428704500198364, + -0.04798422381281853, + -0.00811613816767931, + -0.19619861245155334, + -0.28184008598327637, + 0.0828644260764122, + 0.44643187522888184, + 0.1461745798587799, + -0.005575121380388737, + -0.06604957580566406, + 0.011459077708423138, + 0.03927984461188316, + 0.0634538009762764, + -0.005732079967856407, + -0.01014732290059328, + 0.07607843726873398, + 0.06948187947273254, + -0.010600326582789421, + -0.056259915232658386, + -0.24602480232715607, + -0.01649448834359646, + 0.11143466085195541, + -0.0027401424013078213, + -0.012853104621171951, + 0.08452893793582916, + 0.639316201210022, + 0.5167437195777893, + -0.2775256335735321, + -0.22241903841495514, + -0.07067711651325226, + -0.06368192285299301, + -0.4687917232513428, + -1.1776493787765503, + 0.36015447974205017, + 0.9171182513237, + 0.1905054748058319, + -0.010661551728844643, + 0.10800722986459732, + 0.5352235436439514, + 0.18558207154273987, + -1.5184046030044556, + -0.8130561709403992, + 0.15417319536209106, + 0.0713079422712326, + -0.07369451224803925, + -0.09037846326828003, + 0.6168488264083862, + 0.9663773775100708, + -0.007113471627235413, + -0.33585548400878906, + -0.02738586813211441, + 0.061310965567827225, + -0.0955657884478569, + -0.23896107077598572, + -0.1107473075389862, + 0.1830059289932251, + 0.10748914629220963, + -0.040772341191768646, + -0.05803938955068588, + -0.0004895658930763602, + 0.07664632797241211, + 0.039049405604600906, + -0.002806248841807246, + -0.02642429992556572, + 0.05169009417295456, + -0.036710865795612335, + -0.1002974808216095, + -0.12001149356365204, + -0.08043934404850006, + 0.11466419696807861, + 0.12322796136140823, + 0.07564827799797058, + 0.10148002207279205, + 0.04720174893736839, + 0.14046646654605865, + -0.0819464847445488, + -0.30803975462913513, + -0.0838734582066536, + -0.0801682323217392, + 0.05861072987318039, + 0.04970559477806091, + -0.20592759549617767, + 0.2673366665840149, + 0.2431953400373459, + -0.10027645528316498, + -0.07884806394577026, + -0.09939537942409515, + 0.1181628480553627, + 0.25269386172294617, + -0.3439132571220398, + -0.11160463094711304, + 0.08640077710151672, + 0.07200870662927628, + -0.03449570760130882, + -0.17610406875610352, + -0.021308166906237602, + 0.30556705594062805, + 0.05186203494668007, + -0.004691269714385271, + -0.005278654862195253, + 0.06289899349212646, + 0.052224051207304, + -0.05927770212292671, + -0.1586783081293106, + -0.022610770538449287, + 0.03463536128401756, + 0.004338411148637533, + 0.01452699676156044, + -0.008622901514172554, + 0.010536444373428822, + -0.038111478090286255, + 0.013373414985835552, + 0.007125865668058395, + -0.003420598339289427, + 0.03533756732940674, + 0.0320388600230217, + 0.045789655297994614, + -0.08139114826917648, + -0.03447948023676872, + -0.01453007198870182, + -0.004573625046759844, + 0.10279268026351929, + 0.10881853848695755, + 0.07537791877985, + -0.10887791216373444, + -0.0980544164776802, + -0.06889445334672928, + 0.006558350287377834, + 0.197514146566391, + 0.17890937626361847, + 0.07630149275064468, + -0.16081148386001587, + -0.16685302555561066, + -0.11421715468168259, + -0.013679573312401772, + 0.22477784752845764, + 0.20761631429195404, + 0.07321957498788834, + -0.17697854340076447, + -0.17810045182704926, + -0.1579347848892212, + -0.02679254300892353, + 0.1408146619796753, + 0.15144851803779602, + 0.08801613748073578, + -0.13237154483795166, + -0.13181765377521515, + -0.1279487907886505, + -0.01779216341674328, + 0.08145096898078918, + 0.05625852569937706, + 0.07724357396364212, + -0.04653938114643097, + -0.07479449361562729, + -0.06189379468560219, + -0.04310920089483261, + 0.02028634026646614, + -0.006228619255125523, + 0.03549303859472275, + -0.043929651379585266, + 0.007818001322448254, + 0.00874761026352644, + -0.017027731984853745, + 0.11014463752508163, + 0.0841977447271347, + 0.05960552394390106, + -0.12814101576805115, + -0.0544624924659729, + -0.045333195477724075, + 0.02336869016289711, + 0.22365787625312805, + 0.18523427844047546, + 0.09366372227668762, + -0.20144090056419373, + -0.16367222368717194, + -0.13003699481487274, + 0.0590205080807209, + 0.3301562964916229, + 0.26524844765663147, + 0.09425198286771774, + -0.26156124472618103, + -0.28513699769973755, + -0.21749621629714966, + 0.04356053099036217, + 0.35879984498023987, + 0.29898661375045776, + 0.0977487862110138, + -0.28175386786460876, + -0.2964495122432709, + -0.249031201004982, + 0.028877725824713707, + 0.26395633816719055, + 0.23059280216693878, + 0.09593978524208069, + -0.22489066421985626, + -0.2248908430337906, + -0.19214706122875214, + 0.007535146549344063, + 0.15299226343631744, + 0.09148521721363068, + 0.06946425884962082, + -0.1445557326078415, + -0.11587042361497879, + -0.0978587418794632, + -0.00984917301684618, + -0.012626220472157001, + -0.02837960794568062, + 0.02399199828505516, + -0.005340439733117819, + 0.023224178701639175, + 0.011642432771623135, + 0.003958537708967924, + 0.042965203523635864, + 0.01099414099007845, + 0.024063799530267715, + -0.0702008455991745, + 0.007805663626641035, + 0.0050195748917758465, + 0.017281856387853622, + 0.10123670846223831, + 0.06401767581701279, + 0.02626805007457733, + -0.1073761060833931, + -0.03802435100078583, + -0.014407800510525703, + -0.0006281707319431007, + 0.15516239404678345, + 0.12629136443138123, + 0.033691491931676865, + -0.17609107494354248, + -0.15251316130161285, + -0.07914211601018906, + -0.015578335151076317, + 0.18422608077526093, + 0.1740245372056961, + 0.06139932945370674, + -0.17213505506515503, + -0.1602732092142105, + -0.08922445774078369, + -0.012822975404560566, + 0.13543544709682465, + 0.12543149292469025, + 0.07651004195213318, + -0.13805902004241943, + -0.09661149233579636, + -0.052669934928417206, + -0.03268992528319359, + 0.0391642227768898, + 0.01116940937936306, + 0.04585625231266022, + -0.06474924832582474, + -0.023607701063156128, + -0.007017284631729126, + -0.026150476187467575, + 0.05729387328028679, + -0.10095079243183136, + 0.16617903113365173, + -0.13664309680461884, + 0.026482274755835533, + 0.008411461487412453, + -0.03410203382372856, + 0.022963764145970345, + 0.008903563022613525, + 0.11244194954633713, + -0.20863348245620728, + 0.11064451932907104, + -0.024916114285588264, + 0.009591493755578995, + -0.26092270016670227, + 0.5717483758926392, + -0.38539814949035645, + 0.035056713968515396, + 0.08623965084552765, + -0.016184961423277855, + 0.11129201203584671, + -0.6138678789138794, + 1.3646206855773926, + -1.4969615936279297, + 0.8465064764022827, + -0.2794847786426544, + 0.05826558917760849, + 0.07709132134914398, + -0.5444677472114563, + 1.3013663291931152, + -1.5686073303222656, + 0.9930508732795715, + -0.39188963174819946, + 0.08085884898900986, + -0.05875617265701294, + 0.03498996049165726, + 0.23967482149600983, + -0.3468690514564514, + 0.19146253168582916, + 0.019604403525590897, + -0.027150027453899384, + -0.024670494720339775, + 0.09944183379411697, + -0.11718503385782242, + 0.09772855788469315, + -0.11857263743877411, + 0.09660946577787399, + -0.03638811036944389, + -0.0295167975127697, + 0.1032838523387909, + -0.12557579576969147, + 0.11812210828065872, + -0.08446288853883743, + 0.027706580236554146, + 0.010997293516993523, + -0.06348618865013123, + 0.09578556567430496, + -0.0165568757802248, + -0.014778072014451027, + -0.07772849500179291, + 0.11245536059141159, + -0.043248821049928665, + 0.013345679268240929, + -0.22149333357810974, + 0.6456363797187805, + -0.7280437350273132, + 0.3046833574771881, + 0.06304280459880829, + -0.07310052216053009, + 0.08824795484542847, + -0.65179842710495, + 1.6453673839569092, + -2.046448230743408, + 1.3267604112625122, + -0.42399832606315613, + 0.0010522910160943866, + 0.07953720539808273, + -0.5960973501205444, + 1.5601089000701904, + -2.084894895553589, + 1.4612183570861816, + -0.5491638779640198, + 0.13709494471549988, + -0.09170618653297424, + 0.07287970930337906, + 0.24422486126422882, + -0.4581631124019623, + 0.29479551315307617, + -0.07515113800764084, + -0.012292998842895031, + -0.04451148584485054, + 0.14961428940296173, + -0.15577177703380585, + 0.06323063373565674, + -0.07806269824504852, + 0.07061618566513062, + -0.026793144643306732, + -0.051938362419605255, + 0.13946141302585602, + -0.14129231870174408, + 0.11092118173837662, + -0.08889970183372498, + 0.034787945449352264, + -0.008983314968645573, + -0.04930088296532631, + 0.09856640547513962, + -0.09350966662168503, + 0.07015673816204071, + -0.06468848884105682, + 0.08028972148895264, + -0.02378295361995697, + 0.004251216538250446, + -0.11239825189113617, + 0.2660067081451416, + -0.367576539516449, + 0.2212517410516739, + -0.035011082887649536, + -0.037866897881031036, + 0.11835235357284546, + -0.4868132174015045, + 0.9402765035629272, + -1.0933791399002075, + 0.9518744349479675, + -0.5096855759620667, + 0.12277142703533173, + 0.12916085124015808, + -0.4648635983467102, + 0.8895858526229858, + -1.0776352882385254, + 1.023865818977356, + -0.5914785861968994, + 0.1682877242565155, + -0.05646277964115143, + 0.04132156819105148, + -0.01790236309170723, + -0.059831030666828156, + 0.10092897713184357, + -0.1268356889486313, + 0.013669619336724281, + -0.02746082842350006, + 0.11544085294008255, + -0.2124193012714386, + 0.2733248472213745, + -0.1360178142786026, + 0.025302443653345108, + 0.01249375008046627, + -0.015119954012334347, + 0.017966970801353455, + 0.00269943755120039, + 0.014392177574336529, + 0.007648292928934097, + 0.011665135622024536, + -0.006192799191921949, + 0.004215092398226261, + 0.017718149349093437, + 0.046436555683612823, + 0.044417623430490494, + 0.01518242433667183, + -0.0020157198887318373, + -0.01828707568347454, + -0.029163505882024765, + -0.03131464868783951, + -0.004393945913761854, + 0.048599082976579666, + 0.015757638961076736, + -0.015650734305381775, + -0.002684049541130662, + -0.0697445422410965, + -0.25050923228263855, + -0.4758685231208801, + -0.5382962822914124, + -0.38907238841056824, + -0.12599025666713715, + -0.00266047241166234, + 0.0758173018693924, + 0.26593172550201416, + 0.4203726053237915, + 0.4958920478820801, + 0.3697706162929535, + 0.12434400618076324, + 0.026325728744268417, + 0.022295912727713585, + 0.08135133236646652, + 0.2627769708633423, + 0.26325660943984985, + 0.12326934933662415, + 0.058665141463279724, + 0.04346219077706337, + -0.0013142779935151339, + -0.10037153959274292, + -0.27075886726379395, + -0.28071707487106323, + -0.17300420999526978, + -0.06914675980806351, + 0.004067219793796539, + -0.020674005150794983, + 0.02103183977305889, + 0.0033879741095006466, + 0.013523808680474758, + -0.007318845018744469, + -0.009975744411349297, + -0.02981705591082573, + 0.023193644359707832, + 0.09624253213405609, + 0.1077117845416069, + 0.11186518520116806, + 0.07592211663722992, + 0.04614634811878204, + 0.015908582136034966, + -0.05212458223104477, + -0.1262977123260498, + -0.10974782705307007, + -0.07645918428897858, + -0.06987964361906052, + -0.08783216774463654, + -0.046172842383384705, + -0.22593465447425842, + -0.5281140804290771, + -0.8424770832061768, + -0.9608982801437378, + -0.7363743185997009, + -0.3312055170536041, + -0.10426472127437592, + 0.24067367613315582, + 0.5504152178764343, + 0.81276935338974, + 0.9592635035514832, + 0.7479950785636902, + 0.32608768343925476, + 0.14525265991687775, + 0.15008939802646637, + 0.32246851921081543, + 0.5287250876426697, + 0.5817036032676697, + 0.37340155243873596, + 0.20366452634334564, + 0.1546182781457901, + -0.11224830150604248, + -0.29856279492378235, + -0.5281672477722168, + -0.5890122056007385, + -0.4024880528450012, + -0.23706914484500885, + -0.0641399398446083, + -0.0025121152866631746, + 0.0051757702603936195, + -0.014290476217865944, + 0.0043721878901124, + -0.004783981014043093, + 0.021787043660879135, + -0.004969750996679068, + -0.022116241976618767, + 0.05208030343055725, + 0.07022145390510559, + 0.03730607405304909, + 0.03242917358875275, + 0.04344351217150688, + -0.01189794484525919, + -0.0418211966753006, + -0.059125497937202454, + -0.014576594345271587, + 0.01294493954628706, + -0.011262460611760616, + -0.059920165687799454, + -0.04733816161751747, + -0.12665517628192902, + -0.29677024483680725, + -0.5247481465339661, + -0.6474934816360474, + -0.4751538038253784, + -0.1937171369791031, + -0.05117221921682358, + 0.14646948873996735, + 0.32891425490379333, + 0.5415402054786682, + 0.6071264147758484, + 0.4653589427471161, + 0.18045872449874878, + 0.09937354922294617, + 0.1264665126800537, + 0.18507222831249237, + 0.31783968210220337, + 0.3545042872428894, + 0.22468777000904083, + 0.09973976761102676, + 0.1227618008852005, + -0.07824759930372238, + -0.20465101301670074, + -0.36476215720176697, + -0.38243186473846436, + -0.2540777623653412, + -0.13525226712226868, + -0.03621843457221985, + -0.012233156710863113, + -0.01481863297522068, + -0.04313792288303375, + 0.002874002791941166, + -0.028444716706871986, + -0.04687628522515297, + -0.026806645095348358, + -0.0228339321911335, + -0.015892738476395607, + -0.015550780110061169, + 0.07011140882968903, + 0.0017389585264027119, + -0.05721491947770119, + -0.017484690994024277, + -0.03954736143350601, + -0.006339249666780233, + 0.08166316151618958, + 0.37439921498298645, + 0.2830294966697693, + 0.00668215099722147, + -0.038873329758644104, + -0.012295035645365715, + 0.04932165890932083, + 0.31826695799827576, + 0.8449289202690125, + 0.7123299241065979, + 0.2574000954627991, + 0.04747961834073067, + -0.04416817054152489, + -0.005029442720115185, + 0.2027042657136917, + 0.6639980673789978, + 0.6243636012077332, + 0.21359916031360626, + 0.027929672971367836, + -0.05395142361521721, + -0.04981911554932594, + -0.006375179626047611, + 0.23660773038864136, + 0.2155737280845642, + 0.020577391609549522, + -0.032118700444698334, + -0.02332071214914322, + -0.009217707440257072, + -0.038096409291028976, + 0.05811609327793121, + 0.03776064142584801, + -0.03570764884352684, + -0.042420413345098495, + 0.017812976613640785, + 0.019242385402321815, + 0.030057156458497047, + 0.003040613606572151, + 0.02378096617758274, + 0.04043402150273323, + 0.0243258997797966, + 0.014026327058672905, + 0.005650558043271303, + -0.002831381279975176, + -0.0645776093006134, + -0.03761167451739311, + 0.043774381279945374, + 0.010685136541724205, + 0.031011218205094337, + -0.0025828774087131023, + -0.11959855258464813, + -0.3524792194366455, + -0.30037227272987366, + -0.053334690630435944, + 0.009859252721071243, + 0.0010005333460867405, + -0.04819931834936142, + -0.3154168128967285, + -0.7240553498268127, + -0.6380828022956848, + -0.25695785880088806, + -0.06639125943183899, + 0.03295261785387993, + -0.012727363035082817, + -0.24232468008995056, + -0.6055921912193298, + -0.5679556727409363, + -0.20067356526851654, + -0.03628019988536835, + 0.04774145409464836, + 0.029560575261712074, + -0.038632482290267944, + -0.24032950401306152, + -0.2095729559659958, + -0.006905315909534693, + 0.02563827484846115, + 0.03053808957338333, + 0.0012747920118272305, + 0.004095789045095444, + -0.07932732999324799, + -0.046672020107507706, + 0.02153847925364971, + 0.019504766911268234, + -0.006118285935372114, + 0.0026654782705008984, + 0.013819373212754726, + -0.01078135147690773, + 0.0070082321763038635, + 0.00906399916857481, + 0.010149766691029072, + 0.000516490894369781, + 0.00034157291520386934, + 0.02412085421383381, + 0.006926041562110186, + 0.023299943655729294, + 0.01129852794110775, + -0.0018704778049141169, + 0.016042279079556465, + 0.023886069655418396, + 0.04207555204629898, + -0.0021778997033834457, + 0.041684601455926895, + 0.05059140920639038, + 0.03518521040678024, + -0.0032736151479184628, + -0.0007146652205847204, + 0.015503454953432083, + -0.11896659433841705, + -0.07006713002920151, + 0.007565992418676615, + 0.012584990821778774, + 0.00843358226120472, + 0.017024952918291092, + 0.0359124094247818, + -0.05997823178768158, + -0.04116949439048767, + -0.016472430899739265, + 0.002696823561564088, + 0.00829327292740345, + 0.016238784417510033, + 0.0455794483423233, + 0.0019872160628437996, + -0.005927432328462601, + -0.003552153240889311, + 0.020063765347003937, + 0.00010026743984781206, + 0.01045019831508398, + 0.034689340740442276, + 0.014206668362021446, + 0.015128945000469685, + 0.00972809735685587, + 0.019944868981838226, + 0.020581791177392006, + 0.02938947267830372, + 0.03923909366130829, + 0.03601628914475441, + 0.030168617144227028, + 0.05403255671262741, + 0.03985666483640671, + 0.020015308633446693, + 0.0285494402050972, + 0.013555807992815971, + -0.04409409686923027, + -0.07503483444452286, + 0.01716756261885166, + 0.02053452841937542, + 0.057520389556884766, + 0.02973104454576969, + -0.04563397541642189, + -0.2676408588886261, + -0.30933722853660583, + -0.11671236902475357, + 0.0020135289523750544, + 0.022801443934440613, + -0.03161352127790451, + -0.2704106271266937, + -0.5803710222244263, + -0.5762420296669006, + -0.30449461936950684, + -0.0780220776796341, + 0.017343536019325256, + -0.05319945886731148, + -0.2906038463115692, + -0.598426342010498, + -0.5925986766815186, + -0.31852787733078003, + -0.09950074553489685, + 0.05888299271464348, + 0.01939479075372219, + -0.1060815081000328, + -0.3505017161369324, + -0.3200446665287018, + -0.10609738528728485, + 0.03659524768590927, + 0.056114207953214645, + 0.03447861596941948, + 0.014380007050931454, + -0.09436371922492981, + -0.07562272250652313, + 0.04223132133483887, + 0.06327345967292786, + -0.03735652193427086, + -0.052881840616464615, + -0.058017320930957794, + -0.02474917098879814, + -0.02431381866335869, + -0.0629878118634224, + -0.05212349444627762, + -0.03820814937353134, + -0.0034579068887978792, + -0.004930540919303894, + 0.07968354970216751, + 0.07278168946504593, + 0.015167324803769588, + -0.013638288713991642, + -0.05875609815120697, + -0.008851750753819942, + 0.10708516091108322, + 0.33075177669525146, + 0.3502756953239441, + 0.14791442453861237, + 0.03131852671504021, + -0.028764141723513603, + 0.07454497367143631, + 0.3000347316265106, + 0.6147283315658569, + 0.6289594173431396, + 0.3398674726486206, + 0.13494613766670227, + -0.03705109655857086, + 0.0633230209350586, + 0.3147434592247009, + 0.595033586025238, + 0.594217836856842, + 0.33864542841911316, + 0.11264053732156754, + -0.059276629239320755, + 0.005206871312111616, + 0.14524762332439423, + 0.37473905086517334, + 0.34477534890174866, + 0.12632343173027039, + 0.011062734760344028, + -0.06149457022547722, + -0.028670497238636017, + 0.011082210578024387, + 0.13112866878509521, + 0.1106843650341034, + -0.0025933771394193172, + -0.03781202808022499, + 0.030325254425406456, + 0.017758814617991447, + 0.01635698974132538, + -0.008786264806985855, + -0.0005018062074668705, + 0.005934061016887426, + 0.020206287503242493, + 0.019497420638799667, + -0.01290479488670826, + -0.010817185044288635, + -0.032760608941316605, + -0.026973316445946693, + -0.0021766452118754387, + -0.012848617509007454, + -0.0002560729335527867, + -0.02383977733552456, + -0.05322824791073799, + -0.05382781848311424, + -0.04459262639284134, + -0.04581240937113762, + -0.03465775027871132, + 0.0026904877740889788, + -0.026097090914845467, + -0.05170493200421333, + -0.04981262609362602, + -0.05221042037010193, + -0.05268307775259018, + -0.04735802114009857, + 0.019142162054777145, + -0.019374292343854904, + -0.03312355652451515, + -0.04133244976401329, + -0.033129844814538956, + -0.01844680868089199, + -0.024726904928684235, + 0.0012146441731601954, + -0.025521529838442802, + -0.03120318427681923, + -0.04863203689455986, + -0.021450525149703026, + -0.04190714284777641, + -0.02833862416446209, + 0.017827404662966728, + -0.010181388817727566, + -0.020994380116462708, + -0.04290826618671417, + -0.031555648893117905, + -0.030525390058755875, + -0.024981478229165077, + -0.017512500286102295, + 0.019927235320210457, + 0.00433371402323246, + -0.009276121854782104, + -0.03990143537521362, + -0.021251117810606956, + 0.017825132235884666, + -0.02313065528869629, + 0.012881814502179623, + 0.0009175563463941216, + -0.0656605213880539, + -0.007037178613245487, + 0.023603176698088646, + 0.04873553663492203, + 0.013912673108279705, + 9.78652315097861e-05, + -0.03166677802801132, + -0.11772678792476654, + -0.034320034086704254, + 0.04952533170580864, + 0.10113520920276642, + 0.030472615733742714, + -0.05131377652287483, + -0.1371452510356903, + -0.2326214611530304, + -0.0629519522190094, + 0.12444627285003662, + 0.15845368802547455, + 0.014535457827150822, + -0.06888624280691147, + -0.18798232078552246, + -0.24720685184001923, + -0.04858007654547691, + 0.26889580488204956, + 0.2433905005455017, + -0.01772989332675934, + -0.06027546152472496, + -0.12164203822612762, + -0.20018024742603302, + 0.0035393801517784595, + 0.27190765738487244, + 0.1929154396057129, + -0.012923460453748703, + -0.013931642286479473, + -0.043986693024635315, + -0.0655391663312912, + 0.04751605913043022, + 0.13482201099395752, + 0.06690078228712082, + -0.01862635649740696, + 0.02938506379723549, + 0.01789080537855625, + -0.006509440019726753, + -0.029202938079833984, + -0.023693149909377098, + 0.01042762491852045, + -0.0035929735749959946, + 0.024952176958322525, + -0.013459124602377415, + -0.10798560827970505, + -0.020217353478074074, + 0.017876077443361282, + 0.07628928124904633, + 0.04444783553481102, + 0.012667268514633179, + -0.09012818336486816, + -0.22452381253242493, + -0.07556752860546112, + 0.07942477613687515, + 0.17035256326198578, + 0.0396822914481163, + -0.08236342668533325, + -0.23916372656822205, + -0.3645225763320923, + -0.10748416185379028, + 0.1996970921754837, + 0.3076043725013733, + -0.0033923503942787647, + -0.13259321451187134, + -0.28894615173339844, + -0.3605952262878418, + -0.07969008386135101, + 0.3583948314189911, + 0.4267900586128235, + -0.02228585258126259, + -0.11386624723672867, + -0.21445821225643158, + -0.26956692337989807, + 0.026791207492351532, + 0.37918713688850403, + 0.37130093574523926, + -0.05172214284539223, + -0.05132569745182991, + -0.07469630241394043, + -0.11400169134140015, + 0.07863093167543411, + 0.24061299860477448, + 0.19393151998519897, + -0.03217098489403725, + 0.013085477985441685, + 0.032348379492759705, + 0.03207695484161377, + 0.010604938492178917, + -0.026534704491496086, + -0.018284842371940613, + -0.01768680103123188, + -0.001516501884907484, + 0.013829287141561508, + -0.034318119287490845, + 0.015753330662846565, + -0.0018936718115583062, + 0.014737343415617943, + 0.03306088596582413, + 0.020835628733038902, + -0.03396771103143692, + -0.10758449137210846, + -0.03052518330514431, + 0.020080547779798508, + 0.06180800125002861, + 0.03735671192407608, + -0.037925880402326584, + -0.09720461815595627, + -0.21495617926120758, + -0.06842153519392014, + 0.08532039076089859, + 0.13350333273410797, + 0.03649023920297623, + -0.03904158994555473, + -0.1483580619096756, + -0.2068314403295517, + -0.05687328055500984, + 0.21108660101890564, + 0.21018920838832855, + 0.009318819269537926, + -0.037683792412281036, + -0.09845960140228271, + -0.1535443514585495, + 0.004504916723817587, + 0.20256847143173218, + 0.1799001693725586, + -0.03175490349531174, + -0.020391397178173065, + -0.007309200707823038, + -0.06765769422054291, + 0.013149870559573174, + 0.08469820767641068, + 0.04147877171635628, + -0.0027241194620728493, + 0.008016721345484257, + 0.001382349175401032, + 0.0001219741752720438, + -0.059255484491586685, + -0.03761141747236252, + 0.0381690077483654, + -0.01603613793849945, + 0.0017731477273628116, + -0.016544193029403687, + 0.09518970549106598, + 0.1735895872116089, + 0.005558829288929701, + -0.13464735448360443, + -0.0703420490026474, + 0.001990854274481535, + -0.03426021337509155, + -0.4390500485897064, + -0.11292288452386856, + 0.20430812239646912, + 0.14832687377929688, + 0.06074441969394684, + -0.03749264031648636, + 0.408058226108551, + 0.43119552731513977, + -0.3804298937320709, + -0.3694773018360138, + -0.03696960583329201, + 0.04022200033068657, + -0.0812998041510582, + -0.4322642385959625, + 0.19638888537883759, + 0.7809834480285645, + 0.11584538966417313, + -0.04975399747490883, + -0.015579828992486, + 0.1362757831811905, + 0.027220597490668297, + -0.4703449606895447, + -0.3726261258125305, + 0.11754196882247925, + -0.01204066164791584, + -0.00118898821529001, + -0.05152498185634613, + 0.08767394721508026, + 0.14183296263217926, + 0.01692730002105236, + -0.04587334021925926, + 0.011115594767034054, + 0.021572716534137726, + -0.021584773436188698, + -0.012763801962137222, + 0.05708793178200722, + 0.021982798352837563, + -0.02731800265610218, + 0.03000856563448906, + 0.006653181277215481, + -0.02485630102455616, + -0.20296195149421692, + -0.10483214259147644, + 0.20483383536338806, + 0.1350196748971939, + -0.08543248474597931, + 0.02644401416182518, + 0.26855263113975525, + 0.1071053072810173, + -0.8168368935585022, + -0.6617473363876343, + 0.02877889946103096, + 0.21807144582271576, + -0.02164696715772152, + -0.03712613880634308, + 0.9743875861167908, + 1.1631361246109009, + -0.45643851161003113, + -0.8180081844329834, + -0.28109386563301086, + -0.09115415811538696, + -0.4352502226829529, + -0.7433719038963318, + 0.5383746027946472, + 1.7271664142608643, + 0.509749174118042, + -0.0689467042684555, + 0.010011479258537292, + 0.11752951890230179, + -0.28825971484184265, + -1.113126277923584, + -0.6029489636421204, + 0.357056587934494, + 0.19766344130039215, + 0.023361098021268845, + 0.04305602237582207, + 0.24867205321788788, + 0.16359609365463257, + -0.2485191822052002, + -0.2251967489719391, + 0.030422789976000786, + 0.0049157580360770226, + -0.05497031658887863, + -0.030760835856199265, + 0.034536562860012054, + 0.019565051421523094, + -0.00933124776929617, + 0.01611645519733429, + 0.07988770306110382, + -0.021982649341225624, + -0.21876110136508942, + -0.10555483400821686, + 0.1893070936203003, + 0.14684906601905823, + -0.031080693006515503, + 0.09768003225326538, + 0.3261844515800476, + 0.1466774046421051, + -0.6738073825836182, + -0.5424039363861084, + 0.04689512774348259, + 0.22039148211479187, + -0.07084018737077713, + -0.07436021417379379, + 0.8260523080825806, + 1.0253428220748901, + -0.38162854313850403, + -0.727206289768219, + -0.2605172097682953, + -0.0996573269367218, + -0.3653049170970917, + -0.6791687607765198, + 0.43514078855514526, + 1.4186147451400757, + 0.38797008991241455, + -0.12675431370735168, + 0.02766786515712738, + 0.14237603545188904, + -0.2306709885597229, + -0.9204807877540588, + -0.5071616172790527, + 0.32662850618362427, + 0.20703284442424774, + -0.020968681201338768, + 0.014105334877967834, + 0.24642448127269745, + 0.20103473961353302, + -0.15519124269485474, + -0.22072142362594604, + 0.049920063465833664, + -0.05465548485517502, + 0.018651481717824936, + 0.030082669109106064, + 0.05234164372086525, + 0.10243640840053558, + 0.03569166734814644, + 0.038984544575214386, + 0.05248976871371269, + 0.24501988291740417, + 0.4674161374568939, + 0.7142530083656311, + 0.7423628568649292, + 0.6262048482894897, + 0.4019012451171875, + -0.010997634381055832, + 0.17266513407230377, + 0.4467124342918396, + 0.7795005440711975, + 0.8282667994499207, + 0.6824804544448853, + 0.3955397605895996, + 0.009771074168384075, + 0.10707246512174606, + 0.23039454221725464, + 0.33151063323020935, + 0.36120596528053284, + 0.3240644633769989, + 0.17939962446689606, + -0.01115038525313139, + -0.11081521213054657, + -0.2146066278219223, + -0.3572347164154053, + -0.44021451473236084, + -0.38320258259773254, + -0.24643990397453308, + 0.031578775495290756, + -0.21325217187404633, + -0.4312629997730255, + -0.7276368141174316, + -0.8273008465766907, + -0.718246340751648, + -0.4161607027053833, + -0.06636986136436462, + -0.28078269958496094, + -0.476252943277359, + -0.734549880027771, + -0.7796792984008789, + -0.6637035608291626, + -0.41896238923072815, + 0.021693198010325432, + 0.006199972704052925, + -0.016619624570012093, + -0.010678192600607872, + 0.012267512269318104, + 0.004102918319404125, + -0.004080160986632109, + -0.0029241242446005344, + -0.027252744883298874, + -0.0772257149219513, + -0.09107967466115952, + -0.11302012205123901, + -0.08569496124982834, + -0.07242150604724884, + -0.016465697437524796, + -0.04874062165617943, + -0.09103028476238251, + -0.09025602042675018, + -0.07523388415575027, + -0.06320428103208542, + -0.048220545053482056, + -0.028701437637209892, + -0.008647853508591652, + -0.022354092448949814, + -0.06076030433177948, + -0.030872423201799393, + -0.045786645263433456, + -0.04190178960561752, + 0.03718986362218857, + 0.021405767649412155, + 0.007675759959965944, + 0.02794131636619568, + 0.030316906049847603, + 0.007403802592307329, + 0.04861852154135704, + 0.023217258974909782, + 0.04545973241329193, + 0.07504793256521225, + 0.06824314594268799, + 0.07417462021112442, + 0.0769289955496788, + 0.0766506940126419, + -0.0028638055082410574, + 0.05911175534129143, + 0.055706772953271866, + 0.10735032707452774, + 0.10494870692491531, + 0.11092723160982132, + 0.09338293969631195, + 0.04235343262553215, + -0.022347571328282356, + -0.026347652077674866, + -0.06954608112573624, + -0.06944439560174942, + -0.05570404976606369, + -0.042987462133169174, + -0.056951191276311874, + -0.2151203453540802, + -0.3603246510028839, + -0.5899456143379211, + -0.6453464031219482, + -0.5338351726531982, + -0.31790611147880554, + 0.049492284655570984, + -0.12898015975952148, + -0.40155911445617676, + -0.6737278699874878, + -0.7170611619949341, + -0.5817899703979492, + -0.32979026436805725, + -0.005899591837078333, + -0.07673019915819168, + -0.190496027469635, + -0.34019437432289124, + -0.3314637243747711, + -0.2796767055988312, + -0.1381818801164627, + -0.008025999180972576, + 0.08429048955440521, + 0.2105528861284256, + 0.3415210545063019, + 0.4151126444339752, + 0.34003961086273193, + 0.21059827506542206, + -0.03514896333217621, + 0.1792585551738739, + 0.3903186321258545, + 0.6413942575454712, + 0.7557680010795593, + 0.6069726943969727, + 0.3415443003177643, + 0.03447553142905235, + 0.21517080068588257, + 0.4215562045574188, + 0.6151171922683716, + 0.6550290584564209, + 0.5680058002471924, + 0.33561068773269653, + -0.12205997854471207, + -0.0038300298620015383, + 0.3281119763851166, + -0.2328944057226181, + -0.03834507241845131, + 0.05432930961251259, + -0.014430212788283825, + 0.006271198857575655, + 0.32864242792129517, + 0.47277259826660156, + -0.5593215227127075, + -0.14971251785755157, + 0.13066314160823822, + -0.09738356620073318, + 0.2966129779815674, + 0.5606555342674255, + -0.3184640407562256, + -2.022890090942383, + -0.361995667219162, + 0.5496177673339844, + 0.02796279452741146, + -0.21818380057811737, + -0.5373459458351135, + -1.9538941383361816, + -1.9984712600708008, + 1.6747761964797974, + 1.5063239336013794, + -0.24534250795841217, + -0.040306344628334045, + -0.16963164508342743, + -0.40690454840660095, + 1.3548375368118286, + 3.922116279602051, + 0.8723023533821106, + -0.8986141681671143, + 0.06912416964769363, + 0.2192920595407486, + 0.352949321269989, + 1.2243634462356567, + 1.1395865678787231, + -1.5146961212158203, + -1.1557590961456299, + -0.05440744385123253, + -0.04629289731383324, + -0.002693743444979191, + -0.21906790137290955, + -0.5464610457420349, + -1.1933224201202393, + 0.01913866586983204, + 0.09363497048616409, + -0.06080613285303116, + -0.049100056290626526, + 0.04482033848762512, + -0.04087500274181366, + -0.009318803437054157, + 0.009458474814891815, + -0.09565524011850357, + -0.2264278084039688, + -0.0698866918683052, + 0.13825084269046783, + 0.014815542846918106, + -0.05801662430167198, + 0.012776852585375309, + -0.0753035843372345, + -0.07555855065584183, + 0.484436959028244, + 0.6397283673286438, + 0.12687323987483978, + -0.01779526099562645, + 0.05689511448144913, + 0.06747376173734665, + 0.26353734731674194, + 0.5908273458480835, + 0.4315526783466339, + -0.5426794290542603, + -0.44501280784606934, + -0.019558124244213104, + -0.03320806100964546, + -0.025809556245803833, + 0.17376014590263367, + -0.5201969742774963, + -1.2842578887939453, + -0.3674038052558899, + 0.0882175862789154, + -0.030023137107491493, + -0.1173325777053833, + 0.02555503323674202, + -0.39882710576057434, + -0.37364596128463745, + 0.3550366163253784, + 0.3903135359287262, + 0.04022252932190895, + 0.016731394454836845, + 0.11207644641399384, + -0.020967213436961174, + -0.028497911989688873, + 0.37590932846069336, + 0.14920172095298767, + 0.029958104714751244, + 0.039632707834243774, + -0.24969367682933807, + 0.16809938848018646, + 0.07703239470720291, + -0.03522319719195366, + -0.007072617299854755, + 0.07751759141683578, + -0.06782346963882446, + -0.4010501801967621, + 0.41269779205322266, + 0.1311105638742447, + -0.07331988960504532, + 0.08240311592817307, + -0.20034979283809662, + -0.4718745946884155, + -0.178948312997818, + 1.3285318613052368, + 0.20384186506271362, + -0.48546233773231506, + -0.09941625595092773, + 0.13249020278453827, + 0.29977336525917053, + 1.2681238651275635, + 1.5725642442703247, + -1.0834472179412842, + -1.0335719585418701, + 0.25975045561790466, + 0.06584863364696503, + 0.1609305590391159, + 0.25940945744514465, + -0.8426372408866882, + -2.590407609939575, + -0.4723183214664459, + 0.7581043243408203, + -0.03634117543697357, + -0.10199672728776932, + -0.3744191527366638, + -0.7823801636695862, + -0.7062401175498962, + 1.116550087928772, + 0.7735803127288818, + 0.012776976451277733, + 0.034575968980789185, + -0.10188565403223038, + 0.2212170958518982, + 0.5182898044586182, + 0.8056022524833679, + -0.1897655427455902, + -0.005556725896894932, + -0.003909373190253973, + -0.02175678312778473, + -0.04085654392838478, + -0.03573022410273552, + -0.0038509985897690058, + 0.02454996667802334, + 0.039437733590602875, + 0.02077251859009266, + 0.02166259102523327, + 0.17245841026306152, + 0.09513862431049347, + -0.10491111874580383, + -0.08084940910339355, + -0.026179829612374306, + 0.0215831957757473, + -0.16602416336536407, + -0.2803819179534912, + 0.23894084990024567, + 0.3269801735877991, + 0.04504352807998657, + 0.0009768904419615865, + 0.01959501951932907, + 0.24426960945129395, + -0.1451571136713028, + -0.5944203734397888, + -0.17875447869300842, + 0.028336334973573685, + 0.004323791246861219, + -0.045389141887426376, + 0.0343034490942955, + 0.46665430068969727, + 0.3707427978515625, + -0.114569291472435, + 0.04335101321339607, + -0.018011711537837982, + -0.021181274205446243, + -0.19074901938438416, + -0.20113815367221832, + 0.048786211758852005, + 0.08533122390508652, + -0.06084573268890381, + 0.01217757910490036, + 0.030666939914226532, + 0.05272842198610306, + 0.010849648155272007, + -0.05913804844021797, + -0.04202868044376373, + -0.0015147016383707523, + -0.03421122953295708, + 0.015080726705491543, + 0.12191007286310196, + 0.10450142621994019, + -0.04972418025135994, + -0.07557133585214615, + -0.02221665158867836, + -0.0861242413520813, + -0.14919178187847137, + -0.04388582333922386, + 0.4605262875556946, + 0.5697804093360901, + 0.1583399623632431, + -0.045628566294908524, + -0.05220475420355797, + -0.13630147278308868, + -0.7103163599967957, + -1.0178179740905762, + 0.1927143931388855, + 0.7479860186576843, + 0.47013771533966064, + 0.16943301260471344, + 0.2398149073123932, + 0.4710526168346405, + -0.5974176526069641, + -1.8564051389694214, + -0.7726883292198181, + 0.05584309995174408, + 0.08902852982282639, + 0.0931839719414711, + 0.46213099360466003, + 1.2080260515213013, + 0.6001025438308716, + -0.590207576751709, + -0.4145379662513733, + -0.04529324173927307, + -0.08303339034318924, + -0.2470429688692093, + -0.03481363505125046, + 0.4808541238307953, + 0.4001348614692688, + -0.1292688548564911, + -0.03635162487626076, + -0.006270444020628929, + -0.0314505510032177, + -0.13043232262134552, + -0.10837803781032562, + 0.10718243569135666, + 0.07523836195468903, + -0.00597786670550704, + 0.06580565124750137, + 0.11166563630104065, + 0.021869506686925888, + -0.10510984063148499, + -0.07651247084140778, + 0.01229890063405037, + -0.08976037800312042, + -0.14929910004138947, + -0.018859578296542168, + 0.4408939778804779, + 0.4029107689857483, + -0.05015433207154274, + -0.13887189328670502, + -0.04514491930603981, + -0.07346425950527191, + -0.5277182459831238, + -0.7335640788078308, + 0.24182197451591492, + 0.626846432685852, + 0.23399080336093903, + 0.09675730019807816, + 0.15529058873653412, + 0.42680656909942627, + -0.4012089967727661, + -1.3605350255966187, + -0.4793834686279297, + 0.10987094044685364, + 0.07592830061912537, + 0.003319029463455081, + 0.24004696309566498, + 0.9590277671813965, + 0.4946591258049011, + -0.4889579117298126, + -0.34744441509246826, + -0.020535729825496674, + -0.026767954230308533, + -0.2090117186307907, + -0.11841326951980591, + 0.37452432513237, + 0.39960840344429016, + -0.07025045901536942, + -0.022984744980931282, + 0.022319970652461052, + -0.0027356306090950966, + -0.13681942224502563, + -0.09797768294811249, + 0.09914079308509827, + 0.10856777429580688, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 3, 7, 7)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_501") + + initializers.append(tensor) + + list_value = [ + 3.085598945617676, + 2.2436060905456543, + 4.244357585906982, + 1.4069645404815674, + -4.00622034072876, + 2.595770835876465, + 2.7202603816986084, + 2.4405417442321777, + 1.1759933233261108, + 2.021026372909546, + 2.6628992557525635, + 6.445226192474365, + -7.029932498931885, + 1.1305793523788452, + 2.537140369415283, + 5.456772327423096, + 4.780154705047607, + 10.039976119995117, + 2.912492275238037, + 15.781542778015137, + 2.5154318809509277, + 2.628824472427368, + 2.2992050647735596, + 2.0950584411621094, + -7.93365478515625, + 2.067786931991577, + 4.094852447509766, + 1.673399806022644, + 3.1814424991607666, + 22.49496078491211, + 2.232640027999878, + 2.6427979469299316, + -9.418174743652344, + 1.790976643562317, + 2.3774726390838623, + 2.5836219787597656, + 2.5608203411102295, + 2.287343978881836, + 2.6439085006713867, + 16.859027862548828, + 1.8699607849121094, + -3.6987526416778564, + 2.6861538887023926, + 2.8997464179992676, + 2.689293384552002, + 2.6654043197631836, + 2.3799915313720703, + 2.5603086948394775, + 3.146122694015503, + 2.715951681137085, + 2.889486789703369, + 2.966134548187256, + -4.960191249847412, + 2.6123547554016113, + 1.3074164390563965, + 2.2033026218414307, + 2.2114620208740234, + 4.132844924926758, + 4.893764495849609, + 2.6469600200653076, + 2.654136896133423, + 1.9311997890472412, + 2.881012439727783, + 2.6991193294525146, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_502") + + initializers.append(tensor) + + list_value = [ + 0.057212892919778824, + 0.06299274414777756, + -0.018499961122870445, + -0.06501776725053787, + -0.015820641070604324, + 0.024293724447488785, + 0.05624663084745407, + -0.025112055242061615, + 0.043546054512262344, + 0.08439744263887405, + 0.005678815301507711, + 0.0034800865687429905, + 0.030301403254270554, + -0.011669250205159187, + -0.005434689112007618, + -0.1591511219739914, + 0.02324092946946621, + -0.018942436203360558, + 0.025366367772221565, + -0.07414374500513077, + 0.03468436002731323, + -0.003742520697414875, + -0.06651683896780014, + 0.005561002530157566, + 0.04527103528380394, + -0.13710148632526398, + 0.0025444801431149244, + 0.03583350405097008, + 0.015219246037304401, + -0.053635064512491226, + 0.004856681916862726, + -0.07223699986934662, + 0.016770021989941597, + 0.0012010147329419851, + 0.014582094736397266, + -0.005172556731849909, + 0.02009868621826172, + -0.0064261858351528645, + -0.029086023569107056, + 0.001915874076075852, + 0.0008194410474970937, + 0.01620865799486637, + 0.03067426010966301, + -0.0018463254673406482, + 0.05358384922146797, + -0.003966080490499735, + -0.05991416424512863, + -0.06455761194229126, + 0.01634763367474079, + -0.013959774747490883, + 0.03615918383002281, + 0.004434086848050356, + 0.02086004987359047, + -0.004025993403047323, + -0.8869641423225403, + 0.05558132007718086, + 0.024729542434215546, + -0.005809253081679344, + -0.025079259648919106, + 0.04757235199213028, + 0.0023902510292828083, + 0.01522061601281166, + 0.011692625470459461, + 0.023033330217003822, + -0.012664714828133583, + -0.29325294494628906, + -0.006855700630694628, + -0.243958979845047, + 0.0024398649111390114, + -0.060877203941345215, + -0.21996521949768066, + -0.008708474226295948, + -0.06639625877141953, + -0.03170674294233322, + -0.09708897024393082, + 0.013403226621448994, + 0.024766888469457626, + 0.2594103217124939, + -0.02221749909222126, + 0.0662861093878746, + -0.15123076736927032, + -0.010314224287867546, + -0.0029192541260272264, + 0.05985910817980766, + 0.021665453910827637, + 0.003247617743909359, + -0.006802591495215893, + 0.00772367138415575, + 0.0399332195520401, + 0.005198766943067312, + 0.006013805978000164, + -0.04212838411331177, + -0.03166411817073822, + 0.13363900780677795, + 0.006383878644555807, + -0.05536859482526779, + 0.02053261175751686, + 0.015062958002090454, + 0.03352641686797142, + -0.2944328486919403, + 0.019855381920933723, + -0.15567174553871155, + -0.06759943068027496, + 0.07467031478881836, + 0.01674237661063671, + 0.004549413453787565, + -0.0032498433720320463, + -0.1837870180606842, + -0.04725493863224983, + -0.111307792365551, + 0.022237055003643036, + 0.004200428258627653, + 0.00970534235239029, + -0.045657914131879807, + -0.024577995762228966, + 0.0035376595333218575, + 0.008936531841754913, + -0.03904002904891968, + 0.05013228952884674, + -0.011168933473527431, + -0.008444730192422867, + 0.0035155978985130787, + -0.023502476513385773, + 0.005275514908134937, + -0.09448224306106567, + -0.009177467785775661, + -0.010720008052885532, + 0.004110944457352161, + -0.0060218218713998795, + 0.058124978095293045, + -0.0016586220590397716, + 0.15812785923480988, + -0.049118027091026306, + -0.007983109913766384, + -0.04265601187944412, + -0.01627231575548649, + 0.33705562353134155, + 0.01555223111063242, + 0.035853929817676544, + 0.0005046340520493686, + 0.054810188710689545, + -0.08808254450559616, + -0.0013819067971780896, + -0.14938786625862122, + -0.019771935418248177, + 0.004152575507760048, + 0.021979758515954018, + 0.1985529363155365, + -0.07694264501333237, + 0.013187955133616924, + -0.016572976484894753, + -0.03094586730003357, + -0.03673199936747551, + -0.03916170820593834, + -0.003836784977465868, + -0.012262578122317791, + 0.005559554789215326, + 0.1488093137741089, + -0.01842501200735569, + -0.004847189411520958, + -0.02391587756574154, + 0.015824301168322563, + 0.012022596783936024, + 0.06724318116903305, + -0.032682593911886215, + 0.00450896704569459, + -0.0024625889491289854, + 0.00933725107461214, + -0.04473242908716202, + 0.06270455569028854, + -0.02062271721661091, + -0.01071448065340519, + -0.017757099121809006, + 0.01575278490781784, + -0.06489317119121552, + -0.01519051194190979, + 0.0028058059979230165, + 0.00917835533618927, + -0.01291860081255436, + -0.009537308476865292, + 0.041757628321647644, + 0.03203853219747543, + -0.10918509215116501, + -0.007152496371418238, + -0.06777876615524292, + 0.03223242610692978, + 0.01780836284160614, + -0.09791012853384018, + -0.009385241195559502, + 0.013184775598347187, + 0.0031673219054937363, + -0.010640445165336132, + 0.024713385850191116, + -0.026738369837403297, + -0.004191657993942499, + -0.13764967024326324, + -0.003720735665410757, + 0.01737186871469021, + 0.015459887683391571, + 0.033229030668735504, + 0.008042111992835999, + -0.007184108253568411, + 0.008226306177675724, + 0.0031303109135478735, + 0.0406314842402935, + -0.8669105768203735, + 0.02079751342535019, + -0.17030003666877747, + -0.03849703446030617, + 0.034153200685977936, + -0.007219486869871616, + 0.11227627843618393, + -0.2681085467338562, + 0.015872526913881302, + 0.10855260491371155, + -0.008631505072116852, + 0.02556358277797699, + 0.06043418496847153, + -0.012900532223284245, + -0.08834894001483917, + 0.028099440038204193, + -0.05156330019235611, + 0.032628703862428665, + 0.044928934425115585, + 0.006176372990012169, + 0.007333829998970032, + -0.037409231066703796, + -0.046724822372198105, + -0.011172871105372906, + 0.04603327810764313, + 0.03288746625185013, + -0.20848578214645386, + 0.0028185085393488407, + -0.032673876732587814, + 0.061944279819726944, + 0.016787173226475716, + 0.02703898213803768, + -0.0060023171827197075, + 0.06870592385530472, + 0.03154531493782997, + 0.02784041129052639, + 0.007780189625918865, + 0.02033168077468872, + 0.0019289497286081314, + 0.02545374445617199, + 0.04262726008892059, + 0.01301807351410389, + -0.023882156237959862, + 0.027872221544384956, + -0.013518108054995537, + -0.0031075032893568277, + 0.03753834590315819, + 0.0369209349155426, + -0.014378191903233528, + 0.004397932440042496, + -0.030286893248558044, + -0.007679021451622248, + -0.045032769441604614, + 0.032050322741270065, + -0.03373495861887932, + -0.04363032802939415, + 0.034301597625017166, + -0.07021668553352356, + 0.03942524269223213, + -0.11061309278011322, + 0.049139462411403656, + 0.04161922261118889, + -0.01507576834410429, + -0.012748259119689465, + 0.06599434465169907, + 0.007602245546877384, + -0.03973209857940674, + -0.06923151016235352, + 0.026153067126870155, + -0.04221056029200554, + -0.4828230142593384, + 0.03360651433467865, + 0.01847662217915058, + -0.08594681322574615, + 0.04071836546063423, + -0.0035729086957871914, + 0.0049045816995203495, + -0.036198534071445465, + 0.03046257793903351, + 0.013275806792080402, + 0.09266786277294159, + -0.03625647351145744, + -0.059672992676496506, + 0.050213005393743515, + -0.018153885379433632, + -0.0858495831489563, + 0.01621098257601261, + -0.03029749169945717, + 0.02193332649767399, + 0.0422661192715168, + 0.6109512448310852, + -0.01068826112896204, + -0.02184930257499218, + -0.03213764354586601, + -0.03148162364959717, + -0.055331334471702576, + 0.006972005590796471, + -0.00815682765096426, + 0.014874683693051338, + -0.012943249195814133, + -0.03318992629647255, + -0.0010484680533409119, + 0.005414161365479231, + -0.013610370457172394, + 0.008836873807013035, + -0.05890084058046341, + -0.022663919255137444, + -0.018899116665124893, + -0.01037894282490015, + 0.005064660683274269, + 0.08522599190473557, + 0.0075323861092329025, + 0.013720778748393059, + 0.032096460461616516, + -0.008450351655483246, + 0.020377663895487785, + 0.04537765309214592, + 0.014030816033482552, + 0.024340089410543442, + 0.0231801588088274, + -0.10347768664360046, + 0.041163086891174316, + -0.060614243149757385, + -0.09241361171007156, + 0.05831432715058327, + -0.16008608043193817, + -0.04505622759461403, + 0.04866329953074455, + -0.0656094029545784, + 0.09627313911914825, + 0.1153625100851059, + 0.008151216432452202, + 0.03813345730304718, + 0.05990723893046379, + 0.24788673222064972, + 0.06294118613004684, + 0.11761849373579025, + -0.0722033903002739, + -0.013892017304897308, + -0.016778236255049706, + 0.038522012531757355, + -0.015539593063294888, + 0.01263216882944107, + 0.0003969807003159076, + -0.0224238783121109, + -0.005919966846704483, + 0.031987495720386505, + -0.014712700620293617, + 0.03508169203996658, + 0.07568854838609695, + -0.011961974203586578, + 0.027983952313661575, + -0.03512958809733391, + -0.010324078612029552, + -0.2895449995994568, + 0.007338976487517357, + -0.042290836572647095, + -0.1640917807817459, + -0.034807007759809494, + -0.1268443465232849, + 0.18418198823928833, + -0.3867812156677246, + -0.14214494824409485, + 0.001021744217723608, + 0.11288078874349594, + 0.006741920951753855, + -0.006421610247343779, + 0.021150892600417137, + 0.02486848644912243, + 0.002660338068380952, + 0.03732302784919739, + 0.10844919830560684, + -0.032568808645009995, + 0.009477612562477589, + 0.053578171879053116, + -0.07421902567148209, + 0.05660263076424599, + 0.03038308583199978, + 0.049440011382102966, + 0.0395139642059803, + 0.0217339675873518, + 0.028231965377926826, + 0.1661153882741928, + -0.02168717049062252, + 0.055143170058727264, + -0.14159196615219116, + 0.05894732475280762, + 0.006888065952807665, + -0.06988262385129929, + 0.017527412623167038, + -0.007171930745244026, + -0.00448343763127923, + 0.02932717651128769, + -0.00652179354801774, + -0.002897858154028654, + 0.020487705245614052, + -0.027063967660069466, + -0.02539752423763275, + -0.1066114604473114, + -0.10011029988527298, + -0.03331710025668144, + -0.003807300003245473, + -0.010441976599395275, + -0.005605363752692938, + 0.09679440408945084, + 0.020033519715070724, + -0.010188378393650055, + -0.030630890280008316, + -0.00955540407449007, + 0.02825581096112728, + -0.4307324290275574, + 0.012557203881442547, + 0.043258048593997955, + 0.09386534243822098, + -0.009555542841553688, + 0.05304868891835213, + 0.014706632122397423, + -0.012911850586533546, + 0.0981304720044136, + -0.010722141712903976, + -0.027317194268107414, + 0.0893903523683548, + -0.19983792304992676, + -0.15778200328350067, + -0.1012115329504013, + -0.3758164644241333, + -0.05782865360379219, + -0.01230492815375328, + -0.37126046419143677, + -0.01596723683178425, + 0.0020407456904649734, + -0.017498979344964027, + 0.005369496997445822, + -0.023121315985918045, + 0.022279681637883186, + -0.006232256535440683, + 0.05115891620516777, + 0.006679570768028498, + 0.0026316209696233273, + 0.04291496425867081, + 0.04381528124213219, + -0.05994122102856636, + 0.007081915624439716, + -0.04571640491485596, + 0.07592425495386124, + -0.00836833007633686, + 0.008123279549181461, + -0.008003163151443005, + -0.003938044421374798, + 0.005643180105835199, + 0.016194086521863937, + -0.004063089843839407, + 0.012334472499787807, + 0.017072021961212158, + 0.005761854816228151, + 0.004702428821474314, + 0.005736868362873793, + 0.0017962371930480003, + 0.059996701776981354, + 0.19533602893352509, + 0.02649352326989174, + -0.06493135541677475, + -0.05955052375793457, + 0.015692468732595444, + -0.10623155534267426, + 0.07290898263454437, + 0.036108434200286865, + -0.01248949021100998, + 0.16444285213947296, + -0.005899128969758749, + 0.07875277101993561, + 0.0014204353792592883, + 0.03381470963358879, + -0.09680792689323425, + 0.002102318685501814, + 0.026962973177433014, + 0.031665392220020294, + -0.18168538808822632, + 0.11163855344057083, + -0.5409999489784241, + 0.07833191007375717, + -0.005324948113411665, + 0.0267564058303833, + 0.02250477857887745, + 0.03249068558216095, + -0.18441715836524963, + -0.006447427906095982, + 0.037927329540252686, + 0.0005173985846340656, + -0.02617005631327629, + 0.05929232016205788, + -0.028510913252830505, + 0.05447050556540489, + 0.012390155345201492, + 0.00046797769027762115, + -0.008598590269684792, + -0.17247197031974792, + -0.02855759859085083, + 0.033968932926654816, + -0.09011702984571457, + 0.05276056379079819, + 0.03299655020236969, + -0.005699596833437681, + -0.1954648792743683, + 0.011109501123428345, + -0.0013570536393672228, + -0.6543989181518555, + 0.009102803654968739, + 0.0407538004219532, + 0.04312055557966232, + 0.027609223499894142, + -0.035538043826818466, + 0.027167823165655136, + -0.024043193086981773, + 0.0047575319185853004, + -0.006788836792111397, + 0.025714389979839325, + 0.007848678156733513, + -0.07680192589759827, + 0.009700766764581203, + -0.0097329281270504, + 0.00586724653840065, + 0.022815868258476257, + -0.023448282852768898, + -0.05608998239040375, + 0.10786863416433334, + -0.02803603559732437, + 0.012898198328912258, + -0.009270391426980495, + -0.021972229704260826, + 0.26533082127571106, + -0.01021308358758688, + -0.01972626894712448, + 0.062940314412117, + 0.022569671273231506, + 0.027042347937822342, + -0.05669092759490013, + -0.01200617104768753, + -0.006279367487877607, + -0.009608528576791286, + -0.013600943610072136, + -0.02187415212392807, + 0.0351138636469841, + 0.006282923277467489, + -0.011123511008918285, + -0.009205769747495651, + 0.001010146806947887, + -0.4796978235244751, + -0.0030205894727259874, + -0.011987377889454365, + -0.027548225596547127, + 0.009372347965836525, + -0.005388603545725346, + -0.006444129627197981, + -0.02501147985458374, + 0.027465635910630226, + 0.027784524485468864, + 0.006878893356770277, + -0.027763860300183296, + -0.0047700353898108006, + -0.018965192139148712, + 0.027898501604795456, + 0.022454144433140755, + 0.02973407506942749, + 0.03505602851510048, + 0.04003170132637024, + -0.004336829297244549, + -0.01998550072312355, + -0.06097743660211563, + -0.07844759523868561, + 0.0013787010684609413, + 0.0066132270731031895, + -0.03124997951090336, + 0.0313432514667511, + 0.047656893730163574, + 0.06175797060132027, + -0.02077358029782772, + -0.004535601008683443, + -0.10219905525445938, + -0.07125344127416611, + -0.06927482783794403, + -0.04813461750745773, + -0.02618095651268959, + -0.01255929097533226, + -0.009180150926113129, + -0.005838831886649132, + 0.09108023345470428, + -0.032710760831832886, + 0.03091445378959179, + -0.01955563761293888, + 0.0959300771355629, + -0.09353741258382797, + -0.0761636272072792, + -0.023445438593626022, + -0.012328366748988628, + 0.05850536748766899, + -0.052494827657938004, + 0.0025638933293521404, + -0.017152179032564163, + -0.004435579292476177, + 0.12312240898609161, + -0.007241012528538704, + 0.09605048596858978, + 0.03355967625975609, + -0.015987426042556763, + -0.03470349311828613, + -0.02499505691230297, + -0.015004142187535763, + -0.018609771504998207, + -0.06654462963342667, + 0.013861652463674545, + -0.005973289255052805, + -0.04734775796532631, + 0.08755116909742355, + 0.03012942522764206, + 0.07887610793113708, + -0.01827712170779705, + 0.10793066769838333, + 0.10793614387512207, + -0.01075535174459219, + 0.03439560532569885, + 0.011567444540560246, + 0.0016386889619752765, + -0.031207261607050896, + -0.01707504875957966, + 0.20471863448619843, + 0.0025428179651498795, + 0.004082779865711927, + -0.012389302253723145, + 0.0400562584400177, + -0.21075034141540527, + 0.012872264720499516, + -0.01639414019882679, + 0.016652485355734825, + 0.0016037120949476957, + -0.006540367379784584, + -0.0068405005149543285, + -0.2484254390001297, + 0.0008089764742180705, + -0.022340824827551842, + -0.005441636312752962, + 0.002882100408896804, + 0.008654038421809673, + 0.07159754633903503, + -0.02537086047232151, + 0.011997461318969727, + -0.49913132190704346, + -0.02300887741148472, + 0.044442202895879745, + 0.001787978457286954, + 0.010291379876434803, + 0.009601960889995098, + -0.5312613248825073, + -0.014247804880142212, + 0.06685849279165268, + 0.035772595554590225, + 0.03432310372591019, + 0.03151272237300873, + -0.10318460315465927, + -0.030476456508040428, + -0.004469831008464098, + -0.16645164787769318, + -0.021104637533426285, + 0.013934006914496422, + -0.011767406016588211, + 0.008054615929722786, + 0.06089277192950249, + 0.0003409573109820485, + -0.0053401123732328415, + 0.05970478057861328, + -0.004363172687590122, + 0.014423285610973835, + -0.002795026171952486, + -0.019875092431902885, + -0.07540513575077057, + -0.09043378382921219, + 0.00750827556475997, + -0.045314721763134, + -0.00724808732047677, + 0.005193864461034536, + -0.020468784496188164, + -0.01098695583641529, + -0.0003122477210126817, + -0.007263806648552418, + -0.03325646370649338, + 0.021689830347895622, + -0.13272541761398315, + 0.02332465350627899, + -0.019292252138257027, + 0.05533658340573311, + -0.018616480752825737, + -0.015228793025016785, + -0.28432801365852356, + -0.29721561074256897, + 0.04648810625076294, + -0.014750649221241474, + -0.15370936691761017, + -0.1497083604335785, + 0.013243601657450199, + 0.042343802750110626, + -0.017519792541861534, + -0.0161418616771698, + 0.00807454064488411, + -0.023562468588352203, + -0.0315413773059845, + 0.03386805206537247, + 0.2854529917240143, + 0.0191020630300045, + -0.49126777052879333, + 0.052687134593725204, + -0.023298051208257675, + -0.009119837544858456, + 0.05149759724736214, + -0.8527837991714478, + 0.08062390983104706, + 0.057379938662052155, + -0.020724931731820107, + -0.006624895613640547, + 0.05322050303220749, + 0.017887847498059273, + 0.04229281470179558, + 0.04171830415725708, + 0.029683062806725502, + -0.00028416322311386466, + 0.1112222746014595, + -0.0448714978992939, + -0.005255761090666056, + 0.017773712053894997, + -0.0016064767260104418, + -0.013840594328939915, + -0.00398495327681303, + -4.32919041486457e-05, + 0.040796443819999695, + 0.018185198307037354, + -0.018671950325369835, + 0.0028256692457944155, + -0.020582057535648346, + 0.05567716807126999, + -0.056062404066324234, + 0.01614757999777794, + -0.0029299987945705652, + 0.048686008900403976, + 0.04299888014793396, + 0.12249592691659927, + 0.01469603180885315, + -0.1254546344280243, + -0.18532024323940277, + -0.003263876074925065, + 0.014804725535213947, + 0.004450956825166941, + -0.013681051321327686, + -0.0030781759414821863, + -0.03433656692504883, + -0.0035507124848663807, + 0.1600082814693451, + -0.028547707945108414, + -0.00989136379212141, + -0.012126478366553783, + -0.12963305413722992, + 0.008547360077500343, + 0.017959514632821083, + -0.012571084313094616, + 0.0008666724897921085, + -0.010519342496991158, + -0.009684977121651173, + -0.04285729303956032, + 0.015031769871711731, + -0.030043724924325943, + 0.018907636404037476, + 0.08019450306892395, + -0.04836742579936981, + 0.01025464478880167, + -0.004908542148768902, + -0.10327022522687912, + -0.10163667798042297, + -0.03403499722480774, + -0.019678063690662384, + -0.043049123138189316, + 0.0384567566215992, + -0.05596519634127617, + -0.09381429851055145, + -0.18688108026981354, + -0.09762943536043167, + -0.03164997324347496, + -0.006416287273168564, + 0.07003920525312424, + -0.016646990552544594, + -0.025972194969654083, + -0.028768088668584824, + -0.06332779675722122, + 0.045144014060497284, + -0.03735211119055748, + -0.010442189872264862, + 0.10948455333709717, + 0.14629514515399933, + -0.023416690528392792, + -0.01347778458148241, + 0.020830679684877396, + 0.0003131759003736079, + 0.007049075793474913, + 0.06547018885612488, + 0.03152740001678467, + 0.08380027115345001, + 0.03185325488448143, + -0.015359007753431797, + 0.08864206075668335, + 0.032676901668310165, + -0.002908645663410425, + 0.053111132234334946, + 0.0026159954722970724, + -0.05177146941423416, + -0.033048152923583984, + -0.0020293137058615685, + -0.07363513857126236, + -0.17662747204303741, + 0.004798125941306353, + 0.07139395922422409, + 0.019802849739789963, + 0.009199771098792553, + -0.009043877013027668, + -0.07681646943092346, + -0.06748555600643158, + 0.05094710737466812, + 0.0014789587585255504, + -0.0166088305413723, + -0.27988284826278687, + 0.03634800389409065, + 0.05322619527578354, + -0.15566207468509674, + -0.019964642822742462, + -0.010204506106674671, + -0.011832086369395256, + -0.0680927112698555, + -0.05793820694088936, + 0.0020100779365748167, + -0.24647225439548492, + 0.04904041066765785, + -0.05589786171913147, + -0.030167482793331146, + 0.023974033072590828, + -0.22719347476959229, + 0.019620347768068314, + -0.18078163266181946, + -0.11321499198675156, + -0.023790234699845314, + -0.1266157031059265, + 0.01117659267038107, + 0.13824795186519623, + -0.024211348965764046, + -0.0548308864235878, + 0.04849318787455559, + -0.0016174454940482974, + -0.01826266385614872, + 0.006709347013384104, + -0.350631982088089, + 0.03139018639922142, + 0.021502504125237465, + -0.12596893310546875, + 0.04311670735478401, + -0.005905786994844675, + -0.0807335153222084, + -0.07214773446321487, + -0.2054852843284607, + -0.04526854678988457, + -0.09145382046699524, + 0.002603817731142044, + -0.01951524056494236, + -0.0028278473764657974, + -0.03270411863923073, + -0.0003385065938346088, + -0.019816655665636063, + -0.003430107608437538, + 0.010664679110050201, + 0.030127109959721565, + 0.02611778862774372, + 0.030213139951229095, + 0.04682943969964981, + 0.010338326916098595, + -0.02618880569934845, + 0.014982170425355434, + -0.06979402899742126, + 0.06403722614049911, + 0.025545112788677216, + -0.11981001496315002, + 0.004320457112044096, + 0.008849565871059895, + 0.07450827211141586, + -0.04322020336985588, + -0.07648278027772903, + 0.009221173822879791, + -0.12771189212799072, + 0.027474528178572655, + -0.1637975573539734, + -0.022587651386857033, + 0.0713210329413414, + -0.09652210026979446, + -0.04942077025771141, + -0.08977267891168594, + -0.004629603121429682, + -0.09891843795776367, + 0.0004028059483971447, + 0.12999524176120758, + 0.009417874738574028, + -0.012465995736420155, + 0.09959464520215988, + 0.012048770673573017, + 0.00529639283195138, + -0.1231047734618187, + -0.010156300850212574, + -0.0067022680304944515, + 0.09231371432542801, + 0.1372271031141281, + 0.01140755694359541, + -0.014376018196344376, + 0.009014246053993702, + -0.0558021254837513, + 0.009297777898609638, + -0.023461824283003807, + 0.12312523275613785, + 0.0013492326252162457, + -0.10130659490823746, + 0.07867099344730377, + -0.04363301396369934, + -0.05203291028738022, + 0.010715829208493233, + 0.2679101228713989, + 0.047242000699043274, + 0.009700302965939045, + -0.004188477993011475, + 0.04595324397087097, + -0.10256988555192947, + 0.013266253285109997, + 0.13415516912937164, + -0.06461263447999954, + -0.04262775555253029, + 0.014638054184615612, + -0.020396970212459564, + 0.016008291393518448, + 0.012964261695742607, + 0.030219901353120804, + -0.03906702250242233, + -0.009459082037210464, + -0.006880247965455055, + 0.009383107535541058, + 0.0591101311147213, + -0.049882922321558, + -0.014105924405157566, + -0.04896679148077965, + 0.021726086735725403, + -0.013863577507436275, + -0.05801064148545265, + -0.031143831089138985, + 0.0010298469569534063, + -0.03104572743177414, + 0.1193046048283577, + 0.00880056619644165, + -0.01678626798093319, + 0.0014990485506132245, + -0.001967367948964238, + -0.0053575835190713406, + -0.006879259832203388, + -0.008937212638556957, + 0.014141763560473919, + 0.00687083275988698, + -0.0012949275551363826, + 0.017160816118121147, + -0.035110652446746826, + -0.00976842176169157, + 0.026605995371937752, + 0.004003277514129877, + 0.010927689261734486, + 0.002173327375203371, + -0.05133439600467682, + -0.04658171907067299, + 0.03023359179496765, + -0.015038624405860901, + 0.016580749303102493, + 0.02393144741654396, + 0.004817661829292774, + -0.008468102663755417, + 0.017239807173609734, + 0.019924553111195564, + 0.02557404898107052, + 0.01985766738653183, + -0.01881517469882965, + -0.14637643098831177, + -0.005403783638030291, + -0.013156545348465443, + -0.3882855176925659, + 0.01537711638957262, + 0.005061861593276262, + 0.018044542521238327, + 0.00010373388067819178, + -0.01769324019551277, + -0.020439250394701958, + 0.01761222817003727, + 0.017716309055685997, + -0.01828574948012829, + 0.0059916484169662, + 0.006117791403084993, + -0.0025541253853589296, + 0.01598154753446579, + 0.0015296537894755602, + 0.006711189169436693, + -0.005831963382661343, + 0.024547481909394264, + 0.011665170080959797, + 0.013990279287099838, + -0.009193074889481068, + -0.0014407691778615117, + 0.0025373499374836683, + -0.001535113900899887, + 0.022016262635588646, + 0.002165747107937932, + -0.00010288839985150844, + -0.01185672264546156, + 0.3959958255290985, + -0.06701132655143738, + 0.024550342932343483, + -0.007259713020175695, + 0.00011224728223169222, + 0.08959072828292847, + 0.006745494436472654, + -0.007461291737854481, + -0.0010788652580231428, + -0.003997487016022205, + 0.0023250498343259096, + 0.005845727398991585, + 0.002441686810925603, + 0.0010628585005179048, + 0.004687050357460976, + 0.03825820982456207, + 0.0027951127849519253, + 0.004356732591986656, + 0.0036379920784384012, + -0.00048690394032746553, + -0.31681910157203674, + 0.01621195860207081, + 0.009373913519084454, + -0.005099120549857616, + 0.004866141825914383, + 0.008112045004963875, + -0.009933174587786198, + -0.006929770577698946, + 0.005561198107898235, + -0.2225065976381302, + -0.00019208311277907342, + -0.003284667618572712, + 0.010527989827096462, + -0.010160842910408974, + -0.008410060778260231, + 0.004605174530297518, + 0.01542133092880249, + 0.013958578929305077, + 0.0021779180970042944, + 0.002810562262311578, + 0.001369283301755786, + -0.0003347232413943857, + 0.013902815990149975, + -0.0022218015510588884, + 0.00024955783737823367, + -0.0019350153161212802, + 0.0025213193148374557, + -0.0054915109649300575, + -0.00011564489977899939, + -0.0037644850090146065, + -0.002863431815057993, + -0.0025196163915097713, + 0.02352992817759514, + 0.00354134407825768, + -0.010700036771595478, + -0.03428381308913231, + 0.008170859888195992, + 0.005420713219791651, + -0.0013479178305715322, + 0.0015741022070869803, + -0.18286381661891937, + 0.03189067915081978, + 0.0014371845172718167, + -4.885893940809183e-05, + -0.004666821099817753, + -0.026595929637551308, + -0.0064376350492239, + 0.01583540253341198, + -0.085715651512146, + -0.00916224904358387, + -0.3605174124240875, + 0.019973354414105415, + 0.05533794313669205, + 0.053907446563243866, + 0.030877795070409775, + -0.919844925403595, + 8.968543988885358e-05, + -0.02068270742893219, + 0.012602192349731922, + 0.03245612978935242, + 0.06622699648141861, + 0.00882122665643692, + -0.03616628423333168, + -0.02428283728659153, + 0.003318701172247529, + -0.0007259293342940509, + -0.026197656989097595, + -0.059503961354494095, + 0.029495801776647568, + -0.006955073680728674, + -0.01926456019282341, + 0.009927013888955116, + 0.059641581028699875, + 0.0016886347439140081, + -0.029346982017159462, + 0.01948450319468975, + -0.04397860914468765, + 0.025248751044273376, + 0.04597266763448715, + 0.009454794228076935, + -0.018872544169425964, + -0.039650529623031616, + 0.026324709877371788, + -0.01808176562190056, + 0.028935831040143967, + 0.009501701220870018, + -0.05183069407939911, + -0.005787428934127092, + -0.021436212584376335, + 0.029735956341028214, + 0.0350160151720047, + 0.033825185149908066, + 0.03185566887259483, + 0.018431033939123154, + 0.02450188808143139, + 0.03271135315299034, + -0.0027792940381914377, + -0.0004625302099157125, + 0.01268392987549305, + 0.045023106038570404, + 0.05562014505267143, + 0.029052015393972397, + -0.002513203304260969, + -0.08349838852882385, + 7.017837560852058e-06, + -0.0014392733573913574, + 0.016982918605208397, + 0.016358936205506325, + -0.024013325572013855, + -0.004375616554170847, + -0.03734249249100685, + 0.04336351156234741, + 0.07323610782623291, + -0.0243068914860487, + 0.009403819218277931, + 0.02663031965494156, + 0.01930687017738819, + 0.02175578847527504, + 0.01639295555651188, + 0.024892140179872513, + 0.031219134107232094, + 0.02986173704266548, + -0.002100786194205284, + 0.05054357647895813, + 0.04015854373574257, + 0.0048207067884504795, + -0.03244275599718094, + 0.027246609330177307, + 0.00409608893096447, + -0.0054193479008972645, + 0.07014931738376617, + 0.009954879060387611, + 0.022472694516181946, + -0.47738370299339294, + -0.019097158685326576, + 0.028984038159251213, + -0.042564358562231064, + -0.006040808744728565, + 0.04094231128692627, + -0.007740774191915989, + -0.07854597270488739, + 0.003920051269233227, + -0.050799619406461716, + 0.023691626265645027, + 0.019952887669205666, + 0.00716764759272337, + -0.0046928380616009235, + 0.00041822553612291813, + 0.006359069608151913, + 0.017860781401395798, + -0.22999149560928345, + -0.02180831879377365, + -0.024055887013673782, + -0.0226126741617918, + -0.01795077696442604, + 0.015591473318636417, + -0.004053472075611353, + 0.016760380938649178, + 0.03378744795918465, + -0.0027090508956462145, + 0.00999806821346283, + 0.019252799451351166, + 0.0027550198137760162, + 0.03454355522990227, + -0.0295003242790699, + -0.007663591764867306, + 0.061172280460596085, + 0.049142658710479736, + -0.00858291145414114, + -0.0035321018658578396, + -0.7689260244369507, + 0.0004916944890283048, + 0.02915046364068985, + 0.017000442370772362, + -0.003298018593341112, + -0.0405484102666378, + 0.021160880103707314, + 0.0013289587805047631, + -0.07510386407375336, + 0.03890690207481384, + 0.03729970380663872, + -0.04906352981925011, + -0.10020274668931961, + 0.01506283599883318, + -0.053726132959127426, + 0.016631007194519043, + 0.03425036743283272, + 0.03358260169625282, + -0.023937245830893517, + -0.13656578958034515, + -0.13947314023971558, + 0.012915699742734432, + 0.02431132085621357, + -0.03089652583003044, + 0.1382707953453064, + 0.056695129722356796, + -0.09263960272073746, + 0.10406216233968735, + 0.02619105577468872, + -0.01678614132106304, + -0.16045455634593964, + 8.974489173851907e-05, + -0.03521093726158142, + -0.028908027336001396, + 0.21234789490699768, + -0.02046572044491768, + -0.09703273326158524, + 0.05248226970434189, + 0.011973158456385136, + 0.004557646345347166, + -0.018632734194397926, + -0.1649131029844284, + -0.00682018743827939, + -0.12712189555168152, + 0.10513507574796677, + 0.020745709538459778, + 0.02996259182691574, + -0.15409024059772491, + -0.08719073981046677, + -0.14634187519550323, + -0.16255779564380646, + -0.15963757038116455, + -0.1324772834777832, + -0.022830091416835785, + -0.06426219642162323, + -0.025459224358201027, + 0.00281702633947134, + 0.03255268186330795, + -0.05778049677610397, + -0.30381152033805847, + -0.06582051515579224, + -0.033722274005413055, + 0.014956191182136536, + 0.004153797868639231, + 0.2391217201948166, + -0.0311420951038599, + 0.001518488978035748, + 0.019769812002778053, + -0.056324463337659836, + -0.006009253207594156, + -0.21367721259593964, + -0.0481688529253006, + 0.22422266006469727, + 0.0402204655110836, + 0.1432792693376541, + 0.14159953594207764, + -0.0025862890761345625, + -0.028965365141630173, + 0.011978867463767529, + 0.161293163895607, + 0.028642605990171432, + -0.008417634293437004, + -0.10145614296197891, + 0.08381767570972443, + 0.05199432373046875, + 0.18680602312088013, + -0.023287687450647354, + 0.03601476550102234, + 0.03738229721784592, + 0.19291405379772186, + 0.03553088754415512, + 0.05483124405145645, + 0.09577616304159164, + -0.004635817836970091, + 0.052481625229120255, + -0.042084019631147385, + -0.2629147469997406, + -0.006157668773084879, + -0.0401761569082737, + 0.02154349908232689, + -0.056558139622211456, + -0.003753019031137228, + 0.01922912523150444, + 0.1291409730911255, + -0.21358416974544525, + 0.004696246236562729, + 0.13787509500980377, + -0.07022479176521301, + -0.06828727573156357, + 0.09193858504295349, + -0.06863763928413391, + -0.05677935853600502, + -0.030970478430390358, + -0.10181070864200592, + -0.1247706487774849, + 0.014181962236762047, + -0.09259836375713348, + -0.03174220770597458, + -0.014812505804002285, + -0.024658311158418655, + -0.04815720021724701, + -0.01683010160923004, + 0.015726473182439804, + 0.002938281511887908, + -0.1586887538433075, + -0.29276973009109497, + -0.029981529340147972, + -0.046828676015138626, + -0.04909103736281395, + 0.06043976545333862, + 0.03698069602251053, + -0.04807118698954582, + 0.0943484902381897, + 0.01930702105164528, + 0.06498143821954727, + 0.0381690077483654, + -0.19611406326293945, + 0.006944946013391018, + 0.06454038619995117, + -0.19779883325099945, + 0.04966692253947258, + 0.046355295926332474, + 0.0590626522898674, + -0.24392037093639374, + -0.0018132536206394434, + 0.010944955050945282, + -0.014556891284883022, + 0.051466893404722214, + -0.0059846509248018265, + -0.06719732284545898, + 0.030604040250182152, + 0.051190104335546494, + -0.053196243941783905, + -0.06912374496459961, + -0.06263922154903412, + 0.05626852437853813, + 0.013047950342297554, + -0.005828890949487686, + 0.056055404245853424, + 0.007044378202408552, + 0.030499491840600967, + -0.035373322665691376, + 0.030934391543269157, + 0.04358363524079323, + 0.001537138712592423, + 0.005963161122053862, + -0.005889860913157463, + 0.053225863724946976, + 0.052091702818870544, + -0.02871675044298172, + 0.05662619322538376, + -0.4585985839366913, + 0.06490323692560196, + 0.02542230300605297, + 0.017592567950487137, + 0.05066920816898346, + -0.20954127609729767, + -0.06689731031656265, + -0.3632309138774872, + -0.03407476842403412, + 0.04976007342338562, + 0.03856723755598068, + 0.009329214692115784, + -0.10107281804084778, + 0.007077769376337528, + -0.005482642911374569, + 0.04388934373855591, + 0.03984231874346733, + 0.005358297843486071, + 0.05032944679260254, + 0.007170544005930424, + 0.017318176105618477, + -0.03577208146452904, + -0.02195456624031067, + 0.014414021745324135, + -0.008203372359275818, + 0.04585091397166252, + -0.012298643589019775, + 0.03959968313574791, + -0.06015963852405548, + -0.1360240876674652, + -0.07704123109579086, + -0.0842466950416565, + -0.11261942237615585, + 0.0433686338365078, + -0.1059969812631607, + 0.014813154004514217, + 0.04216694459319115, + 0.10441470146179199, + 0.04579426348209381, + 0.026033954694867134, + 0.08725529909133911, + -0.14662955701351166, + -0.0726592168211937, + 0.1293957382440567, + 0.013497715815901756, + -0.01318936888128519, + -0.05188713222742081, + 0.08793413639068604, + 0.1094818189740181, + 0.07991892844438553, + 0.03549068048596382, + -0.04469897970557213, + -0.10442564636468887, + 0.13456915318965912, + 0.01154977548867464, + -0.05959299951791763, + 0.01768219843506813, + 0.0179652888327837, + -0.010112428106367588, + 0.020603090524673462, + -0.7144030928611755, + 0.20126283168792725, + 0.058172807097435, + -0.10543914139270782, + 0.07461538910865784, + -0.1744592934846878, + 0.055722273886203766, + -0.046595826745033264, + 0.06237049773335457, + 0.05800141766667366, + 0.04118870943784714, + 0.002582935383543372, + 0.010623090900480747, + -0.0439014658331871, + 0.044685740023851395, + -0.017063472419977188, + -0.0173367727547884, + -0.04761765897274017, + 0.06136244907975197, + 0.08495236933231354, + 0.24923592805862427, + -0.061080869287252426, + 0.15922360122203827, + -0.09322690963745117, + -0.09617402404546738, + 0.0029533954802900553, + 0.12630371749401093, + 0.0011397749185562134, + 0.0005059551913291216, + -0.060922350734472275, + -0.16446451842784882, + 0.057099178433418274, + 0.03073902614414692, + -0.031064951792359352, + 0.012277435511350632, + 0.020447896793484688, + 0.06010727211833, + 0.07065457105636597, + 0.026963504031300545, + 0.010798406787216663, + -0.02631279267370701, + 0.02046871930360794, + -0.004800989292562008, + -0.03282550349831581, + 0.053904879838228226, + -0.03294985368847847, + -0.4204113185405731, + 0.028552187606692314, + 0.023685462772846222, + 0.0017703581834211946, + 0.02868991158902645, + -0.3585520088672638, + -0.011516556143760681, + -0.00248165475204587, + 0.011379038915038109, + 0.0459531806409359, + 0.015357235446572304, + 0.05573337897658348, + 0.06516549736261368, + 0.02981666848063469, + 0.05498211458325386, + 0.028714550659060478, + -0.005899528972804546, + 0.008476868271827698, + 0.11328839510679245, + 0.020578190684318542, + -0.15382742881774902, + 0.015724696218967438, + -0.08402770012617111, + 0.060314107686281204, + 0.032343748956918716, + 0.014438764192163944, + -0.13614842295646667, + -0.0017508765449747443, + 0.09998518973588943, + -0.06364594399929047, + 0.049632295966148376, + -0.11922458559274673, + -0.08834195137023926, + 0.019541991874575615, + 0.06320779770612717, + 0.017419861629605293, + -0.0028468866366893053, + -0.14753428101539612, + 0.02623703144490719, + -0.011462770402431488, + 0.06676206737756729, + -0.014891563914716244, + -0.002118025440722704, + 0.02519390918314457, + -0.29581141471862793, + 0.0264339130371809, + 0.04027356952428818, + 0.00412194337695837, + 0.03778498247265816, + -0.012331741861999035, + 0.15336745977401733, + -0.034510836005210876, + 0.0319819413125515, + 0.01916184462606907, + 0.04952343553304672, + -0.026733938604593277, + -0.014996573328971863, + 0.0010714810341596603, + 0.01959756202995777, + -0.0392388179898262, + -0.0052064210176467896, + -0.05015777423977852, + -0.0002977418771479279, + -0.04029487073421478, + -0.012846150435507298, + -0.09198840707540512, + 0.0118671590462327, + -0.06176264211535454, + 0.006427878048270941, + 0.04043034091591835, + -0.017270859330892563, + -0.012422707863152027, + 0.01713552325963974, + -0.026697810739278793, + 0.2446632832288742, + -0.020500628277659416, + -0.0012782106641680002, + -0.13429665565490723, + 0.07528743892908096, + -0.002225265372544527, + 0.06695574522018433, + 0.0017388156848028302, + -0.0629071593284607, + -0.05081196129322052, + 0.042025983333587646, + 0.029097404330968857, + 0.07048555463552475, + -0.11881273239850998, + 0.012633765116333961, + -0.06181430071592331, + 0.038810230791568756, + 0.05186169967055321, + 0.03248963877558708, + 0.07868267595767975, + 0.024977494031190872, + 0.023991582915186882, + 0.0023529180325567722, + 0.07197123020887375, + 0.02653665468096733, + 0.058702051639556885, + 0.015001803636550903, + 0.043739400804042816, + -0.07251746207475662, + 0.045659150928258896, + -0.02111324854195118, + 0.26666632294654846, + 0.1975221484899521, + -0.031074335798621178, + 0.029075143858790398, + 0.013020229525864124, + 0.015244663693010807, + 0.01387549377977848, + -0.025354426354169846, + 0.06151636317372322, + -0.034430794417858124, + 0.00752665288746357, + 0.1678706705570221, + -0.016560610383749008, + 0.0421285480260849, + -0.02527586743235588, + -0.02166694961488247, + -0.034658536314964294, + 0.036866605281829834, + -0.036233626306056976, + 0.02042747661471367, + 0.028099242597818375, + 0.020503878593444824, + 0.022789381444454193, + 0.08666791766881943, + -0.06426636874675751, + -0.043599683791399, + 0.1136128157377243, + 0.020200412720441818, + -0.003839759388938546, + -0.06010120362043381, + -0.02218424715101719, + 0.09008956700563431, + 0.008711264468729496, + -0.04874516651034355, + -0.011533043347299099, + -0.036206502467393875, + -0.006006627343595028, + -0.0350450798869133, + 0.005623341538012028, + 0.09562186151742935, + -0.03952183946967125, + -0.013931595720350742, + -0.020029470324516296, + 0.0022144403774291277, + -0.020198611542582512, + 0.012238736264407635, + 0.054415784776210785, + -0.024457741528749466, + -0.01174110360443592, + 0.031656913459300995, + 0.060322560369968414, + 0.01573050767183304, + 0.03361794352531433, + 0.022875478491187096, + 0.036340806633234024, + -0.02932620421051979, + 0.0224352665245533, + -0.013475337065756321, + -0.030774995684623718, + 0.013921404257416725, + -0.01229875348508358, + -0.07986237108707428, + -0.007543445099145174, + 0.05208213999867439, + -0.04440496116876602, + -0.029659371823072433, + -0.029070377349853516, + 0.07376870512962341, + -0.07208643853664398, + -0.05429431423544884, + -0.007887271232903004, + 0.011400371789932251, + 0.014227204024791718, + 0.01763899251818657, + -0.0426466204226017, + 0.0024213625583797693, + 0.02564665488898754, + 0.0020850151777267456, + 0.027386819943785667, + 0.12722602486610413, + -0.060991525650024414, + -0.009061425924301147, + 0.014208497479557991, + -0.006956137716770172, + 0.09096626192331314, + 0.0037735258229076862, + -0.8347064852714539, + -0.2857951521873474, + 0.0011818337952718139, + 0.0341162234544754, + -0.04230167716741562, + 0.05230262130498886, + 0.08486262708902359, + -0.34235459566116333, + -0.02393503487110138, + 0.02718495950102806, + 0.050966840237379074, + 0.024611525237560272, + -0.004936584271490574, + -0.036420952528715134, + -0.009803534485399723, + 0.05421328917145729, + 0.008357672952115536, + 0.020987343043088913, + -0.007292840629816055, + 0.018060531467199326, + 0.06739793717861176, + 0.06161382421851158, + 0.000842935056425631, + -0.007857701741158962, + 0.023870037868618965, + -0.009690430946648121, + -0.04231289029121399, + -0.22531479597091675, + 0.034284885972738266, + 0.07360551506280899, + 0.0421777106821537, + 0.000788167177233845, + -0.3953339457511902, + -0.042627450078725815, + -0.02774403616786003, + 0.02647743932902813, + -0.01561375055462122, + 0.04745408892631531, + 0.021774733439087868, + 0.006606150884181261, + 0.03879173845052719, + 0.06500626355409622, + 0.044954728335142136, + 0.01523532159626484, + 0.04741065576672554, + -0.13645507395267487, + 0.0038059696089476347, + -0.012993253767490387, + -0.004529603291302919, + 0.03268986567854881, + -0.025349941104650497, + -0.02268051542341709, + -0.0001516443444415927, + -0.010289257392287254, + -0.0010476588504388928, + -0.0690254345536232, + 0.04298266023397446, + -0.05470968782901764, + 0.04369102790951729, + -0.007372597698122263, + 0.027607066556811333, + 0.0009343988494947553, + -0.09573916345834732, + 0.04389296472072601, + -0.01522558368742466, + -0.03138086944818497, + 0.04511113464832306, + -0.0342172235250473, + -0.00033129166695289314, + -0.037289440631866455, + 0.055575959384441376, + 0.01849759928882122, + 0.03041103295981884, + -0.01965116336941719, + 0.07604960352182388, + -0.0399625338613987, + -0.008190250024199486, + -0.015386211685836315, + -0.04315667226910591, + 0.0023679479490965605, + 0.018971435725688934, + -0.005599244497716427, + -0.029607947915792465, + 0.07574024051427841, + -0.013816094025969505, + 0.04464992880821228, + 0.00032806122908368707, + 0.06071484833955765, + 0.04261377081274986, + 0.012208743952214718, + 0.0801805928349495, + 0.02875029854476452, + -0.0662921741604805, + 0.015754999592900276, + 0.05831082537770271, + 0.03810921683907509, + 0.05483977496623993, + -0.019509335979819298, + 0.0032034649048000574, + 0.011807492934167385, + -0.01916244812309742, + 0.022101666778326035, + -0.0366031751036644, + 0.10915965586900711, + 0.030322788283228874, + -0.028386037796735764, + -0.05443429946899414, + -0.02489445172250271, + 0.0892239362001419, + -0.05427740886807442, + -0.034238025546073914, + -0.04136161506175995, + -0.041148390620946884, + 0.06879492849111557, + -0.37424594163894653, + 0.028803903609514236, + 0.05349116027355194, + 0.0359492301940918, + -0.3629145622253418, + -0.17875684797763824, + -0.012246759608387947, + 0.2744927704334259, + -0.010421697050333023, + -0.19415415823459625, + 0.005668101832270622, + 0.018326066434383392, + 0.28319111466407776, + -0.008164885453879833, + -0.07401272654533386, + -0.04154321923851967, + 0.030028337612748146, + -0.008959534578025341, + -0.03160349279642105, + -0.0191870778799057, + 0.044875819236040115, + 0.052173007279634476, + 0.012135458178818226, + 0.008775291964411736, + 0.005302258301526308, + 0.009224606677889824, + -0.07574712485074997, + 0.06096252053976059, + 0.02645082212984562, + 0.05135556682944298, + 0.021985528990626335, + 0.0076704383827745914, + 0.02961125783622265, + -0.07608609646558762, + -0.17564956843852997, + 0.03679918497800827, + -0.2696506083011627, + 0.0627906322479248, + 0.031165480613708496, + 0.01799822598695755, + 0.02351829782128334, + 0.015595306642353535, + -0.25137314200401306, + -0.011266927234828472, + 0.04895596578717232, + 0.01718883402645588, + 0.0009224268142133951, + 0.021923478692770004, + 0.044791676104068756, + 0.079147569835186, + 0.02014082670211792, + -0.0003547854721546173, + -0.02535748854279518, + -0.029639363288879395, + -0.01965961419045925, + -0.37630724906921387, + 0.01674639992415905, + 0.01316642016172409, + -0.025120021775364876, + -0.12474260479211807, + 0.059980470687150955, + 0.036066047847270966, + -0.15973420441150665, + -0.010871605016291142, + 0.014708316884934902, + -0.2174367904663086, + 0.012985467910766602, + -0.03782057762145996, + -0.003427069401368499, + -0.011010636575520039, + 0.02433733455836773, + 0.08641276508569717, + -0.004630533047020435, + 0.019430357962846756, + -0.02088969387114048, + -0.06182911619544029, + 0.02577812969684601, + 0.015741532668471336, + 0.04723552614450455, + -0.003783567575737834, + 0.11646346747875214, + 0.01827184483408928, + -0.0999741181731224, + -0.0031216999050229788, + -0.002268272452056408, + -0.019456079229712486, + -0.003156653605401516, + 0.0067732855677604675, + 0.027299508452415466, + 0.06979037076234818, + 0.013329057022929192, + -0.016705401241779327, + 0.33774301409721375, + 0.007617524825036526, + 0.044453222304582596, + 0.0016282782889902592, + 0.0010982973035424948, + 0.04183036834001541, + 0.016857653856277466, + 0.006673034280538559, + -0.0187662523239851, + 0.0037163379602134228, + -0.04568779841065407, + -0.007807960733771324, + 0.016653010621666908, + 0.0033014933578670025, + 0.015063234604895115, + 0.012843966484069824, + -0.012042546644806862, + 0.016909126192331314, + 0.022089935839176178, + -0.002550398698076606, + 0.04166745766997337, + -0.0014742743223905563, + -0.010846617631614208, + -0.12333541363477707, + 0.0018612967105582356, + 0.04913188889622688, + -0.029431112110614777, + 0.01824735291302204, + 0.10425490140914917, + -0.08880072832107544, + 0.03029320202767849, + 0.018876856192946434, + 0.016104502603411674, + 0.00882721971720457, + 0.0029782119672745466, + 0.007922517135739326, + -0.02030068263411522, + -0.029835309833288193, + 0.006661414168775082, + -0.04313879832625389, + -0.001850730157457292, + -0.0035070034209638834, + -0.0070700813084840775, + 0.009637435898184776, + -0.016844747588038445, + -0.026075454428792, + 0.0030682040378451347, + 0.004208600614219904, + -0.005515689495950937, + -0.018976539373397827, + -0.019196776673197746, + -0.008948019705712795, + 0.016215825453400612, + 0.00296461652033031, + 0.14222395420074463, + -0.029066482558846474, + -0.011013337410986423, + -0.01267730537801981, + -0.004976287949830294, + -0.016607511788606644, + -0.0005681798211298883, + -0.012520174495875835, + -0.0015903630992397666, + -0.0013642794219776988, + -0.21956196427345276, + -0.0011431180173531175, + -0.0008808697457425296, + -0.022889399901032448, + 0.024718068540096283, + -0.054929111152887344, + -0.015585094690322876, + -0.018188318237662315, + -0.0008287815726362169, + -0.01957552134990692, + 0.10818513482809067, + -0.0034382494632154703, + -0.02667389065027237, + -0.01304248720407486, + -0.0034645304549485445, + -0.008519704453647137, + -0.015123830176889896, + -0.008219013921916485, + -0.009952309541404247, + -2.3375787350232713e-05, + -0.012512428686022758, + -0.001955948770046234, + -0.0029842876829206944, + -0.004291659686714411, + 0.006655955221503973, + 0.007771315053105354, + 0.014132227748632431, + -0.007390063256025314, + -0.024650415405631065, + -0.022503213956952095, + 0.0032607221510261297, + -0.008497492410242558, + 0.00860870536416769, + 0.002819088753312826, + -0.01841069757938385, + -0.010009711608290672, + -0.2912862300872803, + 0.017160022631287575, + 0.11349690705537796, + -0.027656083926558495, + -0.04482223838567734, + -0.019336597993969917, + 0.07413014769554138, + 0.014554106630384922, + 0.020965611562132835, + -0.028231356292963028, + -0.0582813061773777, + 0.05617539584636688, + -0.05042734369635582, + 0.025630727410316467, + -0.0956532284617424, + -0.14554104208946228, + -0.020851148292422295, + 0.006990485824644566, + 0.08457829803228378, + -0.11314752697944641, + 0.004020951222628355, + -0.03477870300412178, + 0.005594289395958185, + 0.011181964538991451, + 0.010988114401698112, + 0.019416088238358498, + 0.026451971381902695, + -0.00452260859310627, + 0.0004952011513523757, + 0.012377702631056309, + -0.0063480171374976635, + 0.0256175734102726, + -0.020753338932991028, + 0.03223377838730812, + -0.1147943064570427, + -0.009170151315629482, + 0.015267477370798588, + -0.0009072314132936299, + -0.1621374636888504, + 0.022807778790593147, + 0.007394107989966869, + 0.01378557924181223, + -0.10719677805900574, + -0.000919080339372158, + -0.006567052565515041, + -0.007409179583191872, + -0.007469762582331896, + -0.004784661345183849, + -0.03967805579304695, + 0.015857066959142685, + -0.02015744335949421, + 0.056037548929452896, + 0.03962035849690437, + 0.08429893851280212, + 0.022117067128419876, + -0.2675061821937561, + 0.016738418489694595, + 0.0037785861641168594, + 0.004771686624735594, + -0.134505033493042, + -0.010618447326123714, + -0.004784524440765381, + 0.014044507406651974, + -0.03105556219816208, + 0.05049083009362221, + 0.012162688188254833, + 0.005920265335589647, + 0.008554516360163689, + 0.0025892227422446012, + 0.023483717814087868, + -0.20711173117160797, + 0.03360452130436897, + -0.24758699536323547, + -0.05136318504810333, + -0.015016172081232071, + 0.06466241925954819, + 0.023470288142561913, + 0.023495715111494064, + 0.004300899337977171, + 0.02461574412882328, + 0.025745516642928123, + -0.026187308132648468, + 0.08441776037216187, + -0.06955462694168091, + -0.11116205900907516, + -0.2169608771800995, + -0.004244703333824873, + -0.024184226989746094, + -0.10068271309137344, + -0.021129190921783447, + -0.021129680797457695, + -0.0054467362351715565, + 0.17416934669017792, + 0.015367642976343632, + -0.01237915363162756, + 0.024573752656579018, + 0.004588739015161991, + 0.05616860091686249, + -0.0018992060795426369, + -0.12394066900014877, + -0.03691404312849045, + -0.15878455340862274, + 0.10572423785924911, + 0.014409378170967102, + -0.008566108532249928, + -0.20319701731204987, + -0.018277373164892197, + -0.21615462005138397, + -0.11269525438547134, + -0.2767113745212555, + -0.25617966055870056, + -0.0036413148045539856, + -0.008058675564825535, + -0.051732294261455536, + -0.013052727095782757, + 0.05229722708463669, + -0.03535814583301544, + 0.3111231327056885, + -0.044130608439445496, + -0.02232682704925537, + -0.0040402463637292385, + 0.013798556290566921, + -0.07689940929412842, + -0.028940049931406975, + -0.00565366679802537, + -0.028972560539841652, + -0.007728889584541321, + 0.013665011152625084, + -0.014678380452096462, + -0.06747694313526154, + -0.06480871140956879, + -0.00028885426581837237, + -0.01525174267590046, + 0.027096102014183998, + -0.05200905352830887, + 0.0066903820261359215, + 0.0023834225721657276, + -0.002379713812842965, + -0.0208051148802042, + 0.335977703332901, + 0.03895771875977516, + -0.04814215749502182, + -0.037339694797992706, + -0.004409746266901493, + 0.07042848318815231, + -0.08318590372800827, + -0.04138712212443352, + 0.06309781968593597, + 0.007484383415430784, + 0.09696535021066666, + 0.024134323000907898, + -0.009859816171228886, + -0.06243982911109924, + 0.04630015045404434, + -0.06593744456768036, + 0.009306293912231922, + 0.5033899545669556, + 0.007804783061146736, + 0.024170484393835068, + -0.036085959523916245, + 0.016438491642475128, + 0.01678072288632393, + -0.006299734115600586, + -0.027441656216979027, + -0.014344800263643265, + 0.022293711081147194, + 0.011197407729923725, + -0.0026971842162311077, + 0.2685070335865021, + 0.01403988990932703, + -0.005100077483803034, + -0.026031343266367912, + -0.005419034510850906, + -0.014735087752342224, + -0.0283498577773571, + 0.002656748052686453, + -0.07137783616781235, + 0.02235356532037258, + -0.02970476634800434, + 0.20672672986984253, + 0.017398398369550705, + 0.02438206970691681, + 0.025746773928403854, + -0.03279582038521767, + 0.043908532708883286, + -0.003417646512389183, + 0.020200302824378014, + 0.007243862375617027, + -0.004560714587569237, + -0.01142876222729683, + -0.028091270476579666, + -0.2949703335762024, + 0.0729827880859375, + 0.004566277377307415, + 0.16689160466194153, + 0.034872010350227356, + -0.09590360522270203, + -0.13309867680072784, + 0.06429398059844971, + 0.04174232855439186, + -0.022723963484168053, + -0.04695400968194008, + 0.013115685433149338, + 0.013574879616498947, + 0.04794493317604065, + -0.015077140182256699, + 0.09493618458509445, + 0.008845972828567028, + 0.020302923396229744, + 0.02037016488611698, + 0.009083293378353119, + 0.0747746080160141, + -0.008078188635408878, + 0.024796344339847565, + -0.015212535858154297, + -0.005867444910109043, + 0.08309170603752136, + 0.03676094114780426, + 0.07232356816530228, + -0.3577176630496979, + 0.0013658110983669758, + -0.0009247250854969025, + 0.02284996211528778, + 0.012630275450646877, + 0.013745593838393688, + 0.003447894938290119, + 0.03563565015792847, + -0.031025355681777, + -0.07258180528879166, + -0.13482442498207092, + -0.029425248503684998, + -0.014927731826901436, + 0.045984312891960144, + -0.0176406130194664, + -0.22678181529045105, + -0.025248311460018158, + -0.11617762595415115, + -0.056157518178224564, + 0.009453062899410725, + -0.34616726636886597, + 0.05691010504961014, + -0.32302799820899963, + -0.026544231921434402, + -0.007374088745564222, + -0.07682909071445465, + -0.021214107051491737, + -0.07102422416210175, + 0.02693488635122776, + 0.014817211776971817, + 0.015572831965982914, + 0.04313618317246437, + -0.1277216374874115, + 0.02174532599747181, + -0.0226149819791317, + -0.00010956164624076337, + 0.023728065192699432, + 0.008212783373892307, + 0.010561724193394184, + -0.011036543175578117, + -0.022485855966806412, + 0.008243439719080925, + -0.03383245691657066, + -0.5630682110786438, + 0.0015974265988916159, + -0.28416821360588074, + 0.04123701527714729, + -0.0042976438999176025, + 0.03786511346697807, + 0.01862393692135811, + -0.04082413762807846, + -0.05792848393321037, + 0.0068894242867827415, + 0.0024085959885269403, + 0.001471342402510345, + 0.030681759119033813, + -0.026314062997698784, + 0.0555737242102623, + 0.03169534355401993, + 0.0031395808327943087, + 0.018701769411563873, + -0.5604594945907593, + 0.01526441890746355, + -0.00621993700042367, + 0.0009401043644174933, + 0.01587403193116188, + 0.030135583132505417, + -0.007350685074925423, + 0.006527469493448734, + 0.016000108793377876, + -0.042957425117492676, + 0.018247080966830254, + 0.0025622656103223562, + -0.03169511258602142, + 0.09235119074583054, + -0.013365034945309162, + 0.01607452519237995, + 0.017734844237565994, + 0.05609896034002304, + 0.04819876700639725, + -0.0871855691075325, + 0.05157865956425667, + 0.009171447716653347, + 0.022200705483555794, + -0.005507844965904951, + -0.024452703073620796, + 0.010224574245512486, + -0.006914906669408083, + 0.004650818649679422, + 0.02167516015470028, + 0.10456826537847519, + -0.07652094960212708, + -6.050072988728061e-05, + 0.012855490669608116, + 0.022669879719614983, + 0.022655120119452477, + 0.033012885600328445, + 0.025709744542837143, + 0.00481270719319582, + 0.005920717027038336, + -0.08545156568288803, + -0.004363589454442263, + -0.01531639602035284, + 0.030760569497942924, + 0.02796284481883049, + -0.03690989315509796, + 0.044959694147109985, + -0.14276015758514404, + -0.0002254673163406551, + -0.15694372355937958, + 0.012381293810904026, + -0.021977441385388374, + 0.005496624857187271, + -0.035593707114458084, + -0.0950438603758812, + 0.03825876861810684, + 0.05915532633662224, + -0.023323312401771545, + 0.017213119193911552, + -0.03807183355093002, + 0.02619507722556591, + 0.02741156332194805, + 0.005847832188010216, + 0.0020307491067796946, + 0.025714349001646042, + -0.04780200496315956, + 0.010206928476691246, + -0.01345440000295639, + 0.029133174568414688, + -0.0014764482621103525, + 0.004046705551445484, + -0.007725241594016552, + 0.013041527941823006, + 0.0018969239899888635, + 0.002417983952909708, + -0.010975837707519531, + 0.0015862436266615987, + 0.00597577728331089, + 0.002882696921005845, + 0.02855525352060795, + -0.005954153370112181, + 0.04090835899114609, + -0.39500924944877625, + 0.03586621209979057, + -0.5250031352043152, + -0.05697731301188469, + -0.09568691998720169, + -0.07179264724254608, + 0.04683076590299606, + 0.009320023469626904, + -0.11629963666200638, + -0.0016945215174928308, + 0.01624997705221176, + -0.0063682254403829575, + 0.15033549070358276, + -0.5171176791191101, + -0.01525783073157072, + 0.016417231410741806, + -0.00303818890824914, + 0.2500321865081787, + 0.022074062377214432, + 0.01191191840916872, + 0.012274803593754768, + 0.016534989699721336, + -0.028437916189432144, + 0.04241323843598366, + -0.01824999786913395, + -0.34815871715545654, + 0.04734490439295769, + -0.06419701874256134, + -0.022288290783762932, + -0.0004865761147812009, + 0.05369419604539871, + -0.058212973177433014, + -0.2196469008922577, + 0.010950890369713306, + 0.029042819514870644, + -0.07349151372909546, + -0.0422789566218853, + 0.062069639563560486, + 0.05589267984032631, + 0.014877256006002426, + 0.04236084595322609, + 0.03975239768624306, + 0.16930873692035675, + 0.03981085494160652, + 0.11499395221471786, + 0.0271450225263834, + 0.013969083316624165, + -0.0002660648606251925, + 0.010936664417386055, + -0.18389767408370972, + -0.10237602889537811, + 0.03041323646903038, + -0.013864071108400822, + -0.015729930251836777, + 0.037400804460048676, + -0.009598327800631523, + -0.09533312171697617, + -0.014712700620293617, + 0.08537333458662033, + -0.007200485561043024, + -0.31139102578163147, + -0.06366845220327377, + 0.02039063163101673, + -0.023356139659881592, + -0.0029549277387559414, + -0.12494662404060364, + 0.011755092069506645, + -0.26468148827552795, + -0.11541861295700073, + 0.010529865510761738, + -0.05965733155608177, + -0.05945499241352081, + -0.08796169608831406, + -0.014683439396321774, + 0.008732054382562637, + 0.010073489509522915, + 0.09553763270378113, + 0.034884922206401825, + 0.018675342202186584, + -0.009549405425786972, + -0.0007051719003356993, + -0.16936513781547546, + -0.0030460187699645758, + -0.022060535848140717, + -0.06689190864562988, + 0.013926704414188862, + 0.012043816037476063, + -0.0587068572640419, + -0.03814113140106201, + 0.06235629320144653, + 0.013228330761194229, + 0.04154474660754204, + -0.08039120584726334, + 0.028436705470085144, + -0.042226389050483704, + -0.019135186448693275, + 0.03747033327817917, + -0.14261123538017273, + 0.02827540971338749, + 0.0455685593187809, + -0.031124960631132126, + -0.007588588632643223, + 0.0034326373133808374, + -0.07682976871728897, + 0.24654042720794678, + -0.014518304727971554, + -0.07052458822727203, + -0.08241941034793854, + -0.04116151109337807, + -0.048463717103004456, + -0.038745298981666565, + 0.036902472376823425, + 0.0442035011947155, + 0.05572585016489029, + -0.014312628656625748, + 0.010794793255627155, + -0.3440641760826111, + -0.5161325335502625, + 0.0005156552069820464, + -0.010257269255816936, + -0.02412656880915165, + -0.023385023698210716, + 0.05533458665013313, + -0.012186119332909584, + -0.029286568984389305, + 0.04116401448845863, + -0.044610101729631424, + -0.019175484776496887, + 0.06835268437862396, + 0.06366674602031708, + 0.0373748242855072, + 0.03804386034607887, + 0.05369521677494049, + -0.04451881721615791, + 0.0018838117830455303, + 0.34775662422180176, + 0.010958605445921421, + -0.047990139573812485, + 0.04386777803301811, + -0.10427688807249069, + 0.04417382925748825, + 4.402965714689344e-05, + 0.01935163326561451, + -0.06753949075937271, + 0.02735923044383526, + 0.01465953141450882, + 0.06198301538825035, + -0.015980403870344162, + -0.2108263075351715, + 0.008177559822797775, + 0.006046924740076065, + 0.002665479900315404, + 0.20868580043315887, + -0.013740362599492073, + 0.008203004486858845, + -0.005066391546279192, + 0.026405498385429382, + 0.01383009273558855, + 0.012581533752381802, + 0.009014940820634365, + 0.022820021957159042, + -0.008534795604646206, + 0.2603924572467804, + 0.02297227643430233, + -0.000749691273085773, + 0.044753506779670715, + 0.018596511334180832, + 0.006852792575955391, + -0.008686172775924206, + -0.10452616959810257, + 0.017021872103214264, + 0.003722329391166568, + -0.025453045964241028, + -0.011473417282104492, + -0.017907623201608658, + 0.01400628499686718, + -0.1670989990234375, + 0.004298652987927198, + -0.0022204748820513487, + 0.16521315276622772, + -0.008831127546727657, + 0.026490870863199234, + 0.006190746556967497, + -0.0177209060639143, + 0.08967147767543793, + 0.0033069502096623182, + -0.005021366756409407, + 0.0004906906979158521, + 0.0169216375797987, + -0.06124846637248993, + -0.005200678016990423, + 0.08404737710952759, + -0.010559299029409885, + -0.006309974938631058, + 0.023113396018743515, + -0.010227260179817677, + 0.001256447983905673, + 0.019783375784754753, + -0.006308461539447308, + -0.04529590904712677, + -0.00908862054347992, + -0.043217338621616364, + -0.32200074195861816, + 0.02592635713517666, + 0.030795685946941376, + -0.001814531977288425, + 0.0092842485755682, + 0.07088880985975266, + -0.0867588147521019, + 0.024099843576550484, + -0.0034031609538942575, + 0.007234686985611916, + -0.02505563199520111, + 0.0030480287969112396, + -0.019158190116286278, + 0.26473408937454224, + -0.011918547563254833, + -0.023240016773343086, + -0.06084466353058815, + -0.021916134282946587, + -0.010251260362565517, + -0.0009625791572034359, + 0.082605741918087, + -0.013018425554037094, + 0.007627277635037899, + -0.0010813736589625478, + 0.007952406071126461, + 0.06551267951726913, + -0.026020025834441185, + 0.050048135221004486, + -0.010610008612275124, + -0.02429312653839588, + -0.025263017043471336, + -0.04611891135573387, + 0.04451768472790718, + -0.08045025914907455, + -0.048037610948085785, + 0.008019295521080494, + 0.0160224549472332, + 0.002078550634905696, + -0.0202508345246315, + -0.5446130633354187, + 0.012585492804646492, + -0.0331973135471344, + 0.08371605724096298, + -0.00590998912230134, + -0.013058983720839024, + 0.027742384001612663, + 0.1042199358344078, + -0.3072803318500519, + 0.06284149736166, + -0.28551968932151794, + 0.026768438518047333, + 0.022245990112423897, + 0.018242113292217255, + -0.035077981650829315, + 0.03546127676963806, + 0.10165776312351227, + -0.025475669652223587, + -0.014933750964701176, + 0.040547240525484085, + -0.033055808395147324, + 0.011755919083952904, + -0.014459444209933281, + -0.03455093130469322, + 0.020743343979120255, + 0.02720930427312851, + -0.287664532661438, + 0.008260028436779976, + -0.009877690114080906, + 0.16657423973083496, + -0.010943812318146229, + -0.012381386943161488, + 0.030678801238536835, + 0.1559792459011078, + 0.038967035710811615, + -0.023399239405989647, + 0.015019542537629604, + -0.014201333746314049, + -0.014202176593244076, + -0.006699408870190382, + -0.13175444304943085, + 0.004643211141228676, + 0.012747463770210743, + -0.04086190089583397, + 0.06581410765647888, + -0.12192045897245407, + -0.03126347437500954, + 0.011175516061484814, + -0.00914736744016409, + -0.02883930690586567, + -0.11305265873670578, + -0.04405384883284569, + -0.009120048955082893, + -0.008926079608500004, + -0.03169447183609009, + 0.05464877560734749, + 0.25674498081207275, + 0.08497058600187302, + -0.023222925141453743, + 0.35592252016067505, + -0.006929511670023203, + 0.025255810469388962, + -0.05150032415986061, + 0.039239466190338135, + -0.07082924991846085, + -0.017321549355983734, + 0.17293211817741394, + -0.02155853807926178, + -0.014333213679492474, + 0.0031305316369980574, + -0.013490653596818447, + -0.1376512199640274, + -0.021713266149163246, + -0.029826253652572632, + -0.0011473714839667082, + -0.012434332631528378, + -0.04860873892903328, + 0.013857590034604073, + 0.0703854188323021, + 0.034528713673353195, + -0.014423011802136898, + 0.0882454589009285, + -0.091700978577137, + 0.038885727524757385, + 0.012043441645801067, + -0.03183690831065178, + -0.014495689421892166, + -0.019726552069187164, + -0.010094117373228073, + -0.004218627233058214, + -0.04413086175918579, + -0.1344134360551834, + -0.0004976870259270072, + -0.0008357573533430696, + 0.04518067091703415, + 0.046797975897789, + 0.24766182899475098, + 0.01065139751881361, + -0.0034267394803464413, + -0.016103556379675865, + -0.05139121413230896, + 0.012563390657305717, + -0.03310413286089897, + -0.030157553032040596, + 0.046670909970998764, + 0.012565785087645054, + -0.040275491774082184, + 0.023816417902708054, + -0.38536572456359863, + 0.04508889466524124, + 0.13637560606002808, + -0.010654824785888195, + 0.0459851399064064, + -0.0046302699483931065, + -0.020852191373705864, + 0.10662271827459335, + 0.06486576050519943, + 0.05727925896644592, + 0.09816201776266098, + 0.04878557100892067, + -0.16256237030029297, + 0.014547038823366165, + 0.018567964434623718, + -0.07284612208604813, + 0.017150163650512695, + 0.0246741883456707, + -0.38470372557640076, + -0.07465949654579163, + 0.03010236658155918, + -0.004397575277835131, + -0.06618984788656235, + -0.02908281609416008, + 0.060166433453559875, + -0.0020949048921465874, + 0.007689109072089195, + -0.0047390698455274105, + -0.014199030585587025, + -0.01794746331870556, + -0.02528063952922821, + 0.002218312583863735, + 0.10169881582260132, + 0.010602130554616451, + -0.06605861335992813, + -0.0008762837387621403, + -0.035027723759412766, + -0.011684391647577286, + 0.02247578091919422, + 0.17245104908943176, + 0.22525252401828766, + -0.010771296918392181, + 0.05595310404896736, + 0.06338834017515182, + -0.0038216698449105024, + -0.0032836494501680136, + 0.005779017228633165, + -0.18020786345005035, + -0.05066698044538498, + -0.0035458216443657875, + -0.10578767210245132, + -0.041712939739227295, + 0.2104150652885437, + -0.03753345459699631, + 0.013989892788231373, + 0.01988149993121624, + 0.05108603090047836, + 0.04496738687157631, + -0.3034508526325226, + 0.0226743221282959, + -0.0431472510099411, + -0.025635428726673126, + -0.18961989879608154, + -0.17218825221061707, + 0.03576141223311424, + 0.060613714158535004, + -0.011970550753176212, + -0.21435107290744781, + 0.01422552578151226, + 0.02974064089357853, + -0.061079952865839005, + 0.031064646318554878, + 0.009629320353269577, + -0.13762925565242767, + 0.01928475871682167, + 0.007310172542929649, + 0.06103459745645523, + -0.16216528415679932, + 0.03330384939908981, + 0.09578404575586319, + -0.0037327276077121496, + 0.029233848676085472, + -0.0015759399393573403, + 0.005511409603059292, + -0.4195749759674072, + 0.024169376119971275, + 0.13220365345478058, + 0.007961929775774479, + 0.008045470342040062, + 0.01919495314359665, + -0.023188553750514984, + 0.07084394991397858, + -0.24922333657741547, + 0.02011212892830372, + -0.18514998257160187, + 0.03114209696650505, + 0.09826567023992538, + 0.00592303741723299, + -0.010020115412771702, + 0.027117054909467697, + -0.214133620262146, + -0.01214816514402628, + 0.06564164906740189, + 0.02513044886291027, + 0.02132420241832733, + -0.02127540111541748, + -0.041606876999139786, + 0.04196378216147423, + -0.02060609683394432, + 0.01730814389884472, + -0.17418994009494781, + 0.03462710976600647, + -0.017470642924308777, + -0.3992193639278412, + 0.02652592957019806, + 0.025042008608579636, + 0.026447610929608345, + -0.19199316203594208, + 3.27593952533789e-05, + 0.002988220192492008, + -0.21171888709068298, + 0.03300239518284798, + 0.015727035701274872, + -0.008947308175265789, + 0.03924538940191269, + -0.08990193158388138, + 0.023726975545287132, + 0.03463870286941528, + -0.05018220469355583, + 0.13170146942138672, + 0.054000236093997955, + 0.01158218178898096, + 0.062349993735551834, + -0.014724616892635822, + 0.039657603949308395, + 0.04436490684747696, + 0.014076294377446175, + 0.07666806876659393, + 0.09630247205495834, + -0.04152659326791763, + -0.1860806941986084, + -0.07671733945608139, + 0.031573690474033356, + -0.44617798924446106, + -0.004897239152342081, + -0.03991628438234329, + 0.01880800537765026, + -0.04769768565893173, + 0.02198435738682747, + 0.01341161783784628, + -0.12239313870668411, + 0.019765935838222504, + 0.005221452098339796, + -0.025201082229614258, + 0.005132562946528196, + 0.08668412268161774, + 0.0035341952461749315, + 0.008583099581301212, + 0.032979920506477356, + 0.03324040770530701, + 0.04411708936095238, + -0.008390798233449459, + 0.040486790239810944, + -0.059673551470041275, + 0.02003314346075058, + -0.0990666076540947, + 0.03971675783395767, + 0.012021057307720184, + 0.0017271327087655663, + 0.01818535290658474, + 0.0025106174871325493, + 0.043714240193367004, + 0.019146842882037163, + -0.0041794623248279095, + 0.033447377383708954, + 0.06863203644752502, + -0.004350902978330851, + 0.0113364327698946, + -0.05825724080204964, + -0.04649435728788376, + -0.10618306696414948, + 0.02653644233942032, + 0.012514552101492882, + 0.019399365410208702, + -0.0022177041973918676, + 0.017741208896040916, + 0.04115311801433563, + 0.05122101679444313, + 0.055051617324352264, + 0.01687677949666977, + -0.03698579967021942, + 0.10053858160972595, + -0.007528421934694052, + 0.003968802746385336, + 0.02458524890244007, + -0.02144794538617134, + 0.026791265234351158, + -0.016701897606253624, + 0.014119372703135014, + -0.03460531681776047, + -0.02320348098874092, + 0.056146953254938126, + 0.028700685128569603, + -0.14820916950702667, + -0.016996873542666435, + 0.025667931884527206, + 0.08408629894256592, + 0.00034475952270440757, + 0.007573155220597982, + 0.06784884631633759, + 0.025982951745390892, + -0.08363039791584015, + -0.015748541802167892, + -0.0029514851048588753, + -0.01523523684591055, + 0.10500328987836838, + 0.3070858418941498, + -0.024624783545732498, + 0.0058471946977078915, + -0.039751242846250534, + 0.0012745993444696069, + -0.0796508714556694, + 0.024727927520871162, + 0.056764136999845505, + -0.013338261283934116, + -0.04794292524456978, + -0.02609768509864807, + -0.010784422047436237, + -0.048712026327848434, + 0.020345501601696014, + 0.0021618579048663378, + -0.0021724768448621035, + 0.03056410700082779, + -0.01633712649345398, + -0.47168225049972534, + -0.014639903791248798, + -0.012550815008580685, + 0.03358187526464462, + 0.07889427989721298, + -0.03615899011492729, + -0.002809660043567419, + -0.006953644100576639, + 0.02024337276816368, + -0.0738825723528862, + -0.006984011270105839, + -0.04472561925649643, + -0.027498915791511536, + 0.07207506150007248, + -0.09166522324085236, + -0.008861960843205452, + 0.05264359340071678, + 0.01889069564640522, + -0.1380404680967331, + -0.010141258127987385, + 0.015403619967401028, + -0.16416165232658386, + -0.03529815003275871, + 0.042106859385967255, + 0.11173021793365479, + -0.3143587112426758, + 0.011045016348361969, + 0.0012351945042610168, + 0.03840603306889534, + 0.0685538575053215, + -0.000746160454582423, + -0.028142500668764114, + 0.027154160663485527, + 0.005731801502406597, + 0.04433267563581467, + -0.8158469796180725, + 0.02226361259818077, + -0.07650655508041382, + 0.026958195492625237, + -0.005810025613754988, + -0.020102059468626976, + -0.0019310436910018325, + 0.07697021961212158, + -0.057701658457517624, + 0.05954534560441971, + 0.0027106746565550566, + -0.06311310827732086, + 0.011713752523064613, + -0.0034454476553946733, + -0.0006881420267745852, + 0.08937360346317291, + -0.0008253820124082267, + -0.031066063791513443, + -0.14708301424980164, + -0.04438449814915657, + 0.004772413522005081, + 0.05992274731397629, + 0.07473544776439667, + -0.1784757375717163, + -0.19057415425777435, + -0.014637955464422703, + -0.24898527562618256, + 0.13606221973896027, + -0.018039124086499214, + -0.047193415462970734, + -0.06526428461074829, + 0.04075757786631584, + 0.049901530146598816, + -0.008585861884057522, + 0.01616351678967476, + -3.091737016802654e-05, + 0.024283329024910927, + 0.008861682377755642, + -0.0005823548417538404, + 0.0997646301984787, + 0.051001910120248795, + 0.009473294951021671, + -0.0032046104315668344, + 0.018362928181886673, + 0.008627718314528465, + -0.4148157835006714, + -0.016077928245067596, + 0.0745391696691513, + 0.00724065862596035, + 0.08948155492544174, + 0.11626332253217697, + -0.052439428865909576, + 0.005599102005362511, + 0.002622961765155196, + 0.07586965709924698, + 0.03274847939610481, + -0.02099076844751835, + -0.04666733741760254, + -0.0013019372709095478, + 0.04945925995707512, + 0.11393380910158157, + 0.006346395239233971, + 0.04721064493060112, + 0.010331138968467712, + 0.08918803185224533, + 0.04288423806428909, + -0.09234773367643356, + 0.020141584798693657, + -3.256054696976207e-05, + -0.02799108810722828, + 0.018966441974043846, + -0.4136410355567932, + -0.07217283546924591, + 0.01840362884104252, + -0.055327851325273514, + 0.003275467548519373, + -0.017174070701003075, + -0.032178670167922974, + 0.09021560847759247, + -0.524413526058197, + 0.01994725503027439, + 0.10380692034959793, + -0.01043684035539627, + -0.00011200909648323432, + 0.01331041194498539, + 0.020127851516008377, + -0.025159789249300957, + 0.05252581834793091, + 0.04759140685200691, + 0.0032084162812680006, + -0.03579062595963478, + 0.054719552397727966, + -0.04674411937594414, + 0.028389262035489082, + 0.001127603929489851, + -0.0006243048119358718, + -0.00550495833158493, + -0.022523507475852966, + -0.024282312020659447, + 0.009519628249108791, + -0.39908328652381897, + -0.009265545755624771, + -0.00037090369733050466, + 0.06425131112337112, + -0.05998316407203674, + -0.015221518464386463, + -0.004825026262551546, + 0.11847284436225891, + -0.011302731931209564, + -0.006884834263473749, + -0.04678218811750412, + -0.012078279629349709, + 0.021638741716742516, + -0.016819776967167854, + -0.009127719327807426, + -0.002491263672709465, + 0.0016752213705331087, + -0.016600262373685837, + 0.011772023513913155, + -0.013447183184325695, + -0.020662957802414894, + -0.011593316681683064, + 0.008270744234323502, + -0.0026990456972271204, + -0.004406482446938753, + -0.023110052570700645, + -0.00208942755125463, + -0.1711198389530182, + 0.012432538904249668, + -0.0045453268103301525, + 0.024807902052998543, + -0.0035043740645051003, + -0.004001997876912355, + -0.013488625176250935, + -0.02020987868309021, + -0.01216109935194254, + -0.004432092886418104, + 0.09323672950267792, + -0.015641510486602783, + -0.019307948648929596, + 0.01117538008838892, + -0.01422040443867445, + 0.01705607771873474, + -0.0029596879612654448, + -0.0021530911326408386, + -0.006551788654178381, + 0.00429268553853035, + -0.1620807945728302, + -0.014128226786851883, + -0.005428737495094538, + -0.006771362852305174, + 0.005730633158236742, + 0.0007243106956593692, + 0.0024031582288444042, + -0.00199915561825037, + 0.006133859045803547, + -0.013380909338593483, + 0.00733462069183588, + -0.001863821060396731, + -0.0020169683266431093, + -0.014070986770093441, + -0.006501683499664068, + -0.029421553015708923, + 0.0009377509704791009, + -0.01718256250023842, + -0.05819401144981384, + -0.018859732896089554, + 0.0010356366401538253, + 0.006394123658537865, + -0.021985618397593498, + -0.01204769592732191, + -0.002014884725213051, + -0.019398409873247147, + -0.013122898526489735, + -0.017277296632528305, + -0.002270353492349386, + -0.05294327810406685, + -0.020317314192652702, + -0.018196573480963707, + -0.010375416837632656, + -0.019704729318618774, + -0.016109557822346687, + -0.0167380403727293, + -0.0285252146422863, + -0.02665277197957039, + -0.03554505482316017, + -0.00741522666066885, + -0.013580105267465115, + -0.026335405185818672, + -0.011694515123963356, + -0.004639182705432177, + -0.03996071219444275, + -0.022463932633399963, + -0.007204636000096798, + -0.021065134555101395, + -0.014410646632313728, + 0.0035447971895337105, + -0.0013098351191729307, + -0.024171002209186554, + 0.00047751085367053747, + -0.01870289072394371, + -0.06016797944903374, + -0.025703946128487587, + -0.009730588644742966, + -0.021792838349938393, + -0.024519823491573334, + -0.01843440905213356, + -0.0016325484029948711, + -0.008116388693451881, + -0.017774557694792747, + -0.04375867918133736, + -0.03893980756402016, + -0.018188582733273506, + -0.007122726645320654, + -0.028115490451455116, + -0.01821342669427395, + -0.01011319737881422, + -0.02616124413907528, + -0.013797983527183533, + -0.03202736750245094, + -0.030110370367765427, + -0.01883666031062603, + -0.01185502391308546, + -0.006012012716382742, + -0.017311619594693184, + -0.022577986121177673, + -0.02101938985288143, + 0.0025952248834073544, + -0.005058783106505871, + -0.004162575118243694, + -0.01559755764901638, + -0.017923563718795776, + -0.04231095686554909, + -0.017630560323596, + -0.011938830837607384, + -0.01587115228176117, + 0.004972478374838829, + -0.016601158306002617, + 0.15419845283031464, + 0.0009241115767508745, + 0.051028184592723846, + 0.008128340356051922, + -0.019917558878660202, + -0.0010339801665395498, + 0.022349294275045395, + -0.0072520882822573185, + 0.0017750378465279937, + -0.10526080429553986, + 0.03420695662498474, + 0.019183926284313202, + -0.0006544998032040894, + -0.0032203509472310543, + -0.01216941885650158, + -0.03561796247959137, + 0.024905826896429062, + -0.026948239654302597, + -0.01913355104625225, + -0.014459407888352871, + 0.006972283590584993, + -0.033184293657541275, + 0.04884861409664154, + -0.002296984428539872, + -0.19194477796554565, + 0.00392142403870821, + 0.009490449912846088, + -0.02687196619808674, + -0.06327224522829056, + -0.03684951737523079, + -0.0002613202668726444, + -0.012086644768714905, + 0.03630973398685455, + 0.007296048104763031, + 0.011186012998223305, + 0.0074085514061152935, + -0.020394617691636086, + -0.010585476644337177, + -0.030289918184280396, + 0.0773506686091423, + 0.008841303177177906, + 0.019423579797148705, + 0.001184571417979896, + 0.005553434602916241, + 0.015373414382338524, + -0.0027953842654824257, + 0.013204757124185562, + 0.029097743332386017, + 0.012627501040697098, + 0.02102004364132881, + -0.09469914436340332, + -0.023324014618992805, + 0.029243655502796173, + 0.002979277865961194, + -0.004492263309657574, + 0.20549021661281586, + -0.3244459927082062, + 0.025892559438943863, + 0.009620796889066696, + -0.05520407855510712, + -0.02271144650876522, + 0.008378816768527031, + -0.0671214610338211, + -0.016056722030043602, + -0.02355658821761608, + 0.0005429868469946086, + -0.007960098795592785, + 0.02513299137353897, + -0.13005328178405762, + -0.0025323680602014065, + -0.02197088487446308, + -0.02404806576669216, + 0.08261960744857788, + 0.17078880965709686, + 0.02880753017961979, + -0.03642067685723305, + 0.021994341164827347, + -0.012368184514343739, + -0.10681373625993729, + 0.16371481120586395, + 0.17881983518600464, + -0.10202010720968246, + -0.08641688525676727, + -0.1259487271308899, + 0.06907707452774048, + 0.023792706429958344, + -0.02534419298171997, + 0.016984017565846443, + -0.06743635982275009, + 0.08445960283279419, + -0.08037827908992767, + -0.11935994029045105, + -0.31716489791870117, + -0.01860150322318077, + 0.060669515281915665, + -0.06137414649128914, + 0.09878886491060257, + 0.01794014871120453, + 0.12382296472787857, + -0.016424886882305145, + 0.09045679122209549, + -0.02998783066868782, + -0.00972777884453535, + -0.024124544113874435, + 0.09879253059625626, + 0.05500243604183197, + -0.06635259836912155, + 0.11268552392721176, + 0.011751363053917885, + -0.04690232127904892, + -0.025168607011437416, + 0.088335320353508, + -0.1140628531575203, + 0.04129032790660858, + -0.04258979484438896, + -0.0903872698545456, + 0.008473021909594536, + -0.026690304279327393, + -0.051559556275606155, + -0.05481572076678276, + -0.05251916125416756, + -0.0018165932269766927, + 0.09836867451667786, + 0.0054859439842402935, + 0.06432581692934036, + 0.10621821135282516, + -0.019325286149978638, + -0.028727786615490913, + 0.014013150706887245, + -0.008022608235478401, + -0.006281842477619648, + -0.0297000203281641, + 0.01525485422462225, + -0.4346403479576111, + 0.07787995040416718, + -0.25380268692970276, + 0.05261845141649246, + 0.010875157080590725, + 0.0014149334747344255, + 0.05021188035607338, + -0.24382442235946655, + 0.0807114690542221, + 0.022907381877303123, + 0.006440790370106697, + -0.017028095200657845, + 0.001552293193526566, + 0.05961666256189346, + -0.14113056659698486, + 0.03398876264691353, + -0.005411976482719183, + -0.014025667682290077, + -0.5433799624443054, + 0.019015472382307053, + 0.04091138765215874, + 0.05059061944484711, + 0.0274446289986372, + -0.010288042947649956, + -0.001335533568635583, + -0.013533512130379677, + 0.018798377364873886, + -0.04099345579743385, + 0.0031264263670891523, + -0.21071769297122955, + -0.014384736306965351, + -0.1045387014746666, + -0.014340974390506744, + 0.001986369490623474, + -0.04118456318974495, + -0.10952988266944885, + 0.049147430807352066, + -0.08382093161344528, + -0.1741400957107544, + -0.0885215476155281, + -0.10934099555015564, + 0.05553343519568443, + 0.02434251271188259, + 0.006634524557739496, + -0.0017163373995572329, + 0.0185443926602602, + 0.06250902265310287, + -0.17145656049251556, + -0.07543934881687164, + 0.026583310216665268, + 0.01634727604687214, + 0.003603539662435651, + -0.2817271649837494, + 0.03882112354040146, + 0.011341865174472332, + 0.00826666783541441, + 0.050427842885255814, + -0.22358834743499756, + 0.06419781595468521, + 0.03245265409350395, + -0.04503164440393448, + -0.023194484412670135, + -0.027968740090727806, + 0.08563586324453354, + 0.07954753190279007, + -0.08513130992650986, + 0.02850884199142456, + 0.008976672776043415, + 0.07886530458927155, + 0.0022273347713053226, + -0.09540755301713943, + 0.032016951590776443, + -0.05196075513958931, + 0.10555616766214371, + 0.07629868388175964, + 0.039732079952955246, + -0.0029798501636832952, + 0.014692343771457672, + 0.09200941026210785, + -0.04299614951014519, + -0.023488566279411316, + -0.01851060427725315, + 0.09257487207651138, + 0.055612049996852875, + 0.06423109769821167, + -0.28587806224823, + -0.09950444847345352, + 0.10397437959909439, + 0.025166453793644905, + -0.03235514089465141, + -0.033381711691617966, + 0.1513858139514923, + 0.06468874961137772, + 0.01928441785275936, + 0.0032701045274734497, + -0.0579083226621151, + -0.022929169237613678, + 0.012971373274922371, + -0.018524186685681343, + -0.06484643369913101, + 0.012233717367053032, + 0.06590451300144196, + -0.04558677598834038, + 0.05253027006983757, + 0.048656731843948364, + -0.2288871705532074, + 0.037114787846803665, + -0.20519588887691498, + 0.0058607361279428005, + -0.002009372925385833, + -0.006671734619885683, + -0.07107856124639511, + -0.07407436519861221, + 0.03941629081964493, + 0.0447598397731781, + 0.03509354963898659, + -0.061107732355594635, + -0.09305761009454727, + -0.012180411256849766, + 0.04902016744017601, + 0.07974442094564438, + -0.016854410991072655, + 0.005089411046355963, + -0.08127597719430923, + 0.03258403390645981, + 0.039813362061977386, + -0.01668727956712246, + 0.027226485311985016, + -0.029213925823569298, + -0.008598217740654945, + 0.00931101106107235, + 0.026936721056699753, + -0.03083401545882225, + -0.05799110606312752, + -0.008277476765215397, + -0.014854338951408863, + -0.20012643933296204, + 0.012290815822780132, + 0.007194168865680695, + 0.06858328729867935, + -0.3296163082122803, + -0.11424986273050308, + 0.009912200272083282, + -0.06211454048752785, + 0.0007546336855739355, + 0.03507614880800247, + 0.10649498552083969, + -0.03036407195031643, + 0.0646015852689743, + -0.01595110446214676, + -0.16919563710689545, + 0.0013865949586033821, + -0.08339446783065796, + 0.06962471455335617, + 0.016058098524808884, + -0.04729780554771423, + 0.010602935217320919, + 0.01470863912254572, + 0.06903938204050064, + 0.014901719056069851, + -0.15120048820972443, + 0.016727851703763008, + 0.05003673583269119, + 0.04370126873254776, + 0.029703885316848755, + 0.021875420585274696, + 0.026293285191059113, + -0.01048936415463686, + 0.00040810942300595343, + -0.015616541728377342, + -0.062451593577861786, + 0.010016348212957382, + -0.06790193170309067, + -0.02077331207692623, + 0.007985175587236881, + -0.04435744881629944, + 0.06920231133699417, + 0.018344474956393242, + 0.028591370210051537, + 0.021957838907837868, + 0.0017570338677614927, + 0.036665257066488266, + 0.015438515692949295, + -0.0006347382441163063, + 0.04621066153049469, + -0.001942177303135395, + 0.010664877481758595, + -0.016754357144236565, + 0.006541184149682522, + -0.027716301381587982, + -0.0058586387895047665, + -0.005346015095710754, + 0.020482052117586136, + 0.06882552057504654, + 0.0026622572913765907, + 0.016321638599038124, + 0.017728103324770927, + -0.13356441259384155, + 0.030281176790595055, + 1.0354949154134374e-05, + 0.050639618188142776, + 0.0013030078262090683, + -0.11136802285909653, + -0.006832807790488005, + -0.09628921747207642, + 0.046699415892362595, + 0.002175685251131654, + 0.008100612089037895, + 0.012449901551008224, + -0.01713990420103073, + -0.000769267207942903, + 0.022544430568814278, + -0.0018787183798849583, + -0.014189678244292736, + 0.37042510509490967, + -0.030317893251776695, + 0.012663356028497219, + -0.04071582853794098, + 0.01653047651052475, + 0.06578584760427475, + 0.005606585182249546, + 0.0029362838249653578, + -0.02035594917833805, + 0.016131827607750893, + -0.06512665003538132, + 0.020292088389396667, + 0.12818951904773712, + -0.00017647731874603778, + 0.0004811069811694324, + 0.013025660999119282, + -0.006004344671964645, + 0.011330580338835716, + 0.0021733916364610195, + -0.0026290342211723328, + 0.008579215034842491, + -0.017107143998146057, + 0.0032798980828374624, + 0.21415431797504425, + -0.011049880646169186, + 0.04915957152843475, + -0.01152863260358572, + 0.01988764852285385, + -0.30189022421836853, + 0.1491061896085739, + 0.022540517151355743, + 0.02323656715452671, + -0.0028044115751981735, + -0.02501249685883522, + 0.0016759912250563502, + 0.023405946791172028, + 0.0865691602230072, + 0.0056661744602024555, + 0.2334042638540268, + -0.05771901085972786, + 0.03428330272436142, + -0.05191519856452942, + 0.025708407163619995, + -0.11474912613630295, + 0.05345827341079712, + 0.050046734511852264, + -0.03785427287220955, + 0.02726786397397518, + 0.008640051819384098, + -0.05810163915157318, + 0.19147679209709167, + 0.12065602838993073, + -0.08667072653770447, + -0.12831886112689972, + 0.027053257450461388, + -0.1771622896194458, + -0.2615586817264557, + 0.112942636013031, + 0.002398239215835929, + 0.00907410029321909, + 0.059947770088911057, + 0.040937639772892, + 0.003431124845519662, + 0.012721046805381775, + -0.10228776186704636, + 0.04169567674398422, + -0.04826785624027252, + -0.021415220573544502, + 0.027615519240498543, + 0.16087181866168976, + 0.03552674129605293, + -0.36409878730773926, + 0.0015418739058077335, + 0.03940089792013168, + -0.12929502129554749, + 0.017082052305340767, + -0.07193783670663834, + 0.10395099222660065, + -0.2240910828113556, + -0.003303584409877658, + -0.0074868109077215195, + -0.13708709180355072, + 0.2098008245229721, + 0.013808795250952244, + -0.03606148064136505, + 0.001965852687135339, + 0.04186573252081871, + 0.02105732634663582, + -0.11873909085988998, + -0.08529136329889297, + 0.0060731275007128716, + 0.04803553968667984, + 0.07665349543094635, + 0.026997262611985207, + 0.05191565304994583, + 0.09013131260871887, + 0.013081093318760395, + 0.04667182266712189, + -0.19899451732635498, + 0.004642056301236153, + 0.0025570227298885584, + -0.2640555500984192, + 0.008254006505012512, + 0.05971720814704895, + -0.002980671590194106, + 0.0011313167633488774, + -0.004445134196430445, + 0.01951296627521515, + -0.006634386721998453, + -0.008033698424696922, + 0.012400158680975437, + -0.15906694531440735, + 0.007047838997095823, + 0.0003521084145177156, + -0.00517050176858902, + -0.0003226286207791418, + -0.01226231548935175, + -0.06750697642564774, + -0.03061128593981266, + -0.0027100055012851954, + 0.004726986400783062, + 0.010185977444052696, + 0.021205933764576912, + -0.05105980113148689, + -0.006725164130330086, + 0.26042309403419495, + 0.003935054875910282, + 0.009450466372072697, + -0.009512278251349926, + 0.036205559968948364, + 0.0066987741738557816, + 0.05687355250120163, + -0.0070350514724850655, + 0.021287698298692703, + 0.004246287513524294, + -0.004053668584674597, + 0.0030501342844218016, + -0.003596516093239188, + 0.00571554945781827, + 0.039099883288145065, + 0.06648323684930801, + 0.011140268296003342, + 0.002779693342745304, + 0.0004113377653993666, + 0.0019621821120381355, + 0.002047213725745678, + -9.034215327119455e-05, + 0.006674906238913536, + -0.024464793503284454, + 4.372629337012768e-05, + 0.04560312256217003, + 0.029951298609375954, + 0.0053787860088050365, + 0.010052027180790901, + 0.0018156497972086072, + 0.001613074098713696, + -0.3710610568523407, + 0.18385423719882965, + 0.0197732076048851, + -2.409513217571657e-05, + 0.043657880276441574, + 0.029824273660779, + -0.0015272254822775722, + -0.0009817760437726974, + 0.030571524053812027, + 0.05133187025785446, + 0.021092001348733902, + -0.022430723533034325, + -0.011050102300941944, + -0.01653454266488552, + 0.00856624636799097, + 0.007617316208779812, + 0.023697074502706528, + -0.00541776092723012, + -0.06940567493438721, + -0.024501511827111244, + 0.0029131292831152678, + 0.005110545549541712, + 0.02394089475274086, + 0.009317552670836449, + -0.05198051407933235, + -0.14872707426548004, + -0.03553030639886856, + 0.05354774370789528, + 0.053996339440345764, + 0.016679847612977028, + -0.4505158066749573, + 0.006403166800737381, + -0.014287465251982212, + 0.010499212890863419, + 0.00510875741019845, + 0.0230255089700222, + -0.04791099205613136, + -0.08405473828315735, + -0.00807158276438713, + -0.016310568898916245, + -0.018034789711236954, + -0.03381670266389847, + 0.038599055260419846, + 0.01189411524683237, + 0.0038598189130425453, + 0.0077203805558383465, + -0.0006835742969997227, + 0.3038807809352875, + 0.00930703990161419, + -0.017654214054346085, + -0.029550395905971527, + 0.0014829621650278568, + -0.010562432929873466, + -0.011867706663906574, + -0.008104459382593632, + 0.008003979921340942, + -0.028282882645726204, + 0.00898829661309719, + -0.04963170364499092, + 0.014971665106713772, + 0.028662119060754776, + 0.055792808532714844, + 0.018142173066735268, + 0.029526766389608383, + 0.04726170003414154, + 0.020290115848183632, + -0.01347910612821579, + -0.027794860303401947, + -0.033374592661857605, + 0.05699307844042778, + -0.005888971965759993, + 0.009723466821014881, + 0.011825029738247395, + 0.0005665962235070765, + -0.22433574497699738, + 0.04777664318680763, + 0.054696254432201385, + 0.06447272002696991, + 0.006656138692051172, + -0.2656468152999878, + -0.006602808367460966, + -0.04309352487325668, + 0.024392882362008095, + -0.046948980540037155, + 0.17317010462284088, + -0.014694501645863056, + 0.09150391072034836, + 0.05414793640375137, + -0.0034523033536970615, + -0.029682809486985207, + -0.11646991223096848, + 0.036394182592630386, + -0.008510537445545197, + -0.09555189311504364, + 0.012331446632742882, + 0.022554755210876465, + 0.037040166556835175, + 0.011939534917473793, + -0.035405583679676056, + -0.008284371346235275, + 0.008629710413515568, + -0.0017152110813185573, + -0.01656493730843067, + 0.02205522358417511, + -0.008015291765332222, + -0.02198217809200287, + -0.08165504783391953, + 0.018647879362106323, + 0.010489191859960556, + 0.0009643095545470715, + 0.08301698416471481, + 0.00795030314475298, + -0.08973152935504913, + 0.05324552580714226, + 0.0187348835170269, + 0.00770497927442193, + 0.016434336081147194, + 0.0031714467331767082, + 0.031489044427871704, + -0.01682765781879425, + -0.0006042059976607561, + 0.006229344755411148, + 0.0031935630831867456, + -0.03694210946559906, + -0.027148112654685974, + 0.03319454565644264, + 0.013541879132390022, + 0.04362545907497406, + 0.010766182094812393, + 0.01287879142910242, + 0.02723391354084015, + 0.01831277459859848, + -0.0028144901152700186, + 0.0317537821829319, + -0.05053209140896797, + 0.03341667726635933, + 0.009338690899312496, + 0.030376508831977844, + 0.028512636199593544, + 0.002190604107454419, + 0.031132254749536514, + 0.04174429178237915, + 0.025147251784801483, + 0.02602408640086651, + 0.022863827645778656, + 0.024160150438547134, + 0.04043813422322273, + 0.011693909764289856, + 0.008020071312785149, + 0.010814648121595383, + 0.014862221665680408, + 0.043966785073280334, + 0.04133215174078941, + 0.03920775279402733, + 0.02128027193248272, + -0.0024078795686364174, + 0.03185494989156723, + 0.030951442196965218, + 0.008766901679337025, + -0.0013500713976100087, + 0.012680909596383572, + 0.01911563239991665, + 0.02226334996521473, + 0.03873631730675697, + 0.005242412444204092, + 0.02335301972925663, + 0.00577192846685648, + 0.0019918885082006454, + 0.019501060247421265, + 0.048295676708221436, + 0.027288099750876427, + 0.03500128164887428, + 0.032504353672266006, + 0.03619033470749855, + 0.022762063890695572, + 0.014124974608421326, + 0.04055529460310936, + 0.040181197226047516, + 0.04843837395310402, + 0.019578352570533752, + 0.04370861127972603, + 0.024640914052724838, + 0.027013463899493217, + 0.04700532928109169, + 0.018523193895816803, + 0.03569294884800911, + 0.031140455976128578, + 0.010298499837517738, + 0.03979840502142906, + 0.015059049241244793, + 0.020604899153113365, + 0.010335667058825493, + 0.02557498589158058, + 0.015946611762046814, + 0.018900645896792412, + 0.05494159087538719, + 0.015756357461214066, + 0.0452926866710186, + 0.04820817708969116, + -0.0183499027043581, + 0.04002442955970764, + -0.08226092159748077, + -0.034417178481817245, + 0.059122342616319656, + 0.028960591182112694, + -0.020427608862519264, + -0.043222296983003616, + 0.023134637624025345, + -0.014232538640499115, + -0.06970997899770737, + -0.0035826240200549364, + -0.015384080819785595, + -0.0695020854473114, + 0.03645527362823486, + 0.013986784033477306, + -0.027729706838726997, + -0.05711805075407028, + -0.0763891413807869, + -0.16338491439819336, + -0.02358265034854412, + -0.004730133805423975, + 0.022057903930544853, + -0.011578230187296867, + 0.040772147476673126, + -0.059327173978090286, + -0.03819728270173073, + -0.050089117139577866, + -0.005152902565896511, + -0.3071111738681793, + -0.010683669708669186, + 0.030922774225473404, + 0.08924981951713562, + 0.005679265595972538, + 0.06334424018859863, + 0.016136568039655685, + -0.02575727365911007, + -0.012562219053506851, + 0.007206748705357313, + -0.1373208612203598, + -0.010450832545757294, + -0.05991309881210327, + -0.006700845435261726, + -0.006468744482845068, + -0.02040017955005169, + -0.010068708099424839, + 0.008442427963018417, + 0.012259873561561108, + -0.002103718463331461, + -0.019605906680226326, + -0.010690353810787201, + 0.0005222380859777331, + -0.015031278133392334, + -0.012983204796910286, + -0.03552224859595299, + -0.007792052812874317, + -0.035602111369371414, + -0.03479204699397087, + -0.02480080910027027, + -0.05733964219689369, + 4.38804054283537e-05, + -0.021825626492500305, + -0.03287259489297867, + -0.05437042564153671, + -0.007981077767908573, + 0.023045696318149567, + 0.05785335600376129, + 0.03685669228434563, + 0.04314129799604416, + -0.005843586288392544, + -0.024806369096040726, + -0.02562016434967518, + 0.0015172295970842242, + -0.01568800024688244, + -0.005925294477492571, + 0.010173594579100609, + 0.06834683567285538, + 0.024159085005521774, + -0.009547322988510132, + 0.014080812223255634, + 0.013578452169895172, + 0.035671167075634, + 0.01240566186606884, + -0.021352441981434822, + 0.05245270952582359, + -0.008943279273808002, + -0.010131126269698143, + 0.02976749651134014, + 0.0600045844912529, + 0.0014893191400915384, + 0.03796907886862755, + 0.01258794590830803, + -0.025344882160425186, + 0.14140591025352478, + 0.028354406356811523, + 0.0035325682256370783, + 0.05017172172665596, + 0.01994139887392521, + 0.03679897263646126, + -0.009579945355653763, + -0.012607194483280182, + -0.00034231581958010793, + 0.00046832446241751313, + 0.057916246354579926, + 0.02351403795182705, + 0.06157909706234932, + 0.00789523497223854, + -0.018361341208219528, + 0.0018971840618178248, + -0.007180131506174803, + -0.0010631990153342485, + -0.03140748664736748, + -0.028505641967058182, + 0.010669395327568054, + -0.036474280059337616, + 0.01703447848558426, + 0.04667484760284424, + -0.007303370162844658, + 0.01768752932548523, + 0.012412219308316708, + 0.013702306896448135, + 0.07651616632938385, + 0.05469715967774391, + 0.013292597606778145, + -0.006288900971412659, + 0.0215559434145689, + 0.010094149969518185, + -0.024216346442699432, + -0.15225785970687866, + 0.05467289313673973, + 0.019871067255735397, + 0.04662928730249405, + 0.05072600021958351, + -0.011824453249573708, + -0.028083933517336845, + 0.013322187587618828, + -0.044827401638031006, + 0.05955006927251816, + -0.006152187939733267, + 0.013426700606942177, + -0.014220507815480232, + 0.022510837763547897, + 0.019426455721259117, + -0.05546477064490318, + -0.49202534556388855, + 0.026985207572579384, + -0.08852843940258026, + 0.07166163623332977, + 0.05509938299655914, + -0.42284780740737915, + -0.05131356418132782, + 0.0196990966796875, + -0.008681846782565117, + 0.02739996463060379, + 0.0010900507913902402, + 0.04289104416966438, + -0.06694932281970978, + 0.05930810049176216, + -0.02174360118806362, + 0.03464379161596298, + 0.018284866586327553, + 0.018807150423526764, + 0.019874336197972298, + -0.03665176033973694, + -0.2980017066001892, + 0.050937239080667496, + -0.013874954544007778, + -0.0229057464748621, + 0.016420641914010048, + 0.024160616099834442, + -0.10750921070575714, + -0.010134756565093994, + 0.026874780654907227, + 0.007151094265282154, + 0.06304068863391876, + -0.11811652034521103, + -0.12590888142585754, + 0.031846947968006134, + -0.06898463517427444, + 0.03395693376660347, + -0.00010166154243052006, + -0.19019480049610138, + 0.06616076827049255, + -0.035927142947912216, + 0.08526375889778137, + 0.0015017242403700948, + -0.009137739427387714, + 0.04529058188199997, + -0.23621641099452972, + 0.02148340456187725, + -0.02741178683936596, + -0.20779411494731903, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 64, 1, 1)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_504") + + initializers.append(tensor) + + list_value = [ + 5.195802688598633, + 0.940099835395813, + -7.016428470611572, + 5.185446739196777, + -4.134859085083008, + 2.0121846199035645, + 5.215719223022461, + 3.371406078338623, + 3.7616095542907715, + -3.6593239307403564, + 15.99945068359375, + 3.306276321411133, + 5.790191173553467, + 6.33050537109375, + 3.4512906074523926, + 2.5531861782073975, + 4.278702259063721, + 4.350361347198486, + 8.025779724121094, + -2.8830037117004395, + 2.915111541748047, + 3.592482805252075, + 5.810481071472168, + 3.4743332862854004, + 3.5245680809020996, + 1.8243598937988281, + 8.069726943969727, + 1.401036024093628, + 5.110081672668457, + -12.873579978942871, + 10.977816581726074, + 5.909627437591553, + -0.4007779359817505, + -20.147268295288086, + 6.649413585662842, + 3.325921058654785, + 5.84471321105957, + 4.47447395324707, + 3.754193067550659, + -5.167671203613281, + 3.2778055667877197, + -9.067073822021484, + 2.6243438720703125, + 1.7002031803131104, + 5.476454734802246, + 2.510835647583008, + 3.856968402862549, + 2.3172807693481445, + 12.462139129638672, + 7.355924129486084, + 4.140628814697266, + 4.807559967041016, + 5.7524309158325195, + 4.128836154937744, + 11.4532470703125, + -12.482564926147461, + 5.590144157409668, + 0.9172697067260742, + 4.356811046600342, + 0.9934853315353394, + -4.3548994064331055, + 15.853201866149902, + -5.241130828857422, + 5.9644365310668945, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_505") + + initializers.append(tensor) + + # inputs + + inputs.append(make_tensor_value_info("input", 1, ["batch_size", 3, 32, 32])) + + # outputs + + outputs.append(make_tensor_value_info("/layer1/layer1.0/relu/Relu_output_0", 1, ["batch_size", 64, 8, 8])) + + # nodes + + node = make_node( + "Conv", + ["input", "onnx::Conv_501", "onnx::Conv_502"], + ["/conv1/Conv_output_0"], + name="/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[7, 7], + pads=[3, 3, 3, 3], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node("Relu", ["/conv1/Conv_output_0"], ["/relu/Relu_output_0"], name="/relu/Relu", domain="") + nodes.append(node) + + node = make_node( + "MaxPool", + ["/relu/Relu_output_0"], + ["/maxpool/MaxPool_output_0"], + name="/maxpool/MaxPool", + ceil_mode=0, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node( + "Conv", + ["/maxpool/MaxPool_output_0", "onnx::Conv_504", "onnx::Conv_505"], + ["/layer1/layer1.0/conv1/Conv_output_0"], + name="/layer1/layer1.0/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + domain="", + ) + nodes.append(node) + + node = make_node( + "Relu", + ["/layer1/layer1.0/conv1/Conv_output_0"], + ["/layer1/layer1.0/relu/Relu_output_0"], + name="/layer1/layer1.0/relu/Relu", + domain="", + ) + nodes.append(node) + + # opsets + opset_imports = [make_opsetid(domain, 1 if version is None else version) for domain, version in opsets.items()] + + # graph + graph = make_graph(nodes, "torch_jit", inputs, outputs, initializers) + # '7' + + onnx_model = make_model(graph, opset_imports=opset_imports, functions=functions) + onnx_model.ir_version = 7 + onnx_model.producer_name = "pytorch" + onnx_model.producer_version = "" + onnx_model.domain = "" + onnx_model.model_version = 0 + onnx_model.doc_string = "" + set_model_props(onnx_model, {}) + + return onnx_model diff --git a/onnxruntime/test/python/quantization/test_quantize_static_resnet.py b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py new file mode 100644 index 0000000000000..1efa283af6881 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import random +import tempfile +import unittest + +import numpy as np +import onnx +from numpy.testing import assert_allclose +from onnx.numpy_helper import to_array +from resnet_code import create_model + +from onnxruntime import InferenceSession +from onnxruntime import __version__ as ort_version +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod + + +class FakeResnetCalibrationDataReader(CalibrationDataReader): + def __init__(self, batch_size: int = 16): + super().__init__() + self.dataset = [ + (np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size) + ] + self.iterator = iter(self.dataset) + + def get_next(self) -> dict: + try: + return {"input": next(self.iterator)[0]} + except Exception: + return None + + +class TestStaticQuantizationResNet(unittest.TestCase): + def test_quantize_static_resnet(self): + kwargs = { + "activation_type": QuantType.QUInt8, + "weight_type": QuantType.QInt8, + "calibrate_method": CalibrationMethod.Percentile, + "extra_options": { + "ActivationSymmetric": False, + "EnableSubgraph": False, + "ForceQuantizeNoInputCheck": False, + "MatMulConstBOnly": False, + "WeightSymmetric": True, + "extra.Sigmoid.nnapi": False, + }, + "nodes_to_exclude": None, + "nodes_to_quantize": None, + "op_types_to_quantize": None, + "per_channel": True, + "quant_format": QuantFormat.QDQ, + "reduce_range": False, + } + + proto = create_model() + + with tempfile.TemporaryDirectory() as temp: + model = os.path.join(temp, "resnet_first_nodes.onnx") + with open(model, "wb") as f: + f.write(proto.SerializeToString()) + + for per_channel in [True, False]: + kwargs["per_channel"] = per_channel + dataloader = FakeResnetCalibrationDataReader(16) + with self.subTest(per_channel=per_channel): + qdq_file = os.path.join( + temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx" + ) + quantize_static( + model_input=model, + model_output=qdq_file, + calibration_data_reader=dataloader, + use_external_data_format=False, + **kwargs, + ) + + # With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is: + # * uint8(128) if per_channel is False + # * int8([0, 0, ....]) if per_channel is True + # With onnxruntime>1.16.0 + # * uint8(128) if per_channel is False + # * uint8([128, 128, ..., 127, ...]) if per_channel is True + # QLinearConv : zero point of per-channel filter must be same. + # That's why the quantization forces a symmetric quantization into INT8. + # zero_point is guaranted to be zero whatever the channel is. + + with open(qdq_file, "rb") as f: + onx = onnx.load(f) + for init in onx.graph.initializer: + arr = to_array(init) + if ( + arr.dtype == np.int8 + and "zero_point" not in init.name + and not init.name.endswith("quantized") + ): + raise AssertionError( + f"Initializer {init.name!r} has type {arr.dtype} and " + f"shape {arr.shape} but should be {np.uint8}." + ) + + sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"]) + shape = (1, 3, 32, 32) + size = np.prod(shape) + dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape) + got = sess.run(None, {"input": dummy}) + self.assertEqual(got[0].shape, (1, 64, 8, 8)) + self.assertEqual(got[0].dtype, np.float32) + if per_channel: + expected = np.array( + [ + [[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]], + [[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + else: + expected = np.array( + [ + [[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]], + [[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [1.2182037830352783, 1.050175666809082]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From fcfc2391b818af5e28ded7f669a2f466f9276cb1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:20:56 -0700 Subject: [PATCH 077/121] [JSEP] allow JsCustomAllocator to deal with zero sized input (#17660) ### Description allow JsCustomAllocator to deal with zero sized input. This is a standalone fix for zero-sized tensor handling for JsCustomAllocator. There are other components in JSEP not supporting zero-sized tensors need to be fixed. --- onnxruntime/core/providers/js/allocator.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/allocator.cc b/onnxruntime/core/providers/js/allocator.cc index c1d0aa9abbf6b..574c507222a5c 100644 --- a/onnxruntime/core/providers/js/allocator.cc +++ b/onnxruntime/core/providers/js/allocator.cc @@ -10,6 +10,10 @@ namespace onnxruntime { namespace js { void* JsCustomAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + void* p = EM_ASM_PTR({ return Module.jsepAlloc($0); }, size); stats_.num_allocs++; stats_.bytes_in_use += size; @@ -17,8 +21,10 @@ void* JsCustomAllocator::Alloc(size_t size) { } void JsCustomAllocator::Free(void* p) { - size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); - stats_.bytes_in_use -= size; + if (p != nullptr) { + size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); + stats_.bytes_in_use -= size; + } } void JsCustomAllocator::GetStats(AllocatorStats* stats) { From f50fa46fe09819734550a7e8725db999210e2b97 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:21:20 -0700 Subject: [PATCH 078/121] [JSEP] allow DataTransfer to deal with zero sized input (#17661) ### Description allow DataTransfer to deal with zero sized input. This is a standalone fix for zero-sized tensor handling for JSEP DataTransfer. There are other components in JSEP not supporting zero-sized tensors need to be fixed. --- .../core/providers/js/data_transfer.cc | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index c62362d90867f..ebea041b80128 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -20,23 +20,25 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { - // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); - } else { - // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + if (bytes > 0) { + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + } else { + // copy from CPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + jsepDownload(src_data, dst_data, bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { - // copy from GPU to CPU - jsepDownload(src_data, dst_data, bytes); } return Status::OK(); From b2b140860817d974b5f9dd0b8943c49a725a5edd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:24:46 -0700 Subject: [PATCH 079/121] [js/web] update browser launch cmd flags (#17658) ### Description update Chromium browser launch command line flags Canary already using dxc so no need to specify '--enable-dawn-features=use_dxc' for canary. --- js/web/script/test-runner-cli.ts | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a75321d45f1ef..f3764e63fcf45 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -493,19 +493,13 @@ async function main() { karmaArgs.push('--force-localhost'); } if (webgpu) { - if (browser.includes('Canary')) { - chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); - } else { - chromiumFlags.push('--enable-dawn-features=use_dxc'); - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } + // flag 'allow_unsafe_apis' is required to enable experimental features like fp16 and profiling inside pass. + // flag 'use_dxc' is required to enable DXC compiler. + chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); } if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } - if (config.options.globalEnvFlags?.webgpu?.profilingMode === 'default') { - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { From a942bbf4897688fd1bd88bb1db73d2f8d648bcbc Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 25 Sep 2023 14:12:11 -0700 Subject: [PATCH 080/121] Update nodejs to 18.x (#17657) 1. Upgrade nodejs from 16.x to 18.x for Windows pipelines 2. Avoid using Azure DevOps "NodeTool" on Linux. The tool installs nodejs from internet or local disk cache. But we already moved all Linux tests to docker. So we do not need the installer anymore. 3. Remove some other unused code. --- .../azure-pipelines/linux-ci-pipeline.yml | 4 -- .../linux-cpu-aten-pipeline.yml | 4 -- .../linux-multi-gpu-tensorrt-ci-pipeline.yml | 2 - .../linux-openvino-ci-pipeline.yml | 2 - .../orttraining-linux-ci-pipeline.yml | 47 +------------------ .../orttraining-linux-gpu-ci-pipeline.yml | 11 +---- .../templates/jobs/win-ci-vs-2022-job.yml | 2 +- .../azure-pipelines/templates/linux-ci.yml | 40 +--------------- .../templates/linux-wasm-ci.yml | 2 +- .../templates/mac-cpu-packing-jobs.yml | 2 +- .../templates/react-native-ci.yml | 2 +- .../templates/web-browserstack-ci.yml | 2 +- .../azure-pipelines/templates/web-ci.yml | 2 +- .../azure-pipelines/templates/win-ci.yml | 2 +- .../azure-pipelines/templates/win-wasm-ci.yml | 2 +- .../azure-pipelines/templates/win-web-ci.yml | 2 +- .../templates/win-web-multi-browsers.yml | 2 +- .../azure-pipelines/win-ci-fuzz-testing.yml | 2 +- 18 files changed, 15 insertions(+), 117 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 33fc9d94bac09..395c190ce9e11 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -56,10 +56,6 @@ stages: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml index 2c5a69e216d14..146186e9eeaf5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -53,10 +53,6 @@ jobs: clean: true submodules: recursive - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu diff --git a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml index 0a7dc0e456a95..e4441853240e5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml @@ -36,5 +36,3 @@ jobs: JobName: 'Linux_CI_Multi_GPU_TensorRT_Dev' # The latest TensorRT container only supports ubuntu20.04 and python 3.8 RunDockerBuildArgs: '-o ubuntu20.04 -d tensorrt -x "--enable_multi_device_test"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 93ee17b4cc7e6..c92fc93abba37 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,6 +33,4 @@ jobs: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2023.0.0 -x "--use_openvino CPU_FP32 --build_wheel"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml index 007630edb25be..018672e0b2dea 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml @@ -54,10 +54,6 @@ jobs: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -88,6 +84,7 @@ jobs: mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ + --volume /data/models:/build/models:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ @@ -109,51 +106,11 @@ jobs: --build_wheel \ --enable_onnx_tests \ --enable_training \ - --use_cache \ - --update --build; \ + --use_cache; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: CmdLine@2 - displayName: 'Install python deps' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_training - --ctest_path "" - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml index adf5695bd76eb..2d2719fef8f3d 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml @@ -32,7 +32,6 @@ jobs: parameters: AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' JobName: 'Onnxruntime_Linux_GPU_Training' - SubmoduleCheckoutMode: 'recursive' RunDockerBuildArgs: > -o ubuntu20.04 -d gpu -t onnxruntime_orttraining_ortmodule_tests_image @@ -40,24 +39,16 @@ jobs: -e -x " --enable_training - --config $(buildConfig) + --config Release --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 --build_wheel --enable_nvtx_profile --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 " - DoNugetPack: 'false' RunInjectedPipeline: 'true' InjectedPipeline: 'orttraining-linux-gpu-test-ci-pipeline.yml' DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - BuildConfig: $(buildConfig) - ArtifactName: 'drop-linux' TimeoutInMinutes: 140 # Enable unreleased onnx opsets in CI builds # This facilitates testing the implementation for the new opsets AllowReleasedOpsetOnly: '0' - Strategy: - maxParallel: 2 - matrix: - Release: - buildConfig: Release diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 46f2ae7b97acc..3b1fde6cb6e4f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -97,7 +97,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' force32bit: ${{ parameters.isX86 }} # Our build machine doesn't have java x86 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml index 05b2dee77e689..7b9788d90b17d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml @@ -1,23 +1,14 @@ parameters: AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' StageName : 'Linux_CI_Dev' - SubmoduleCheckoutMode: '' RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' - DoNodejsPack: 'false' - DoNugetPack: 'false' NuPackScript: '' RunInjectedPipeline: 'false' InjectedPipeline: '' DockerImageTag: '' - BuildConfig: '' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 # Controls whether unreleased onnx opsets are allowed. Default is set to 1 AllowReleasedOpsetOnly: '1' - # to inject strategy, you need to pass in the whole yaml structure - - # https://docs.microsoft.com/en-us/azure/devops/pipelines/yaml-schema?view=azure-devops&tabs=schema#strategies - # see example in orttraining-linux-gpu-ci-pipeline.yml - Strategy: '' jobs: - job: ${{ parameters.StageName }} @@ -28,16 +19,8 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} skipComponentGovernanceDetection: true pool: ${{ parameters.AgentPool }} - ${{ if ne(parameters.Strategy, '') }}: - strategy: - ${{ parameters.Strategy }} steps: - checkout: self - ${{ if ne(parameters.SubmoduleCheckoutMode, '') }}: - submodules: ${{ parameters.SubmoduleCheckoutMode }} - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - template: run-docker-build-steps.yml parameters: RunDockerBuildArgs: '${{ parameters.RunDockerBuildArgs }}' @@ -48,31 +31,10 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() - - ${{ if eq(parameters['DoNugetPack'], 'true') }}: - - script: | - ${{ parameters.NuPackScript }} - displayName: 'Create Artifacts' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - script: | - npm pack - cp $(Build.SourcesDirectory)/js/node/onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - cp -R $(Build.SourcesDirectory)/js/node/prebuilds $(Build.ArtifactStagingDirectory)/prebuilds - workingDirectory: '$(Build.SourcesDirectory)/js/node' - displayName: 'Create NPM Package' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - ${{ if eq(parameters['RunInjectedPipeline'], 'true') }}: - template: | ${{ parameters.InjectedPipeline }} parameters: DockerImageTag: ${{ parameters.DockerImageTag }} - BuildConfig: ${{ parameters.BuildConfig }} + BuildConfig: Release - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 0e584b550f562..96a0ebd753d8e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -83,7 +83,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index adfcd98e37230..f5e5435cfaca0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -50,7 +50,7 @@ jobs: versionSpec: 3.11 - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8c54e71448992..e63939ae0114c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -80,7 +80,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - script: brew install coreutils ninja npm yarn diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml index 4494fd36b336e..96e6ff89cd4f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml @@ -29,7 +29,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 0b7bd3f645442..3f6c6af753a98 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -74,7 +74,7 @@ stages: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 80d285f3fd3fb..8d28b4ce580b4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -101,7 +101,7 @@ stages: - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: jobs/set-winenv.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 9d36e2dbe4944..406683af80222 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -74,7 +74,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index bad7448715936..d737376eb99b5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -72,7 +72,7 @@ jobs: displayName: 'Testing: force EOL to lf on windows for /js/**' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 723567389579d..f7876f15029c1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -33,7 +33,7 @@ jobs: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index f3a5728d6519b..98f1bf7ea1a16 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -32,7 +32,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: NuGetToolInstaller@0 displayName: Use Nuget 5.7.0 From 95e8dfaea51f6e6dda89b66d95c672927f578f21 Mon Sep 17 00:00:00 2001 From: aimilefth <60664743+aimilefth@users.noreply.github.com> Date: Tue, 26 Sep 2023 01:56:03 +0300 Subject: [PATCH 081/121] Update quant_utils.py/write_calibration_table (#17314) --- onnxruntime/python/tools/quantization/quant_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 74e54c3f1fa37..739e399042bf3 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -505,7 +505,7 @@ def apply_plot(hist, hist_edges): plt.show() -def write_calibration_table(calibration_cache): +def write_calibration_table(calibration_cache, dir="."): """ Helper function to write calibration table to files. """ @@ -519,7 +519,7 @@ def write_calibration_table(calibration_cache): logging.info(f"calibration cache: {calibration_cache}") - with open("calibration.json", "w") as file: + with open(os.path.join(dir, "calibration.json"), "w") as file: file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse # Serialize data using FlatBuffers @@ -551,7 +551,7 @@ def write_calibration_table(calibration_cache): builder.Finish(cal_table) buf = builder.Output() - with open("calibration.flatbuffers", "wb") as file: + with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file: file.write(buf) # Deserialize data (for validation) @@ -564,7 +564,7 @@ def write_calibration_table(calibration_cache): logging.info(key_value.Value()) # write plain text - with open("calibration.cache", "w") as file: + with open(os.path.join(dir, "calibration.cache"), "w") as file: for key in sorted(calibration_cache.keys()): value = calibration_cache[key] s = key + " " + str(max(abs(value[0]), abs(value[1]))) From ccb73fd827d29f69b1b5dfcbd4c27188b2364f0d Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 25 Sep 2023 20:03:24 -0700 Subject: [PATCH 082/121] [On-Device Training] Expose Parameters through the Training API (#17364) --- .../Training/CheckpointState.shared.cs | 133 +++++++---- .../Training/NativeTrainingMethods.shared.cs | 34 +++ .../Training/TrainingSession.shared.cs | 55 ++--- .../TrainingTest.cs | 128 ++++++++-- .../python/orttraining_pybind_state.cc | 80 ++++++- .../python/training/api/checkpoint_state.py | 220 ++++++++++++++++-- .../orttraining_test_python_bindings.py | 71 +++++- .../training_api/core/training_capi_tests.cc | 102 ++++++++ .../training_api/checkpoint_property.h | 10 +- .../include/onnxruntime_training_c_api.h | 61 ++++- .../include/onnxruntime_training_cxx_api.h | 36 ++- .../include/onnxruntime_training_cxx_inline.h | 12 + .../orttraining/training_api/module.cc | 59 +++++ orttraining/orttraining/training_api/module.h | 5 +- .../onnxruntime_training_c_api.cc | 79 ++++++- .../training_api/ort_training_apis.h | 10 + 16 files changed, 936 insertions(+), 159 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c6303702ac..6889112acb385 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,16 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + T[] value = { propertyValue }; + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } ///

- /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. /// - /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + if (propertyType == PropertyType.Int) + { + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.Float) + { + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } - else if (propertyType == PropertyType.Float) + finally { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); allocator.FreeMemory(propertyValue); - return value[0]; } - else if (propertyType == PropertyType.String) + } + + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); + + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 1868ff509bfc3..68a399f8b9671 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -42,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -97,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -359,6 +365,34 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be135b..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,13 +358,14 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var shape = new long[] { (long)bufferSize }; + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc118..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); + } } } @@ -530,6 +536,82 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(state); + Assert.NotNull(parameter); + + var typeShape = parameter.GetTensorTypeAndShape(); + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + { + Assert.NotNull(state); + + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 3f3aa396e6ca0..35d9755ba0ba7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); checkpoint_state .def(py::init()) - .def("add_property", [](onnxruntime::training::api::CheckpointState* state, - const std::string& property_name, - const std::variant& property_value) { - state->property_bag.AddProperty(property_name, property_value); - }) - .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.GetProperty(property_name); - }) - .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.HasProperty(property_name); - }); + .def("add_property", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& property_name, + const std::variant& property_value) { + state->property_bag.AddProperty(property_name, property_value); + }) + .def("get_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.GetProperty(property_name); + }) + .def("has_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.HasProperty(property_name); + }) + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + ORT_THROW_IF_ERROR(it->second->CopyFrom( + state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + return it->second; + }) + .def("has_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + return state->module_checkpoint_state.named_parameters.count(parameter_name); + }) + .def("parameter_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }) + .def("property_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->property_bag) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }); py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); @@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(scheduler->Step()); }); + py::class_> + parameter(m, "Parameter"); + parameter + .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name) + .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data) + .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient) + .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad) + .def("copy_from", + [](onnxruntime::training::api::Parameter* parameter, + onnxruntime::training::api::CheckpointState* state, + OrtValue& value) -> void { + ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }); + m.def( "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index 285264bbed744..ba95cd04fce7e 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -5,70 +5,171 @@ import os +import numpy as np + from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue -class CheckpointState: - """Class that holds the state of the training session +class Parameter: + """Class that represents a model parameter - This class holds all the state information of the training session such as the model parameters, - its gradients, the optimizer state and user defined properties. + This class represents a model parameter and provides access to its data, + gradient and other properties. This class is not expected to be instantiated directly. + Instead, it is returned by the `CheckpointState` object. + + Args: + parameter: The C.Parameter object that holds the underlying parameter data. + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, parameter: C.Parameter, state: C.CheckpointState): + self._parameter = parameter + self._state = state - User defined properties can be indexed by name from the `CheckpointState` object. + @property + def name(self) -> str: + """The name of the parameter""" + return self._parameter.name - To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + @property + def data(self) -> np.ndarray: + """The data of the parameter""" + return self._parameter.data.numpy() + + @data.setter + def data(self, value: np.ndarray) -> None: + """Sets the data of the parameter""" + self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + @property + def grad(self) -> np.ndarray: + """The gradient of the parameter""" + return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None + + @property + def requires_grad(self) -> bool: + """Whether or not the parameter requires its gradient to be computed""" + return self._parameter.requires_grad + + def __repr__(self) -> str: + """Returns a string representation of the parameter""" + return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" + + +class Parameters: + """Class that holds all the model parameters + + This class holds all the model parameters and provides access to them. + This class is not expected to be instantiated directly. Instead, it is returned by the + `CheckpointState`'s parameters attribute. + This class behaves like a dictionary and provides access to the parameters by name. Args: - state: The C.Checkpoint state object that holds the underlying session state. + state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): - if not isinstance(state, C.CheckpointState): - raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state - @classmethod - def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: - """Loads the checkpoint state from the checkpoint file + def __getitem__(self, name: str) -> Parameter: + """Gets the parameter associated with the given name + + Searches for the name in the parameters of the checkpoint state. Args: - checkpoint_uri: The path to the checkpoint file. + name: The name of the parameter Returns: - CheckpointState: The checkpoint state object. + The value of the parameter + + Raises: + KeyError: If the parameter is not found """ - return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) - @classmethod - def save_checkpoint( - cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False - ) -> None: - """Saves the checkpoint state to the checkpoint file + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + return Parameter(self._state.get_parameter(name), self._state) + + def __setitem__(self, name: str, value: np.ndarray) -> None: + """Sets the parameter value for the given name + + Searches for the name in the parameters of the checkpoint state. + If the name is found in parameters, the value is updated. Args: - state: The checkpoint state object. - checkpoint_uri: The path to the checkpoint file. - include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + name: The name of the parameter + value: The value of the parameter as a numpy array + + Raises: + KeyError: If the parameter is not found """ - C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + def __contains__(self, name: str) -> bool: + """Checks if the parameter exists in the state + + Args: + name: The name of the parameter + + Returns: + True if the name is a parameter False otherwise + """ + + return self._state.has_parameter(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for parameter_name in self._state.parameter_names(): + yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state) + + def __repr__(self) -> str: + """Returns a string representation of the parameters""" + return self._state.parameter_names() + + def __len__(self) -> int: + """Returns the number of parameters""" + return len(self._state.parameter_names()) + + +class Properties: + def __init__(self, state: C.CheckpointState): + self._state = state def __getitem__(self, name: str) -> int | float | str: """Gets the property associated with the given name + Searches for the name in the properties of the checkpoint state. + Args: name: The name of the property Returns: The value of the property + + Raises: + KeyError: If the property is not found """ + + if name not in self: + raise KeyError(f"Property {name} not found.") + return self._state.get_property(name) def __setitem__(self, name: str, value: int | float | str) -> None: """Sets the property value for the given name + Searches for the name in the properties of the checkpoint state. + The value is added or updated in the properties. + Args: name: The name of the property value: The value of the property + Properties only support int, float and str values. """ self._state.add_property(name, value) @@ -79,6 +180,75 @@ def __contains__(self, name: str) -> bool: name: The name of the property Returns: - True if the property exists, False otherwise + True if the name is a property, False otherwise """ + return self._state.has_property(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for property_name in self._state.property_names(): + yield property_name, self._state.get_property(property_name) + + def __repr__(self) -> str: + """Returns a string representation of the properties""" + return self._state.property_names() + + def __len__(self) -> int: + """Returns the number of properties""" + return len(self._state.property_names()) + + +class CheckpointState: + """Class that holds the state of the training session + + This class holds all the state information of the training session such as the model parameters, + its gradients, the optimizer state and user defined properties. + + To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + + Args: + state: The C.Checkpoint state object that holds the underlying session state. + """ + + def __init__(self, state: C.CheckpointState): + if not isinstance(state, C.CheckpointState): + raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") + self._state = state + self._parameters = Parameters(self._state) + self._properties = Properties(self._state) + + @classmethod + def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: + """Loads the checkpoint state from the checkpoint file + + Args: + checkpoint_uri: The path to the checkpoint file. + + Returns: + CheckpointState: The checkpoint state object. + """ + return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) + + @classmethod + def save_checkpoint( + cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False + ) -> None: + """Saves the checkpoint state to the checkpoint file + + Args: + state: The checkpoint state object. + checkpoint_uri: The path to the checkpoint file. + include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + """ + C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + + @property + def parameters(self) -> Parameters: + """Returns the model parameters from the checkpoint state""" + return self._parameters + + @property + def properties(self) -> Properties: + """Returns the properties from the checkpoint state""" + return self._properties diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 56338ddbaffef..d5c37b3e36ee7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -360,14 +360,18 @@ def test_add_get_property(property_value): if isinstance(property_value, float): property_value = float(np.float32(property_value)) - state["property"] = property_value - assert "property" in state - assert state["property"] == property_value + assert len(state.properties) == 0 + + state.properties["property"] = property_value + assert "property" in state.properties + assert state.properties["property"] == property_value + assert len(state.properties) == 1 CheckpointState.save_checkpoint(state, checkpoint_file_path) new_state = CheckpointState.load_checkpoint(checkpoint_file_path) - assert "property" in new_state - assert new_state["property"] == property_value + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 def test_get_input_output_names(): @@ -563,3 +567,60 @@ def test_eval_step_with_ort_values(): fetches = model(inputs, labels) assert isinstance(fetches, OrtValue) assert fetches + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: + ( + checkpoint_file_path, + training_model_file_path, + eval_model_file_path, + _, + pt_model, + ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + + state = CheckpointState.load_checkpoint(checkpoint_file_path) + + model = Module(training_model_file_path, state, eval_model_file_path, device=device) + + state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + + for name, pt_param in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) + + original_param = state.parameters["fc1.weight"].data + state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32) + updated_param = state.parameters["fc1.weight"].data + assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) + + model.train() + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None + for name, _ in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert ort_param.grad.any() + + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index d734be8e3474b..e46952d87c2bf 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); } } + +TEST(TrainingCApiTest, GetParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); +} + +TEST(TrainingCApiTest, UpdateParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} + +#ifdef USE_CUDA +TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} +#endif + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h index d7b1e295df53e..3c38c99b3152f 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.h +++ b/orttraining/orttraining/training_api/checkpoint_property.h @@ -22,10 +22,12 @@ struct PropertyBag { PropertyBag() = default; void AddProperty(const std::string& name, const PropertyDataType& val) { - ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), - "Duplicated property named ", name); - - named_properties_.insert({name, val}); + auto it = named_properties_.find(name); + if (it == named_properties_.end()) { + named_properties_.insert({name, val}); + } else { + it->second = val; + } } template diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0af737074964d..0e8544a7639ba 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -608,14 +608,14 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_type Type of the property associated with the given name. * \param[in] property_value Property value associated with the given name. * @@ -632,7 +632,7 @@ struct OrtTrainingApi { * exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \param[in] allocator Allocator used to allocate the memory for the property_value. * \param[out] property_type Type of the property associated with the given name. * \param[out] property_value Property value associated with the given name. @@ -669,6 +669,57 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over and returned as an OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 0edef20ba6da8..218bef524200c 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -112,13 +112,13 @@ class CheckpointState : public detail::Base { const std::basic_string& path_to_checkpoint, const bool include_optimizer_state = false); - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_value Property value associated with the given name. * */ @@ -129,12 +129,38 @@ class CheckpointState : public detail::Base { * Gets the property value from an existing entry in the checkpoint state. The property must * exist in the checkpoint state to be able to retrieve it successfully. * - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \return Property value associated with the given property name. * */ Property GetProperty(const std::string& property_name); + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index c0048458ddf4d..7d1326a10f8f8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) { return property; } +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + AllocatorWithDefaultOptions allocator; + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); + + return Value{parameter}; +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index d1775e358163c..cf49a01517d6b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph, #endif } // namespace +Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable())); + + return Status::OK(); +} + +Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) { + ORT_ENFORCE(data_.IsAllocated(), + "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable())); + + return Status::OK(); +} + Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { // assert param is allocated ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient."); @@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } +Module::~Module() { + state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr; +} + size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index adb633343263e..f323e6be72d49 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -21,6 +21,8 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } + Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; + Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -34,7 +36,6 @@ struct Parameter { // Reset and release the gradient buffer of this Parameter greedily. Status ResetGrad(); - protected: Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad); private: @@ -83,6 +84,8 @@ struct Module { const std::vector>& providers, gsl::span op_domains = gsl::span()); + ~Module(); + // Return the trainable/nontrainable parameters std::vector> Parameters() const; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 6693bba348648..38a9aad9640ea 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN + if (checkpoint_buffer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr."); + } + *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); @@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) { + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter) { + API_IMPL_BEGIN + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { + API_IMPL_BEGIN + + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + + auto status = it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter); + if (!status.IsOK()) { + OrtApis::ReleaseValue(*parameter); + return onnxruntime::ToOrtStatus(status); + } + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, - &OrtTrainingApis::LoadCheckpointFromBuffer}; + &OrtTrainingApis::LoadCheckpointFromBuffer, + &OrtTrainingApis::GetParameterTypeAndShape, + &OrtTrainingApis::UpdateParameter, + &OrtTrainingApis::GetParameter}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index c87108957c975..2a8c1e30361c6 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + +ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + +ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + } // namespace OrtTrainingApis From aed43f429a961e2fee543b801da50899a11b43cb Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 26 Sep 2023 04:49:13 -0400 Subject: [PATCH 083/121] [java] Enable output pinning in OrtSession and OrtTrainingSession (#16835) --- .../main/java/ai/onnxruntime/OrtSession.java | 202 +++++++++++++++--- .../ai/onnxruntime/OrtTrainingSession.java | 172 ++++++++++++--- .../main/native/ai_onnxruntime_OrtSession.c | 38 +++- .../ai_onnxruntime_OrtTrainingSession.c | 109 ++++++---- .../java/ai/onnxruntime/InferenceTest.java | 152 ++++++++++++- .../java/ai/onnxruntime/ModelGenerators.java | 96 +++++++++ .../test/java/ai/onnxruntime/TestHelpers.java | 6 + .../java/ai/onnxruntime/TrainingTest.java | 12 +- .../resources/java-three-output-matmul.onnx | Bin 0 -> 530 bytes 9 files changed, 668 insertions(+), 119 deletions(-) create mode 100644 java/src/test/resources/java-three-output-matmul.onnx diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 435f86daa5fe2..fbea13d155507 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -239,7 +239,7 @@ public Result run(Map inputs, RunOptions runOp */ public Result run(Map inputs, Set requestedOutputs) throws OrtException { - return run(inputs, requestedOutputs, null); + return run(inputs, requestedOutputs, Collections.emptyMap(), null); } /** @@ -259,17 +259,90 @@ public Result run( Set requestedOutputs, RunOptions runOptions) throws OrtException { + return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions); + } + + /** + * Scores an input feed dict, returning the map of pinned outputs. + * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, Map pinnedOutputs) + throws OrtException { + return run(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs) + throws OrtException { + return run(inputs, requestedOutputs, pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @param runOptions The RunOptions to control this run. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs, + RunOptions runOptions) + throws OrtException { if (!closed) { if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { throw new OrtException( "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > numOutputs)) { + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -284,20 +357,41 @@ public Result run( "Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (outputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); + } + } for (String s : requestedOutputs) { if (outputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + + s + + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + outputNames.toString()); } } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - - OnnxValue[] outputValues = + boolean[] ownedByResult = run( OnnxRuntime.ortApiHandle, nativeHandle, @@ -307,13 +401,40 @@ public Result run( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new Result(outputNamesArray, outputValues); + return new Result(outputNamesArray, outputValues, ownedByResult); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); } } + /** + * Pulls out the native handle by casting it to the appropriate type. + * + * @param v The OnnxValue. + * @return The native handle. + */ + static long getHandle(OnnxValue v) { + /* + * Note this method exists as interface methods are all public, but we do not want users to be + * able to access the native pointer via a public API so can't add a method to OnnxValue which + * exposes it. + */ + if (v instanceof OnnxTensorLike) { + return ((OnnxTensorLike) v).nativeHandle; + } else if (v instanceof OnnxSequence) { + return ((OnnxSequence) v).nativeHandle; + } else if (v instanceof OnnxMap) { + return ((OnnxMap) v).nativeHandle; + } else { + throw new IllegalArgumentException( + "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + + v.getClass()); + } + } + /** * Gets the metadata for the currently loaded model. * @@ -409,8 +530,9 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long throws OrtException; /** - * The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other - * handles must be valid pointers. + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can + * contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but + * all other handles must be valid pointers. * * @param apiHandle The pointer to the api. * @param nativeHandle The pointer to the session. @@ -419,12 +541,14 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long * @param inputs The input tensors. * @param numInputs The number of inputs. * @param outputNamesArray The requested output names. + * @param outputValues The OnnxValue output array. + * @param outputHandles The OrtValue output pointer array. * @param numOutputs The number of requested outputs. * @param runOptionsHandle The (possibly null) pointer to the run options. - * @return The OnnxValues produced by this run. + * @return A boolean array representing if the OnnxValues were allocated by this run call. * @throws OrtException If the native call failed in some way. */ - private native OnnxValue[] run( + private native boolean[] run( long apiHandle, long nativeHandle, long allocatorHandle, @@ -433,6 +557,8 @@ private native OnnxValue[] run( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; @@ -1417,9 +1543,13 @@ private native void addRunConfigEntry( /** * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. * - *

When this is closed it closes all the {@link OnnxValue}s inside it. If you maintain a - * reference to a value after this object has been closed it will throw an {@link + *

When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you + * maintain a reference to a value after this object has been closed it will throw an {@link * IllegalStateException} upon access. + * + *

{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed + * by the {@link Result#close()} method. Ownership of each output can be checked with {@link + * Result#isResultOwner(int)}. */ public static class Result implements AutoCloseable, Iterable> { @@ -1429,6 +1559,8 @@ public static class Result implements AutoCloseable, Iterable list; + private final boolean[] ownedByResult; + private boolean closed; /** @@ -1437,21 +1569,23 @@ public static class Result implements AutoCloseable, IterableThrows {@link IllegalStateException} if the container has been closed, and {@link + * ArrayIndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return Is that value owned by this result object? + */ + public boolean isResultOwner(int index) { + if (!closed) { + return ownedByResult[index]; + } else { + throw new IllegalStateException("Result is closed"); + } + } + /** * Returns the number of outputs in this Result. * diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 8c03c5b80433c..49ddf29c22335 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -418,7 +418,7 @@ private static native void setSeed(long apiHandle, long trainingHandle, long see */ public OrtSession.Result trainStep(Map inputs) throws OrtException { - return trainStep(inputs, trainOutputNames, null); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), null); } /** @@ -432,7 +432,7 @@ public OrtSession.Result trainStep(Map inputs) public OrtSession.Result trainStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return trainStep(inputs, trainOutputNames, runOptions); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), runOptions); } /** @@ -446,14 +446,41 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs) throws OrtException { - return trainStep(inputs, requestedOutputs, null); + return trainStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single step of training, accumulating the gradients. * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * * @param inputs The inputs (must include both the features and the target). - * @param requestedOutputs The requested outputs. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return Requested outputs produced by the training step. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result trainStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return trainStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single step of training, accumulating the gradients. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs (must include both the features and the target). + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return Requested outputs produced by the training step. * @throws OrtException If the native call failed. @@ -461,6 +488,7 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -472,12 +500,14 @@ public OrtSession.Result trainStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > trainOutputNames.size())) { + int numTrainOutputs = trainOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numTrainOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + trainOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numTrainOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -492,12 +522,35 @@ public OrtSession.Result trainStep( "Unknown input name " + t.getKey() + ", expected one of " + trainInputNames); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (trainOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + trainOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (trainOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + trainOutputNames.toString()); @@ -505,7 +558,7 @@ public OrtSession.Result trainStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = trainStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -516,8 +569,10 @@ public OrtSession.Result trainStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -540,7 +595,7 @@ public OrtSession.Result trainStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] trainStep( + private native boolean[] trainStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -550,6 +605,8 @@ private native OnnxValue[] trainStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle); /** @@ -561,7 +618,7 @@ private native OnnxValue[] trainStep( */ public OrtSession.Result evalStep(Map inputs) throws OrtException { - return evalStep(inputs, evalOutputNames, null); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), null); } /** @@ -575,7 +632,7 @@ public OrtSession.Result evalStep(Map inputs) public OrtSession.Result evalStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return evalStep(inputs, evalOutputNames, runOptions); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), runOptions); } /** @@ -589,14 +646,41 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs) throws OrtException { - return evalStep(inputs, requestedOutputs, null); + return evalStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single evaluation step using the supplied inputs. * - * @param inputs The model inputs. - * @param requestedOutputs The requested output names. + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The requested outputs. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result evalStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return evalStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single evaluation step using the supplied inputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return The requested outputs. * @throws OrtException If the native call failed. @@ -604,6 +688,7 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -615,12 +700,14 @@ public OrtSession.Result evalStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > evalOutputNames.size())) { + int numEvalOutputs = evalOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numEvalOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + evalOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numEvalOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -635,12 +722,35 @@ public OrtSession.Result evalStep( "Unknown input name " + t.getKey() + ", expected one of " + evalInputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (evalOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + evalOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (evalOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + evalOutputNames.toString()); @@ -648,7 +758,7 @@ public OrtSession.Result evalStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = evalStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -659,8 +769,10 @@ public OrtSession.Result evalStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -682,7 +794,7 @@ public OrtSession.Result evalStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] evalStep( + private native boolean[] evalStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -692,6 +804,8 @@ private native OnnxValue[] evalStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 6f4e34648cf81..f4d5ab080cd31 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -316,14 +316,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo(JNIE /* * Class: ai_onnxruntime_OrtSession * Method: run - * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; - * private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs) + * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z + * private native boolean[] run(long apiHandle, long nativeHandle, long allocatorHandle, + * String[] inputNamesArray, long[] inputs, long numInputs, + * String[] outputNamesArray, long numOutputs, + * OnnxValue[] outputValues, long[] outputHandles, + * long runOptionsHandle) throws OrtException; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, - jlong numOutputs, jlong runOptionsHandle) { + jlong numOutputs, jobjectArray outputValuesArr, + jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; @@ -331,7 +336,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv OrtSession* session = (OrtSession*)sessionHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers const char** inputNames = allocarray(numInputs, sizeof(char*)); @@ -376,13 +381,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, tensorArr, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. // ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, // _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, @@ -394,21 +405,26 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, ORTJNI_OnnxValueClassName); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. cleanup_output_values: diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index b3b530a8b15aa..9f7b8d3a3dcfc 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -330,12 +330,12 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad /* * Class: ai_onnxruntime_OrtTrainingSession * Method: trainStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -343,31 +343,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -388,13 +388,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. //ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, // size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, @@ -406,24 +412,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -437,15 +448,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; @@ -454,12 +465,12 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep /* * Class: ai_onnxruntime_OrtTrainingSession * Method: evalStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -467,31 +478,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -512,11 +523,14 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } // Actually score the inputs. @@ -530,24 +544,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -561,15 +580,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 08d2a5698d579..e975117fb75bd 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -6,11 +6,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import ai.onnxruntime.OrtException.OrtErrorCode; import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode; @@ -31,6 +34,8 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -71,7 +76,7 @@ public void environmentTest() { @Test public void testVersion() { String version = env.getVersion(); - Assertions.assertFalse(version.isEmpty()); + assertFalse(version.isEmpty()); } @Test @@ -749,6 +754,151 @@ public void testOverridingInitializer() throws OrtException { } } + @Test + public void testPinnedOutputs() throws OrtException { + String modelPath = TestHelpers.getResourcePath("/java-three-output-matmul.onnx").toString(); + FloatBuffer outputABuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputBBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputCBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooSmallBuf = + ByteBuffer.allocateDirect(4 * 2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooBigBuf = + ByteBuffer.allocateDirect(4 * 6).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer wrongShapeBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + LongBuffer wrongTypeBuf = + ByteBuffer.allocateDirect(8 * 4).order(ByteOrder.nativeOrder()).asLongBuffer(); + + try (SessionOptions options = new SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options); + OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}}); + OnnxTensor outputA = OnnxTensor.createTensor(env, outputABuf, new long[] {1, 4}); + OnnxTensor outputB = OnnxTensor.createTensor(env, outputBBuf, new long[] {1, 4}); + OnnxTensor outputC = OnnxTensor.createTensor(env, outputCBuf, new long[] {1, 4}); + OnnxTensor tooSmall = OnnxTensor.createTensor(env, tooSmallBuf, new long[] {1, 2}); + OnnxTensor tooBig = OnnxTensor.createTensor(env, tooBigBuf, new long[] {1, 6}); + OnnxTensor wrongShape = OnnxTensor.createTensor(env, wrongShapeBuf, new long[] {2, 2}); + OnnxTensor wrongType = OnnxTensor.createTensor(env, wrongTypeBuf, new long[] {1, 4})) { + Map inputMap = Collections.singletonMap("input", t); + Set requestedOutputs = new LinkedHashSet<>(); + Map pinnedOutputs = new LinkedHashMap<>(); + + // Test that all outputs can be pinned + pinnedOutputs.put("output-0", outputA); + pinnedOutputs.put("output-1", outputB); + pinnedOutputs.put("output-2", outputC); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + assertSame(outputA, r.get(0)); + assertSame(outputB, r.get(1)); + assertSame(outputC, r.get(2)); + assertFalse(r.isResultOwner(0)); + assertFalse(r.isResultOwner(1)); + assertFalse(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a single pinned output + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(1, r.size()); + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + assertFalse(r.isResultOwner(0)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a mixture of pinned and generated outputs + requestedOutputs.add("output-0"); + requestedOutputs.add("output-2"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + // pinned outputs are first + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + // requested outputs are different + assertNotSame(outputA, r.get("output-0").get()); + assertNotSame(outputC, r.get("output-2").get()); + // check ownership. + assertFalse(r.isResultOwner(0)); + assertTrue(r.isResultOwner(1)); + assertTrue(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that overlapping names causes an error + requestedOutputs.add("output-1"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_JAVA_UNKNOWN, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong type causes an error + pinnedOutputs.put("output-0", wrongType); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong shape (but right capacity) causes an error. + pinnedOutputs.put("output-1", wrongShape); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too small causes an error + pinnedOutputs.put("output-1", tooSmall); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too large causes an error + pinnedOutputs.put("output-1", tooBig); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + } + } + } + private static File getTestModelsDir() throws IOException { // get build directory, append downloaded models location String cwd = System.getProperty("user.dir"); diff --git a/java/src/test/java/ai/onnxruntime/ModelGenerators.java b/java/src/test/java/ai/onnxruntime/ModelGenerators.java index 90fda4c5cf610..7bf7cef43208a 100644 --- a/java/src/test/java/ai/onnxruntime/ModelGenerators.java +++ b/java/src/test/java/ai/onnxruntime/ModelGenerators.java @@ -182,6 +182,102 @@ public void generateMatMul() throws IOException { } } + public void generateThreeOutputMatmul() throws IOException { + OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder(); + graph.setName("ort-test-three-matmul"); + + // Add placeholders + OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder(); + input.setName("input"); + OnnxMl.TypeProto inputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + input.setType(inputType); + graph.addInput(input); + OnnxMl.ValueInfoProto.Builder outputA = OnnxMl.ValueInfoProto.newBuilder(); + outputA.setName("output-0"); + OnnxMl.TypeProto outputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + outputA.setType(outputType); + graph.addOutput(outputA); + OnnxMl.ValueInfoProto.Builder outputB = OnnxMl.ValueInfoProto.newBuilder(); + outputB.setName("output-1"); + outputB.setType(outputType); + graph.addOutput(outputB); + OnnxMl.ValueInfoProto.Builder outputC = OnnxMl.ValueInfoProto.newBuilder(); + outputC.setName("output-2"); + outputC.setType(outputType); + graph.addOutput(outputC); + + // Add initializers + OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder(); + tensor.addDims(4); + tensor.addDims(4); + Float[] floats = + new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f}; + tensor.addAllFloatData(Arrays.asList(floats)); + tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + tensor.setName("tensor"); + graph.addInitializer(tensor); + OnnxMl.TensorProto.Builder addInit = OnnxMl.TensorProto.newBuilder(); + addInit.addDims(4); + Float[] addFloats = new Float[] {1f, 2f, 3f, 4f}; + addInit.addAllFloatData(Arrays.asList(addFloats)); + addInit.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + addInit.setName("add-init"); + graph.addInitializer(addInit); + + // Add operations + OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder(); + matmul.setName("matmul-0"); + matmul.setOpType("MatMul"); + matmul.addInput("input"); + matmul.addInput("tensor"); + matmul.addOutput("matmul-output"); + graph.addNode(matmul); + + OnnxMl.NodeProto.Builder id = OnnxMl.NodeProto.newBuilder(); + id.setName("id-1"); + id.setOpType("Identity"); + id.addInput("matmul-output"); + id.addOutput("output-0"); + graph.addNode(id); + + OnnxMl.NodeProto.Builder add = OnnxMl.NodeProto.newBuilder(); + add.setName("add-2"); + add.setOpType("Add"); + add.addInput("matmul-output"); + add.addInput("add-init"); + add.addOutput("output-1"); + graph.addNode(add); + + OnnxMl.NodeProto.Builder log = OnnxMl.NodeProto.newBuilder(); + log.setName("log-3"); + log.setOpType("Log"); + log.addInput("matmul-output"); + log.addOutput("output-2"); + graph.addNode(log); + + // Build model + OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder(); + model.setGraph(graph); + model.setDocString("ORT three output matmul test"); + model.setModelVersion(0); + model.setIrVersion(8); + model.setDomain("ai.onnxruntime.test"); + model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build()); + try (OutputStream os = + Files.newOutputStream( + Paths.get("src", "test", "resources", "java-three-output-matmul.onnx"))) { + model.build().writeTo(os); + } + } + private static void genCast( String name, OnnxMl.TensorProto.DataType inputDataType, diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 7d41918b1c6c7..55d8169434d48 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -262,6 +262,12 @@ public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); } + public static void zeroBuffer(FloatBuffer buf) { + for (int i = 0; i < buf.capacity(); i++) { + buf.put(i, 0.0f); + } + } + public static float[] loadTensorFromFile(Path filename) { return loadTensorFromFile(filename, true); } diff --git a/java/src/test/java/ai/onnxruntime/TrainingTest.java b/java/src/test/java/ai/onnxruntime/TrainingTest.java index a02f5a88b2ac5..eaa7da1fc6a16 100644 --- a/java/src/test/java/ai/onnxruntime/TrainingTest.java +++ b/java/src/test/java/ai/onnxruntime/TrainingTest.java @@ -16,7 +16,6 @@ import java.util.Map; import java.util.Set; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfSystemProperty; @@ -69,8 +68,6 @@ public void testCreateTrainingSessionWithEval() throws OrtException { } } - // this test is not enabled as ORT Java doesn't support supplying an output buffer - @Disabled @Test public void testTrainingSessionTrainStep() throws OrtException { String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString(); @@ -99,14 +96,11 @@ public void testTrainingSessionTrainStep() throws OrtException { ByteBuffer.allocateDirect(4 * expectedOutput.length) .order(ByteOrder.nativeOrder()) .asFloatBuffer(); - OnnxTensor outputTensor = - OnnxTensor.createTensor(env, output, new long[expectedOutput.length]); + OnnxTensor outputTensor = OnnxTensor.createTensor(env, output, new long[0]); outputMap.put("onnx::loss::21273", outputTensor); - /* Disabled as we haven't implemented this yet - try (trainingSession.trainStep(pinnedInputs, outputMap)) { - Assertions.assertArrayEquals(expectedOutput, (float[]) outputTensor.getValue(), 1e-3f); + try (OrtSession.Result r = trainingSession.trainStep(pinnedInputs, outputMap)) { + Assertions.assertEquals(expectedOutput[0], (float) outputTensor.getValue(), 1e-3f); } - */ } finally { OnnxValue.close(outputMap); OnnxValue.close(pinnedInputs); diff --git a/java/src/test/resources/java-three-output-matmul.onnx b/java/src/test/resources/java-three-output-matmul.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fed0bbca460cfed900548ed1bfcc5018d8a927b3 GIT binary patch literal 530 zcma)&%}T>S5P-K$T9!c}tSAyZq(KiBwx;ps#iTdEVi8YXG&M`Rke`xHKt1(Qd=#I; zvybA$gn$L1hmYNz{braM&fSAZkMb;gEy@gasz#{R=%3u(KRCE7lydSCS0y@WglU;L z)$i4p0Uq>pMset)%GP-y_G>}by3L!X=k})&PRj(&;jbcitxC@}bu7m&zljyKfNyZI zr2>!QSn5n;n>4n2Rm^vdFplADE1}hVyO-n(dFdLr`9d7#1 Date: Tue, 26 Sep 2023 09:28:17 -0700 Subject: [PATCH 084/121] [TensorRT EP] Back out the PerThreadContext (#17690) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Current TRT EP's PerthreadContext allows more than one IExecutionContext instance to be created by one engine instance. But, it's possible to hit an error that caused by TRT API context.setBindingDimensions() in our TRT EP code [here](https://nam06.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fmicrosoft%2Fonnxruntime%2Fblob%2Fmain%2Fonnxruntime%2Fcore%2Fproviders%2Ftensorrt%2Ftensorrt_execution_provider.cc%23L2775&data=05%7C01%7CChi.Lo%40microsoft.com%7Cd8b23c3a4c0b4dcce9b408dbbd9309de%7C72f988bf86f141af91ab2d7cd011db47%7C1%7C0%7C638312211465211140%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=5EZoAoXgWFSuz%2BIRMH%2FXZaO%2BfKNP%2FZDZYEZg3W%2Ff30w%3D&reserved=0) under the case of the input shape changes ( meaning engine being rebuilt) with multithreading. From the [doc](https://nam06.safelinks.protection.outlook.com/?url=https%3A%2F%2Fdocs.nvidia.com%2Fdeeplearning%2Ftensorrt%2Fapi%2Fc_api%2Fclassnvinfer1_1_1_i_execution_context.html%23ada050e88320bcc40987b0acadc2ef962&data=05%7C01%7CChi.Lo%40microsoft.com%7Cd8b23c3a4c0b4dcce9b408dbbd9309de%7C72f988bf86f141af91ab2d7cd011db47%7C1%7C0%7C638312211465211140%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=%2BmVZU5iLD97B3YBPdHZP7jOQ2dGoleI3R0mSMVgopG4%3D&reserved=0) and the [discussion](https://nam06.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FNVIDIA%2FTensorRT%2Fissues%2F846&data=05%7C01%7CChi.Lo%40microsoft.com%7Cd8b23c3a4c0b4dcce9b408dbbd9309de%7C72f988bf86f141af91ab2d7cd011db47%7C1%7C0%7C638312211465211140%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=c8v%2FK2UkQ%2FNbf8w1sHNDGsB2kxw4sSmkyQ2QuCs8Fs8%3D&reserved=0), it seems we should have different OptimizationProfile for different IExecutionContext which our current TRT EP doesn’t support regardless of using PerThreadContext implementation. Back out the PerThreadContext until we completely solve this issue. --- .../tensorrt/tensorrt_execution_provider.cc | 120 ++++++------------ .../tensorrt/tensorrt_execution_provider.h | 17 +++ 2 files changed, 55 insertions(+), 82 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 96893f63b4540..55204abc80187 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1143,46 +1143,35 @@ bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); -} - -Status TensorrtExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); -} - -void TensorrtExecutionProvider::PerThreadContext::SetGraphStream(cudaStream_t stream) { - cuda_graph_.SetStream(stream); -} - -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::PerThreadContext::CaptureBegin() { +void TensorrtExecutionProvider::CaptureBegin() { cuda_graph_.Reset(); cuda_graph_.CaptureBegin(); } -void TensorrtExecutionProvider::PerThreadContext::CaptureEnd() { +void TensorrtExecutionProvider::CaptureEnd() { cuda_graph_.CaptureEnd(); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured() const { return is_graph_captured_; } -Status TensorrtExecutionProvider::PerThreadContext::ReplayGraph() { +Status TensorrtExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); // Please note that CUDAGraph::Replay() is not thread safe. - // The cuda graph object is maintained by a per thread basis, + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. return cuda_graph_.Replay(); } -void TensorrtExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - // The cuda graph object is maintained by a per thread basis, +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } @@ -1213,18 +1202,6 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } - - // The reason of !IsGraphCaptureEnabled(): - // If cuda graph is enabled, the per thread context will not be released - // because the per thread cuda graph needs to be maintained and replayed for - // the next run. - // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to release. - if (!IsGraphCaptureEnabled() && - PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { - ReleasePerThreadContext(); - } return Status::OK(); } @@ -2384,6 +2361,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &builders_[context->node_name], + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2445,6 +2415,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorinput_shape_ranges; auto trt_builder = trt_state->builder->get(); auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); @@ -2502,7 +2473,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->reset(); *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; @@ -2527,7 +2498,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->reset(); *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } @@ -2556,10 +2527,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext->reset(); trt_state->engine->reset(); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); @@ -2660,7 +2628,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectortrt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } } - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } trt_engine = trt_state->engine->get(); @@ -2706,32 +2674,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector new_context; + if (context_update) { if (trt_state->context_memory_sharing_enable) { - new_context.reset(trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); } else { - new_context.reset(trt_state->engine->get()->createExecutionContext()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); } - auto context_status = GetPerThreadContext().UpdateTensorRTContext(fused_node_name, std::move(new_context)); - if (!context_status) { + if (!(*(trt_state->context))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } - GetPerThreadContext().UpdateProfileShapes(fused_node_name, shape_ranges); + trt_context = trt_state->context->get(); } - // Get the reference to the IExecutionContext object that is maintained on a per thread basis. - nvinfer1::IExecutionContext& trt_context = GetPerThreadContext().GetTensorRTContext(fused_node_name); - // Get input and output binding names int total_bindings = trt_engine->getNbBindings(); std::vector buffers(total_bindings); @@ -2767,12 +2723,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context.setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); } - const bool status = trt_context.setBindingDimensions(binding_index, dimensions); + const bool status = trt_context->setBindingDimensions(binding_index, dimensions); if (!status) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP cannot set the dynamic dimensions of a binding")); @@ -2911,7 +2867,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - nvinfer1::Dims dimensions = trt_context.getBindingDimensions(static_cast(binding_index)); + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); int nb_dims = dimensions.nbDims; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { @@ -3045,20 +3001,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } - trt_context.setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().SetGraphStream(stream); - GetPerThreadContext().CaptureBegin(); + cuda_graph_.SetStream(stream); + CaptureBegin(); } // Run TRT inference - if (!trt_context.enqueueV2(&buffers[0], stream, nullptr)) { + if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } @@ -3089,14 +3045,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector* parser = nullptr; std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; @@ -246,6 +247,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> parsers_; std::unordered_map> engines_; + std::unordered_map> contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -256,6 +258,21 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + cudnnHandle_t external_cudnn_handle_ = nullptr; + cublasHandle_t external_cublas_handle_ = nullptr; + + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // [Note] We don't use PerThreadContext for now since it has issue with multithreading + // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. // For example, TensorRT execution context and CUDA graph are the ones to be put here. class PerThreadContext final { From 1c245e6775201fec7a1d269a71deb0f3af6f7bea Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:46:30 -0700 Subject: [PATCH 085/121] Stop throwing exception on python binding when multiple EP available (#17659) Stop throwing the exception when the provider list is empty but there are multiple available EPs. Other language bindings throw no exception at all, this change will align them up. --------- Co-authored-by: Randy Shuai --- .../onnxruntime_inference_collection.py | 12 ++------ .../test/python/onnxruntime_test_python.py | 30 +++++-------------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 4124822adef1f..bcc6f15129231 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -438,7 +438,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. if "TensorrtExecutionProvider" in available_providers: - if any( + if providers and any( provider == "CUDAExecutionProvider" or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") for provider in providers @@ -448,7 +448,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._fallback_providers = ["CPUExecutionProvider"] # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. elif "MIGraphXExecutionProvider" in available_providers: - if any( + if providers and any( provider == "ROCMExecutionProvider" or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") for provider in providers @@ -463,14 +463,6 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options = check_and_normalize_provider_args( providers, provider_options, available_providers ) - if not providers and len(available_providers) > 1: - self.disable_fallback() - raise ValueError( - f"This ORT build has {available_providers} enabled. " - "Since ORT 1.9, you are required to explicitly set " - "the providers parameter when instantiating InferenceSession. For example, " - f"onnxruntime.InferenceSession(..., providers={available_providers}, ...)" - ) session_options = self._sess_options if self._sess_options else C.get_default_session_options() if self._model_path: diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 59f7781bb4f8a..1d954fe4370ad 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -80,11 +80,7 @@ def test_model_serialization(self): so.log_severity_level = 1 so.logid = "TestModelSerialization" so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx" - onnxrt.InferenceSession( - get_name("mul_1.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) os.remove(so.optimized_model_filepath) except Fail as onnxruntime_error: @@ -107,11 +103,7 @@ def test_model_serialization_with_external_initializers(self): "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession( - get_name("mnist.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(external_initializers_file)) os.remove(so.optimized_model_filepath) @@ -137,7 +129,7 @@ def test_model_serialization_with_external_initializers_to_directory(self): "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so, providers=["CPUExecutionProvider"]) + onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file))) os.remove(so.optimized_model_filepath) @@ -163,9 +155,7 @@ def test_model_serialization_with_original_external_initializers_to_directory(se "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession( - get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] - ) + onnxrt.InferenceSession(get_name("model_with_orig_ext_data.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file))) os.remove(so.optimized_model_filepath) @@ -198,9 +188,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire # still refers to the original external data file. We shall fix this issue so that the # optimized model only refers to one external data file. so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") - session1 = onnxrt.InferenceSession( - get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] - ) + session1 = onnxrt.InferenceSession(get_name("model_with_orig_ext_data.onnx"), sess_options=so) del session1 self.assertTrue(os.path.isfile(optimized_model_filepath)) self.assertTrue(os.path.isfile(external_initializers_file)) @@ -216,9 +204,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire # verify that we can load the optimized model with external data in current directory and save # optimized model with external data to current directory. - session2 = onnxrt.InferenceSession( - optimized_model_filepath, sess_options=so2, providers=["CPUExecutionProvider"] - ) + session2 = onnxrt.InferenceSession(optimized_model_filepath, sess_options=so2) del session2 self.assertTrue(os.path.isfile(optimized_model_filepath_2)) self.assertTrue(os.path.isfile(external_initializers_file_2)) @@ -227,9 +213,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire os.remove(optimized_model_filepath) os.remove(external_initializers_file) - session3 = onnxrt.InferenceSession( - optimized_model_filepath_2, sess_options=onnxrt.SessionOptions(), providers=["CPUExecutionProvider"] - ) + session3 = onnxrt.InferenceSession(optimized_model_filepath_2, sess_options=onnxrt.SessionOptions()) del session3 os.remove(optimized_model_filepath_2) From f43acf2d33ca4c2f87b0927929123ebfaed82b1a Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Wed, 27 Sep 2023 01:51:13 +0900 Subject: [PATCH 086/121] Close the JSON object in settings.json (#17583) ### Description This patch adds a closing curly bracket at the end of `settings.json`. ### Motivation and Context `settings.json` is just not closed. It was accidentally removed at 4e6ea730d633756e9e04df8968304d11a575dde4 --- .vscode/settings.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/settings.json b/.vscode/settings.json index b7a1292efb2c6..fd28e2d7b335c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,3 +40,4 @@ "-build/include_subdir", "-runtime/references" ] +} From b8e348145cbcfaa72172aa78b8c21f1544d3854c Mon Sep 17 00:00:00 2001 From: Vadym Stupakov Date: Tue, 26 Sep 2023 19:57:01 +0300 Subject: [PATCH 087/121] fixed #16873 (#16932) --- tools/perf_view/ort_perf_view.html | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/perf_view/ort_perf_view.html b/tools/perf_view/ort_perf_view.html index e00e38702d342..509fe5593f6a1 100644 --- a/tools/perf_view/ort_perf_view.html +++ b/tools/perf_view/ort_perf_view.html @@ -5,7 +5,7 @@ - +