From 37d14d78960fb1ba54c0bb2dc3be740e93d2ca15 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 22 Jan 2024 18:14:41 -0800 Subject: [PATCH 01/23] [QNN EP] Create Windows ARM64 nightly python package (#19128) ### Description Adds a job to create a nightly python package for ORT/QNN on Windows ARM64. Must build onnxruntime-qnn with python 3.11 and numpy 1.25. **Note: pipeline run may take up to 3 hrs** ### Motivation and Context Make it possible to get a nightly python package with the latest updates to QNN EP. Issue #19161 --- .../azure-pipelines/py-packaging-pipeline.yml | 8 +- .../templates/py-packaging-stage.yml | 13 ++ .../templates/py-win-arm64-qnn.yml | 165 ++++++++++++++++++ 3 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 06cca0068523d..5349b1ca67ab1 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -29,6 +29,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -64,5 +69,6 @@ stages: enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file + cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 8669a883c31f1..297498843c38d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -35,6 +35,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -446,3 +451,11 @@ stages: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} + + - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - template: py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: 'qnn-v2.18.0.240101_win' + PYTHON_VERSION: '3.11' + NUMPY_VERSION: '1.25.2' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml new file mode 100644 index 0000000000000..adf7aa9c43205 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -0,0 +1,165 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'onnxruntime-qnn-windows-vs-2022-arm64' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: PYTHON_VERSION + type: string + default: '3.11' + +- name: NUMPY_VERSION + type: string + default: '1.25.2' + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\3.11.0 + DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + displayName: Copy python 3.11.0 version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: ${{ parameters.PYTHON_VERSION }} + addToPath: true + architecture: 'arm64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', 'numpy==${{parameters.NUMPY_VERSION}}']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + --numpy_version ${{ parameters.NUMPY_VERSION }} + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' From b2aec41a8309bc2dced74a991b1f3c311e037e3d Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 22 Jan 2024 19:17:04 -0800 Subject: [PATCH 02/23] [ROCm] enable hipGraph (#18382) This ports the cudaGraph support from the CUDA EP to the ROCM EP's hipGraph. --- cmake/onnxruntime_unittests.cmake | 7 ++ .../core/session/onnxruntime_c_api.h | 3 + .../providers/rocm/rocm_execution_provider.cc | 77 +++++++++++- .../providers/rocm/rocm_execution_provider.h | 24 ++++ .../rocm/rocm_execution_provider_info.cc | 3 + .../rocm/rocm_execution_provider_info.h | 2 + .../providers/rocm/rocm_provider_factory.cc | 2 + onnxruntime/core/session/inference_session.cc | 52 +++++--- .../core/session/provider_bridge_ort.cc | 1 + onnxruntime/test/shared_lib/test_inference.cc | 112 +++++++++++++++--- 10 files changed, 241 insertions(+), 42 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fa395802d95ff..0987d6d164dbd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() + if (onnxruntime_USE_ROCM) + list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) + endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) + target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 101a578ec3e1d..2ce9d361e8e56 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, + enable_hip_graph{false}, tunable_op_enable{false}, tunable_op_tuning_enable{false}, tunable_op_max_tuning_duration_ms{} {} @@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; + int enable_hip_graph; + /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7c5098d9dbe4..d7bec337a6be4 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -170,6 +170,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); + + hip_graph_.SetStream(stream); } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { @@ -177,6 +179,33 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +} + +void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { + hip_graph_.Reset(); + hip_graph_.CaptureBegin(); +} + +void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { + hip_graph_.CaptureEnd(); + is_graph_captured_ = true; +} + +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + return hip_graph_.Replay(); +} + +void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} + void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); @@ -219,6 +248,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; + } else if (info.enable_hip_graph) { + // current hip graph implementation only works with single stream + // use EP level unified stream for all the reqeust + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); + use_ep_level_unified_stream_ = true; } else { stream_ = nullptr; } @@ -322,25 +356,58 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart() { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; + GetPerThreadContext().CaptureBegin(); + } return Status::OK(); } Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { + if (GetPerThreadContext().IsGraphCaptureAllowed()) { + GetPerThreadContext().CaptureEnd(); + // HIP work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + } else { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + } + } + if (sync_stream) { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to - // release. This didn't happen before because we always call - // GetPerThreadContext in OnRunStart. - if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { + // The reason of !IsGraphCaptureEnabled(): + // If hip graph is enabled, the per thread context will not be released + // because the per thread hip graph needs to be maintained and replayed for + // the next run. + // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): + // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), + // PerThreadContext won't be created and there is nothing to + // release. This didn't happen before because we always call + // GetPerThreadContext in OnRunStart. + if (!IsGraphCaptureEnabled() && + PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { ReleasePerThreadContext(); } return Status::OK(); } +bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { + return info_.enable_hip_graph; +} + +bool ROCMExecutionProvider::IsGraphCaptured() const { + return GetPerThreadContext().IsGraphCaptured(); +} + +Status ROCMExecutionProvider::ReplayGraph() { + return GetPerThreadContext().ReplayGraph(); +} + namespace rocm { // opset 1 to 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index c4945b9ac2481..37d5f7b42210f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -10,6 +10,7 @@ #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" #include "core/providers/rocm/rocm_execution_provider_info.h" +#include "core/providers/rocm/rocm_graph.h" #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -73,6 +74,9 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -81,6 +85,7 @@ class ROCMExecutionProvider : public IExecutionProvider { ROCMExecutionProviderInfo info_; hipDeviceProp_t device_prop_; bool external_stream_ = false; + // only used when set user external stream or hip graph hipStream_t stream_ = nullptr; bool use_ep_level_unified_stream_ = false; @@ -133,6 +138,13 @@ class ROCMExecutionProvider : public IExecutionProvider { } } + bool IsGraphCaptureAllowed() const; + void CaptureBegin(); + void CaptureEnd(); + bool IsGraphCaptured() const; + Status ReplayGraph(); + void IncrementRegularRunCountBeforeGraphCapture(); + private: rocblas_handle rocblas_handle_ = nullptr; miopenHandle_t miopen_handle_ = nullptr; @@ -141,6 +153,18 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; std::unique_ptr> constant_ones_bfloat16_; + + // Hip graph with multi threads will be supported in the future, so hip_graph_ + // is put under PerThreadContext. + ROCMGraph hip_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + + // There is chance that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 650635c153640..b557f92287f2b 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; +constexpr const char* kEnableHipGraph = "enable_hip_graph"; constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) + .AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph) .AddValueParser( rocm::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, + {rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)}, {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index e35c0cc0afecc..2f549cc1ac143 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo { // If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best. bool miopen_conv_use_max_workspace{true}; + bool enable_hip_graph{false}; + rocm::TunableOpInfo tunable_op{}; static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 4d88c25469372..88ef666678b3e 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; + info.enable_hip_graph = params->enable_hip_graph; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider { rocm_options.user_compute_stream = internal_options.user_compute_stream; } rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.enable_hip_graph = internal_options.enable_hip_graph; rocm_options.tunable_op_enable = internal_options.tunable_op.enable; rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8853c8824738..39f47c09f2402 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -153,7 +153,7 @@ static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_eps_only = false; break; @@ -1715,7 +1715,8 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently CUDA graph is only considered by CUDA EP and TRT EP, and + // HIP graph is only considered by ROCM EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1728,47 +1729,58 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for ROCM EP: + // If the ROCM EP is part of the providers list for this session AND + // The ROCM EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the ROCM EP, + // Then the ROCM EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // CUDA/HIP Graphs can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.")); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the CUDA/HIP EP."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.")); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the CUDA/HIP Graph feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the CUDA/HIP graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the CUDA/HIP Graph feature."; } } else { // Following code path is for TRT EP currently. @@ -1787,7 +1799,7 @@ common::Status InferenceSession::Initialize() { } } - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + LOGS(*session_logger_, INFO) << "This session will use the CUDA/HIP Graph feature as requested by the user."; cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); break; // Make sure only one ep can run CUDA graph. } @@ -2477,7 +2489,9 @@ Status InferenceSession::Run(const RunOptions& run_options, // As N+1 inference runs (N for memory allocation and 1 for graph capturing) // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. - // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. + // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, + // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, + // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3269c9f0f4e4b..3178c13d30eec 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2380,6 +2380,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProvider options->has_user_compute_stream = 0; options->user_compute_stream = nullptr; options->default_memory_arena_cfg = nullptr; + options->enable_hip_graph = false; options->tunable_op_enable = 0; options->tunable_op_tuning_enable = 0; options->tunable_op_max_tuning_duration_ms = 0; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6ffe72f81bd24..8dad2c8e2d10d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -43,6 +43,10 @@ #include #endif +#ifdef USE_ROCM +#include +#endif + // Once we use C++17 this could be replaced with std::size template constexpr size_t countof(T (&)[N]) { return N; } @@ -1762,6 +1766,27 @@ TEST(CApiTest, get_allocator_cuda) { } #endif +#ifdef USE_ROCM +TEST(CApiTest, get_allocator_rocm) { + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); + Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); + + Ort::MemoryInfo info_rocm("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + Ort::Allocator rocm_allocator(session, info_rocm); + + auto allocator_info = rocm_allocator.GetInfo(); + ASSERT_TRUE(info_rocm == allocator_info); + void* p = rocm_allocator.Alloc(1024); + ASSERT_NE(p, nullptr); + rocm_allocator.Free(p); + + auto mem_allocation = rocm_allocator.GetAllocation(1024); + ASSERT_NE(nullptr, mem_allocation.get()); + ASSERT_EQ(1024U, mem_allocation.size()); +} +#endif + TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -1937,7 +1962,7 @@ TEST(CApiTest, io_binding_cuda) { } #endif -#if defined(USE_CUDA) || defined(USE_TENSORRT) +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, basic_cuda_graph) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; @@ -1955,7 +1980,7 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( static_cast(session_options), rel_trt_options.get()) == nullptr); -#else +#elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -1968,34 +1993,55 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( static_cast(session_options), rel_cuda_options.get()) == nullptr); +#elif defined(USE_ROCM) + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); - Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#if defined(USE_ROCM) +// local hipify +#define cudaMemcpy hipMemcpy +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost + Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#else + Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#endif - Ort::Allocator cuda_allocator(session, info_cuda); - auto allocator_info = cuda_allocator.GetInfo(); - ASSERT_TRUE(info_cuda == allocator_info); + Ort::Allocator allocator(session, info_mem); + auto allocator_info = allocator.GetInfo(); + ASSERT_TRUE(info_mem == allocator_info); const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - auto input_data = cuda_allocator.GetAllocation(x_values.size() * sizeof(float)); + auto input_data = allocator.GetAllocation(x_values.size() * sizeof(float)); ASSERT_NE(input_data.get(), nullptr); - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data.get()), x_values.size(), + Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, reinterpret_cast(input_data.get()), x_values.size(), x_shape.data(), x_shape.size()); const std::array expected_y_shape = {3, 2}; std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; - auto output_data = cuda_allocator.GetAllocation(expected_y.size() * sizeof(float)); + auto output_data = allocator.GetAllocation(expected_y.size() * sizeof(float)); ASSERT_NE(output_data.get(), nullptr); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), + Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, reinterpret_cast(output_data.get()), expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); // Create IoBinding for inputs and outputs. @@ -2008,31 +2054,37 @@ TEST(CApiTest, basic_cuda_graph) { // Check the values against the bound raw memory (needs copying from device to host first) std::array y_values; - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Replay the captured CUDA graph session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Change the input and replay the CUDA graph again. x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); binding.SynchronizeInputs(); session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); +#if defined(USE_ROCM) +#undef cudaMemcpy +#undef cudaMemcpyHostToDevice +#undef cudaMemcpyDeviceToHost +#endif } -#ifndef REDUCED_OPS_BUILD // The following test uses some ops not supported in the reduced ops build +#ifndef REDUCED_OPS_BUILD +#if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { const auto& api = Ort::GetApi(); @@ -2053,10 +2105,34 @@ TEST(CApiTest, cuda_graph_with_shape_nodes) { // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); } +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) -#endif +#if defined(USE_ROCM) +TEST(CApiTest, hip_graph_with_shape_nodes) { + const auto& api = Ort::GetApi(); -#endif + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + Ort::SessionOptions session_options; + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); + + // Successful loading of the ONNX model with shape nodes with hip graph feature enabled + Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); +} +#endif // defined(USE_ROCM) + +#endif // REDUCED_OPS_BUILD + +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; From 6ca7c1a933e57e0078d8d01eff3a1520098cfed1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jan 2024 20:42:30 -0800 Subject: [PATCH 03/23] unet fusion for stable diffusion webui (#19227) ### Description Update unet fusion for [stable diffusion webui extension](https://github.com/tianleiwu/Stable-Diffusion-WebUI-OnnxRuntime): (1) Update fusion pattern to support fp16 unet model. (2) Add progress bar (3) Use a cached map to speed up dtype or shape lookup in shape inference result. ### Motivation and Context --- .../tools/transformers/fusion_attention.py | 14 +- .../transformers/fusion_attention_unet.py | 166 ++++++++++++++++-- .../tools/transformers/fusion_embedlayer.py | 18 +- .../tools/transformers/fusion_gemmfastgelu.py | 2 +- .../tools/transformers/fusion_nhwc_conv.py | 15 +- .../python/tools/transformers/fusion_shape.py | 8 +- .../python/tools/transformers/fusion_utils.py | 47 +++-- .../python/tools/transformers/import_utils.py | 20 +++ .../models/stable_diffusion/README.md | 2 +- .../python/tools/transformers/onnx_model.py | 98 ++++++++--- .../tools/transformers/onnx_model_bert.py | 16 +- .../tools/transformers/onnx_model_unet.py | 71 +++++++- 12 files changed, 395 insertions(+), 82 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/import_utils.py diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index d11cb91d98b0c..f48cabd25fc5c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -129,6 +129,9 @@ def __init__( self.num_heads_warning = True self.hidden_size_warning = True + self.shape_infer = None + self.shape_infer_done = True + def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] return num_heads, hidden_size def get_add_qk_str(self, add_qk: NodeProto): - shape_infer = self.model.infer_runtime_shape(update=True) - if shape_infer is None: + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is None: return None - input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) - input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) + input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0]) + input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: logger.debug(f"one of the inputs of {add_qk} is None") diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 250ec5f3eb159..9a353e7e2d675 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -28,10 +28,19 @@ def __init__( enable_packed_qkv: bool, enable_packed_kv: bool, ): - super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) + super().__init__( + model, + "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention", + ["LayerNormalization"], + ) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + + # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA. + # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization, + # and CUDA operator pre-packs those tensors to preferred format based on available kernels. + # In this way, we can support LoRA and get optimal performance at same time. self.enable_packed_qkv = enable_packed_qkv self.enable_packed_kv = enable_packed_kv @@ -170,9 +179,7 @@ def create_attention_node( return None # Sometimes weights are stored in fp16 - if q_weight.data_type == 10: - logger.debug("weights are in fp16. Please run fp16 conversion after optimization") - return None + float_type = q_weight.data_type qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) @@ -212,7 +219,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], vals=qkv_weight, ) @@ -235,8 +242,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[attention_node_name + "_input"], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_qkv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -251,7 +261,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qw_in_size, qkv_weight_dim], vals=qkv_weight, ) @@ -280,7 +290,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[kv_weight.shape[0], kv_weight.shape[1]], vals=kv_weight, ) @@ -303,8 +313,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[k_matmul.output[0]], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_kv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -317,7 +330,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_bias_dim], vals=qkv_bias, ) @@ -330,7 +343,7 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [attention_node_name + "_input"] + attention_inputs = [attention_node_name + "_qkv_input"] else: if not self.enable_packed_kv: attention_inputs = [ @@ -342,7 +355,7 @@ def create_attention_node( else: attention_inputs = [ q_matmul.output[0], - k_matmul.output[0], + attention_node_name + "_kv_input", ] attention_node = helper.make_node( @@ -839,6 +852,9 @@ def create_attention_node_lora( return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node): + return + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) # In SD 1.5, for self attention, LayerNorm has parent Reshape @@ -1168,3 +1184,125 @@ def match_lora_path( return (lora_mul_node, lora_matmul_1_node) return None + + def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node): + """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension""" + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0]) + if entry_path is None: + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0]) + if entry_path is None: + return False + _cast, node_before_layernorm = entry_path + + root_input = node_before_layernorm.output[0] + + children_nodes = input_name_to_nodes[root_input] + skip_add = None + for node in children_nodes: + if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet + skip_add = node + break + if skip_add is None: + return False + + match_qkv = self.match_qkv_a1111(root_input, skip_add) + if match_qkv is None: + return False + + ( + reshape_qkv, + transpose_qkv, + reshape_q, + matmul_q, + matmul_k, + matmul_v, + ) = match_qkv + + cast_q = self.model.match_parent(matmul_q, "Cast", 0) + cast_k = self.model.match_parent(matmul_k, "Cast", 0) + cast_v = self.model.match_parent(matmul_v, "Cast", 0) + if not ( + cast_q is not None + and cast_k is not None + and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k) + and cast_k == cast_v + ): + return False + + if cast_q.input[0] != normalize_node.output[0]: + return False + + attention_last_node = reshape_qkv + + q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + q_hidden_size = self.get_hidden_size(normalize_node) + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=matmul_q.input[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return False + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_qkv_a1111(self, root_input, skip_add): + """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path( + einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None] + ) + if qk_nodes is not None: + (_, _, _softmax_qk, _, einsum_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_, _, _, matmul_k) = k_nodes + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index bc38399e3cce5..42156d9123383 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"): description, ) self.utils = FusionUtils(model) - self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = None + self.shape_infer_done = False + # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] - if self.shape_infer_helper is not None: - input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) - position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids) + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is not None: + input_ids_shape = self.shape_infer.get_edge_shape(input_ids) + position_ids_shape = self.shape_infer.get_edge_shape(position_ids) assert input_ids_shape and position_ids_shape if not ( len(input_ids_shape) == 2 @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit ) return False - if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids): + if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format( input_ids_shape, - self.shape_infer_helper.get_edge_shape(segment_ids), + self.shape_infer.get_edge_shape(segment_ids), ) ) return False diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index f1d803a3cc082..4d9913f427b37 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index 141ebb1f95a11..5233fdf272fbd 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -7,7 +7,8 @@ from typing import List from fusion_base import Fusion -from onnx import TensorProto, helper, numpy_helper +from fusion_utils import FusionUtils +from onnx import helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion): def __init__(self, model: OnnxModel, update_weight=False): super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") self.update_weight = update_weight + self.fusion_utils = FusionUtils(model) def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): """Append a Transpose node after an input""" @@ -49,6 +51,15 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): if len(weight.shape) != 4: return + dtype = self.model.get_dtype(nhwc_conv_input) + if not (dtype is not None and weight_tensor.data_type == dtype): + cast_node = self.fusion_utils.add_cast_node( + input_name=nhwc_conv_input, + to_type=weight_tensor.data_type, + output_name_to_node=output_name_to_node, + ) + nhwc_conv_input = cast_node.output[0] + if self.update_weight: # Transpose weights from NCHW to NHWC weight = weight.transpose(0, 2, 3, 1) @@ -56,7 +67,7 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight_name = node_name + "_weight_NHWC" self.add_initializer( name=weight_name, - data_type=TensorProto.FLOAT, + data_type=weight_tensor.data_type, dims=list(weight.shape), vals=weight, ) diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index bc32d78eda66c..dfa77fc7d0221 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i return None def get_dimensions(self, input_name: str) -> Union[int, None]: - graph_input = self.model.find_graph_input(input_name) - if graph_input: - return self.get_dimensions_from_tensor_proto(graph_input) + shape = self.model.get_shape(input_name) + if shape is not None: + return len(shape) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index afc968fab46c1..726c587ff7043 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple +from typing import Optional, Tuple import numpy from numpy import array_equal, ndarray @@ -29,17 +29,7 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: return False, input_name def cast_input(self, input_name: str, target_type="int32"): - cast_output = input_name + "_" + target_type - - # Avoid consequent Cast nodes. - inputs = [input_name] - output_name_to_node = self.model.output_name_to_node() - if input_name in output_name_to_node: - parent_node = output_name_to_node[input_name] - if parent_node and parent_node.op_type == "Cast": - inputs = [parent_node.input[0]] - - cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) + output_name = input_name + "_" + target_type if target_type == "int32": to_type = int(TensorProto.INT32) @@ -50,10 +40,36 @@ def cast_input(self, input_name: str, target_type="int32"): else: raise ValueError("Invalid target_type: {target_type}") + cast_node = self.add_cast_node(input_name, to_type, output_name) + + return output_name, cast_node + + def add_cast_node( + self, + input_name: str, + to_type: int, + output_name: Optional[str] = None, + output_name_to_node=None, + graph_name: Optional[str] = None, + ): + if output_name is None: + output_name = input_name + f"_cast_to_{to_type}" + + # Avoid consequent Cast nodes. + inputs = [input_name] + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + if input_name in output_name_to_node: + parent_node = output_name_to_node[input_name] + if parent_node and parent_node.op_type == "Cast": + inputs = [parent_node.input[0]] + + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name]) + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) - self.model.add_node(cast_node) + self.model.add_node(cast_node, graph_name=graph_name) - return cast_output, cast_node + return cast_node def cast_input_to_int32(self, input_name: str): return self.cast_input(input_name, "int32") @@ -224,9 +240,10 @@ def check_node_input_value(self, node, input_index: int, expected_value): def remove_identity_nodes(self): """Remove Identity nodes, except those right before graph output.""" nodes_to_remove = [] + graph_output_names = self.model.get_graphs_output_names() for node in self.model.nodes(): if node.op_type == "Identity": - if node.output[0] not in self.model.get_graphs_output_names(): + if node.output[0] not in graph_output_names: self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/import_utils.py b/onnxruntime/python/tools/transformers/import_utils.py new file mode 100644 index 0000000000000..9755a26b7b004 --- /dev/null +++ b/onnxruntime/python/tools/transformers/import_utils.py @@ -0,0 +1,20 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import importlib.metadata +import importlib.util + + +def is_installed(package): + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index b10c10c87ee57..8607485bc265b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -51,7 +51,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --allow_running_as_root python3 -m pip install --upgrade pip -python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-*.whl --force-reinstall ``` If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity (like 89 for RTX 4090, or 86 for RTX 3090). diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 37b39c91b5c15..9d1066b6e372b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -40,6 +40,12 @@ def initialize(self, model): self.enable_shape_infer: bool = True self.all_graphs: Optional[List[GraphProto]] = None + # Cache of shape and data type from onnx graph to speed up optimization. + # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) + # Note that these do not cache the symbolic shape inference result. + self._dtype_dict: Optional[Dict[str, int]] = None + self._shape_dict: Optional[Dict[str, List]] = None + def disable_shape_inference(self): self.enable_shape_infer = False @@ -519,20 +525,60 @@ def tensor_shape_to_list(self, tensor_type): shape_list.append("?") # shall not happen return shape_list - def get_dtype(self, input_or_output: str): - """Try get data type given a name (could be initializer, graph input or output).""" - tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get data type given a name (could be initializer, input or output of graph or node).""" + + if self._dtype_dict is None: + self._dtype_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type + + for initializer in self.model.graph.initializer: + if initializer.name not in self._dtype_dict: + self._dtype_dict[initializer.name] = initializer.data_type - if input_or_output in tensor_type_map: - return tensor_type_map[input_or_output].tensor_type.elem_type + if name in self._dtype_dict: + return self._dtype_dict[name] - graph_input = self.find_graph_input(input_or_output) - if graph_input: - return graph_input.type.tensor_type.elem_type + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None - graph_output = self.find_graph_output(input_or_output) - if graph_output: - return graph_output.type.tensor_type.elem_type + def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get shape given a name (could be initializer, input or output of graph or node).""" + + if self._shape_dict is None: + self._shape_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + if value_info.type.tensor_type.HasField("shape"): + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + self._shape_dict[value_info.name] = shape + + for initializer in self.model.graph.initializer: + if initializer.name not in self._shape_dict: + self._shape_dict[initializer.name] = initializer.dims + + if name in self._shape_dict: + return self._shape_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type return None @@ -566,23 +612,14 @@ def remove_cascaded_cast_nodes(self): def remove_useless_cast_nodes(self): """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.infer_runtime_shape(update=True) - if shape_infer is None: - logger.info("Skip removing useless cast nodes since shape inference failed.") - return - - def get_data_type(input_or_output_name): - dtype = self.get_dtype(input_or_output_name) - if dtype: - return dtype - if shape_infer.known_vi_[input_or_output_name].type.tensor_type.HasField("elem_type"): - return shape_infer.known_vi_[input_or_output_name].type.tensor_type.elem_type - return None + if self.enable_shape_infer and shape_infer is None: + logger.warning("shape inference failed which might impact useless cast node detection.") nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Cast": - input_dtype = get_data_type(node.input[0]) - output_dtype = get_data_type(node.output[0]) + input_dtype = self.get_dtype(node.input[0], shape_infer) + output_dtype = self.get_dtype(node.output[0], shape_infer) if input_dtype and input_dtype == output_dtype: nodes_to_remove.append(node) @@ -601,7 +638,10 @@ def get_data_type(input_or_output_name): self.replace_input_of_all_nodes(node.output[0], node.input[0]) self.remove_node(node) - logger.info("Removed %d Cast nodes with output type same as input", len(nodes_to_remove)) + logger.info( + "Removed %d Cast nodes with output type same as input", + len(nodes_to_remove), + ) def convert_model_float32_to_float16(self, cast_input_output=True): logger.warning( @@ -1214,7 +1254,10 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): continue for j in range(i + 1, initializer_count): if OnnxModel.has_same_value( - self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache + self.model.graph.initializer[i], + self.model.graph.initializer[j], + cache, + cache, ): same[j] = i @@ -1223,7 +1266,8 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): if same[i] >= 0: count += 1 self.replace_input_of_all_nodes( - self.model.graph.initializer[i].name, self.model.graph.initializer[same[i]].name + self.model.graph.initializer[i].name, + self.model.graph.initializer[same[i]].name, ) if count > 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 51deb67ce5bf3..431e64509e3cc 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -126,7 +126,8 @@ def fuse_rotary_embeddings(self): # Remove non-MS domain functions rot_emb_nodes = list( filter( - lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", + self.model.graph.node, ) ) non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) @@ -350,7 +351,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): self.attention_fusion = FusionAttention( - self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention + self, + self.hidden_size, + self.num_heads, + self.attention_mask, + options.use_multi_head_attention, ) if (options is None) or options.enable_attention: @@ -415,7 +420,12 @@ def get_fused_operator_statistics(self): "SkipSimplifiedLayerNormalization", "RotaryEmbedding", ] - q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] + q_ops = [ + "QOrderedAttention", + "QOrderedGelu", + "QOrderedLayerNormalization", + "QOrderedMatMul", + ] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 4d15b9288e7b6..01298b3576eb1 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from logging import getLogger +import logging from typing import Optional from fusion_attention_unet import FusionAttentionUnet @@ -14,11 +14,12 @@ from fusion_options import FusionOptions from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose +from import_utils import is_installed from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel -logger = getLogger(__name__) +logger = logging.getLogger(__name__) class UnetOnnxModel(BertOnnxModel): @@ -94,14 +95,24 @@ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False + self, + self.hidden_size, + self.num_heads, + is_cross_attention=False, + enable_packed_qkv=enable_packed_qkv, + enable_packed_kv=False, ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv cross_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, True, False, enable_packed_kv + self, + self.hidden_size, + self.num_heads, + is_cross_attention=True, + enable_packed_qkv=False, + enable_packed_kv=enable_packed_kv, ) cross_attention_fusion.apply() @@ -110,23 +121,48 @@ def fuse_bias_add(self): fusion.apply() def optimize(self, options: Optional[FusionOptions] = None): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 18 + progress_bar = tqdm.tqdm(range(0, steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_gelu: self.fuse_gelu() + if progress_bar: + progress_bar.update(1) self.preprocess() + if progress_bar: + progress_bar.update(1) self.fuse_reshape() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_group_norm: channels_last = (options is None) or options.group_norm_channels_last @@ -135,42 +171,66 @@ def optimize(self, options: Optional[FusionOptions] = None): insert_transpose_fusion = FusionInsertTranspose(self) insert_transpose_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_attention: + # self.save_model_to_file("before_mha.onnx") self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) self.fuse_shape() + if progress_bar: + progress_bar.update(1) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_group_norm: skip_group_norm_fusion = FusionSkipGroupNorm(self) skip_group_norm_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_gelu_approximation: self.gelu_approximation() + if progress_bar: + progress_bar.update(1) if options is None or options.enable_nhwc_conv: self.convert_conv_to_nhwc() - self.merge_adjacent_transpose() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_bias_add: self.fuse_bias_add() + if progress_bar: + progress_bar.update(1) self.postprocess() + if progress_bar: + progress_bar.update(1) logger.info(f"opset version: {self.get_opset_version()}") @@ -190,6 +250,7 @@ def get_fused_operator_statistics(self): "NhwcConv", "BiasAdd", ] + for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) From 61610ff9862ad834f153ed3e70ba526dac86ae7c Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 24 Jan 2024 00:25:05 +0800 Subject: [PATCH 04/23] [js/webgpu] Add FusedConv clip test case (#18900) Bug: https://github.com/microsoft/onnxruntime/issues/18899 --- js/web/test/data/ops/fused-conv.jsonc | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 812e9d7c2def0..ad1c0a72c11d3 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -108,5 +108,39 @@ ] } ] + }, + { + "name": "fused conv with clip", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [400, 470, 600, 600], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] From 0ea48fc73ec6bdbb8af2010483a61823fcf1a613 Mon Sep 17 00:00:00 2001 From: Heflin Stephen Raj Date: Tue, 23 Jan 2024 23:40:54 +0530 Subject: [PATCH 05/23] Modified the condition to load the optimiser model (#18891) --- java/src/main/native/ai_onnxruntime_OrtTrainingSession.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index 9f7b8d3a3dcfc..464234c34798a 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes } } wchar_t* optimizerStr = NULL; - if (optimizerPath == NULL) { + if (optimizerPath != NULL) { optimizerStr = copyAndPad(jniEnv, optimizerPath); if (optimizerStr == NULL) { // exception has been thrown in Java, go to cleanup and return null. From 54871a27736cf54cbda9c4f09bb27e931de7334e Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jan 2024 02:49:24 +0800 Subject: [PATCH 06/23] Replace T4 to A10 in Linux GPU workflow (#19205) ### Description 1. Update Linux GPU machine from T4 to A10, sm=8.6 2. update the tolerance ### Motivation and Context 1. Free more T4 and test with higher compute capability. 2. ORT enables TF32 in GEMM for A10/100. TF32 will cause precsion loss and fail this test ``` 2024-01-19T13:27:18.8302842Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12 2024-01-19T13:27:25.8438153Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:347: Failure 2024-01-19T13:27:25.8438641Z Expected equality of these values: 2024-01-19T13:27:25.8438841Z COMPARE_RESULT::SUCCESS 2024-01-19T13:27:25.8439276Z Which is: 4-byte object <00-00 00-00> 2024-01-19T13:27:25.8439464Z ret.first 2024-01-19T13:27:25.8445514Z Which is: 4-byte object <01-00 00-00> 2024-01-19T13:27:25.8445962Z expected 0.145984 (3e157cc1), got 0.975133 (3f79a24b), diff: 0.829149, tol=0.0114598 idx=375. 20 of 388 differ 2024-01-19T13:27:25.8446198Z 2024-01-19T13:27:25.8555736Z [ FAILED ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12, where GetParam() = "cuda_../models/zoo/opset12/SSD/ssd-12.onnx" (7025 ms) 2024-01-19T13:27:25.8556077Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_YOLOv312_yolov312 2024-01-19T13:27:29.3174318Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:347: Failure 2024-01-19T13:27:29.3175144Z Expected equality of these values: 2024-01-19T13:27:29.3175389Z COMPARE_RESULT::SUCCESS 2024-01-19T13:27:29.3175812Z Which is: 4-byte object <00-00 00-00> 2024-01-19T13:27:29.3176080Z ret.first 2024-01-19T13:27:29.3176322Z Which is: 4-byte object <01-00 00-00> 2024-01-19T13:27:29.3178431Z expected 4.34958 (408b2fb8), got 4.51324 (40906c80), diff: 0.16367, tol=0.0534958 idx=9929. 22 of 42588 differ ``` 3. some other test like SSD throw other exception, so skip them ''' 2024-01-22T09:07:40.8446910Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12 2024-01-22T09:07:51.5587571Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:358: Failure 2024-01-22T09:07:51.5588512Z Expected equality of these values: 2024-01-22T09:07:51.5588870Z COMPARE_RESULT::SUCCESS 2024-01-22T09:07:51.5589467Z Which is: 4-byte object <00-00 00-00> 2024-01-22T09:07:51.5589953Z ret.first 2024-01-22T09:07:51.5590462Z Which is: 4-byte object <01-00 00-00> 2024-01-22T09:07:51.5590841Z expected 1, got 63 ''' --- .../test/global_thread_pools/test_inference.cc | 8 +++++++- onnxruntime/test/providers/cpu/model_tests.cc | 17 +++++++++++++++++ .../providers/cuda/nhwc/conv_transpose_test.cc | 6 +++++- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 ++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 4772e7de2bdd7..f553682975f11 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -55,9 +55,15 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, // size_t total_len = type_info.GetElementCount(); ASSERT_EQ(values_y.size(), static_cast(5)); +// test inference is using onnxruntime_shared_lib_test_LIBS, so HasCudaEnvironment(800) isn't available +#ifdef USE_CUDA + const float tolerance = 1e-5f; +#else + const float tolerance = 1e-6f; +#endif OutT* f = output_tensor->GetTensorMutableData(); for (size_t i = 0; i != static_cast(5); ++i) { - ASSERT_NEAR(values_y[i], f[i], 1e-6f); + ASSERT_NEAR(values_y[i], f[i], tolerance); } } diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 859e082716760..8128c170c5211 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -39,6 +39,8 @@ #include "core/providers/armnn/armnn_provider_factory.h" #endif +#include "test/common/cuda_op_test_utils.h" + // test infrastructure #include "test/onnx/testenv.h" #include "test/onnx/TestCase.h" @@ -94,6 +96,21 @@ TEST_P(ModelTest, Run) { std::unique_ptr model_info = std::make_unique(model_path.c_str()); +#if defined(__linux__) + // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. + if (HasCudaEnvironment(800) && provider_name == "cuda") { + per_sample_tolerance = 1e-1; + if (model_path.find(ORT_TSTR("SSD")) > 0 || + model_path.find(ORT_TSTR("ssd")) > 0 || + model_path.find(ORT_TSTR("yolov3")) > 0 || + model_path.find(ORT_TSTR("mask_rcnn")) > 0 || + model_path.find(ORT_TSTR("FNS")) > 0) { + SkipTest("Skipping SSD test for big tolearance failure or other errors"); + return; + } + } +#endif + if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { SkipTest("it has the training domain. No pipeline should need to run these tests."); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 06da2a5304716..6514feadf0ff7 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -70,7 +70,11 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { auto op = ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; - MAKE_PROVIDERS_EPS_TYPE(TypeParam) + if (HasCudaEnvironment(800)) { + MAKE_PROVIDERS_EPS(1e-2) + } else { + MAKE_PROVIDERS_EPS_TYPE(TypeParam) + } } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1060a0138e0b7..5779b1da3fd43 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -137,7 +137,7 @@ jobs: --enable_cuda_profiling --enable_cuda_nhwc_ops \ --enable_pybind --build_java \ --use_cache \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75; \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) @@ -166,7 +166,7 @@ jobs: skipComponentGovernanceDetection: true workspace: clean: all - pool: Onnxruntime-Linux-GPU-T4 + pool: onnxruntime-Linux-GPU-A10 dependsOn: - Linux_Build steps: From f53068446e7e560012862e1812270bcf908fbda4 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Tue, 23 Jan 2024 13:44:34 -0800 Subject: [PATCH 07/23] Add Temperature to WhisperBeamSearch input (#19188) ### Description Add `temperature` as an input to WhisperBeamSearch op and initialize correctly in parameter setup. ### Motivation and Context Currently, temperature is included as an attribute to the BeamSearch op, which doesn't let the model act dynamically in a single inference session. By including this variable as an input, the temperature value can be altered in any inference call (important for 1P teams) --------- Co-authored-by: Peter McAughan Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Kunal Vaishnavi --- docs/ContribOperators.md | 4 +++- docs/OperatorKernels.md | 4 ++-- .../cpu/transformers/beam_search_parameters.cc | 14 +++++++++++++- .../contrib_ops/cuda/transformers/beam_search.cc | 1 + onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 1 + 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 22e82443167f6..624cda1d37f73 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5761,7 +5761,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14) +#### Inputs (5 - 15)
input_ids : F
@@ -5792,6 +5792,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
temperature (optional) : T
+
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9a2a7ac89bbb3..3b695af2839b6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -499,7 +499,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -876,7 +876,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5eb..bb6885c3216bc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; ORT_ENFORCE(logits_processor >= 0, "logits_processor shall be a non-negative integer, got ", logits_processor); -} + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + auto* temperature_tensor = context->Input(14); + if (temperature_tensor) { + if (temperature_tensor->IsDataType()) { + temperature = *temperature_tensor->Data(); + } else { + temperature = static_cast(*temperature_tensor->Data()); + } + } else { + temperature = 1.0f; + } + } +} void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 2a90e4911f286..08cbb145a6f65 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 982e8fd834b76..27c968a59eb91 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1231,6 +1231,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) + .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", From 532f8c642ce9c1ea2971b7d0f0ff8a4197bcb3a0 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 23 Jan 2024 14:57:30 -0800 Subject: [PATCH 08/23] Fix a backend test by using local backend (#19230) The decomposition pass (e.g., converting torch.add to aten.add) in DORT no longer exists. Therefore, we have to use `use_aot_autograd=True` to enable Dynamo's built-in operator decomposition. I think we need to add the decomposition pass back to DORT or remove `use_aot_autograd` (remove because it will always be `true`). --- .../orttraining/test/python/orttraining_test_dort.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index f0b6b9c5fba28..573ec85d76013 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -216,7 +216,12 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.relu() return tensor_q - local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + # TODO: Set use_aot_autograd=False. In order to decompose torch + # function calls to aten ops, we need to set + # user_aot_autograd=True because there is no decomposition in DORT + # anymore. A long-term fix will be brining # decomposition pass back + # into DORT. + local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) def run(fun, list_x): From cbb29d80ff5ec63d3cc2289911c4420f5a9d8a2d Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:34:26 -0800 Subject: [PATCH 09/23] GQA Rotary and Packed QKV with Flash (#18906) ### Description These changes add rotary embedding and packed qkv input to gqa. As of now, the changes are only supported with Flash-Attention (SM >= 80) but should soon be supported with Memory Efficient Attention as well. ### Motivation and Context With the fusion of rotary embedding into this Attention op, we hope to observe some perf gain. The packed QKV should also provide some perf gain in the context of certain models, like Llama2, that would benefit from running ops on the fused QKV matrix, rather than the separate Q, K, and V. --------- Co-authored-by: Yufeng Li --- docs/ContribOperators.md | 16 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 5 + .../cuda/bert/flash_attention/flash_api.cc | 51 +- .../cuda/bert/flash_attention/flash_api.h | 6 +- .../cuda/bert/group_query_attention.cc | 26 +- .../cuda/bert/group_query_attention.h | 5 + .../cuda/bert/group_query_attention_helper.h | 150 ++-- .../cuda/bert/group_query_attention_impl.cu | 125 ++-- .../cuda/bert/group_query_attention_impl.h | 2 + .../core/graph/contrib_ops/bert_defs.cc | 34 +- .../test/python/transformers/rotary_flash.py | 693 ++++++++++++++++++ .../python/transformers/test_flash_attn.py | 668 ++++++++++++++--- tools/ci_build/build.py | 3 +- ...txt => requirements-transformers-test.txt} | 3 +- 15 files changed, 1517 insertions(+), 272 deletions(-) create mode 100644 onnxruntime/test/python/transformers/rotary_flash.py rename tools/ci_build/{requirements.txt => requirements-transformers-test.txt} (94%) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 624cda1d37f73..e7b537d6894c8 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs +#### Inputs (7 - 9)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
value : T
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
total_sequence_length : M
Scalar tensor of total sequence length (past + new).
+
cos_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b695af2839b6..31cca232fde34 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -843,7 +843,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index da489a6901512..8afeb874750b4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters { bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; + bool is_packed_qkv; bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; }; namespace attention { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index d6eb87228bb4a..2c296bf4f8483 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size) { - // if (seqlen_q == 1) { - // is_causal = false; - // } // causal=true is the same as causal=false in this case - + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, is_causal ? 0 : -1); params.dprops = &dprops; - if (k != nullptr && v != nullptr) { + if (k_new != nullptr && v_new != nullptr) { params.seqlen_knew = seqlen_k_new; - params.knew_ptr = k; - params.vnew_ptr = v; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; // All stride are in elements, not bytes. - params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.knew_row_stride = num_heads_k * head_size; - params.vnew_row_stride = num_heads_k * head_size; + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; } else { @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.cu_seqlens_k = static_cast(seqlens_k_); } + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 3d75d6834b8e0..387d1cf9d84fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size = -1); + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fd6fb79742cac..fe56f84f0a886 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(kv_num_heads); is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_memory_efficient_attention_ = true; #endif + if (!disable_flash_attention_) { + zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); + } } template @@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { value, past_key, past_value, + cos_cache, + sin_cache, ¶meters, num_heads_, kv_num_heads_, @@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.output = reinterpret_cast(output->MutableData()); @@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 54a8127e29e7b..15573ece166fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel { int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention int local_window_size_; + bool is_unidirectional_; bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + IAllocatorUniquePtr zeros_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 2cb9955807f26..853e1a710cb24 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query, bool is_past_bsnh, float scale) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // no packing for q/k/v: - // query (Q) : (B, S, D) - // key (K) : (B, S, D_kv) - // value (V) : (B, S, D_kv) + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - + const bool is_packed_qkv = key == nullptr; const auto& query_dims = query->Shape().GetDims(); - const auto& key_dims = key->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query, int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); int q_hidden_size = static_cast(query_dims[2]); - int head_size = static_cast(q_hidden_size) / num_heads; + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - int kv_hidden_size = static_cast(key_dims[2]); + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); @@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (static_cast(sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); - } - - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - // Check seqlens_k tensor (holding past seqlen for token gen) const auto& seqlens_dim = seqlens_k->Shape().GetDims(); if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { @@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 must be of present_sequence_length."); + } + if (sin_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 must be of present_sequence_length."); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + bool is_prompt = sequence_length != 1; if (parameters != nullptr) { @@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query, output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; - output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; @@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 5b0f5d0cfe601..d88e9a49fb5ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -151,9 +151,10 @@ template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, cudaStream_t stream, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.sequence_length; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; @@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } - __global__ void PastToTotalSeqlen(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int add_seqlen) { @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int threads_per_block) { if (parameters.is_prompt) { return Status::OK(); } @@ -482,91 +482,63 @@ Status FlashAttention( const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; - const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = true; - bool is_bf16 = std::is_same::value; - // Note: seqlens_k is past sequence length for flash - if (parameters.is_prompt) { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); - } - - void* seqlens_k = reinterpret_cast(data.seqlens_k); - - if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } - - if (parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - key = nullptr; - value = nullptr; - seqlens_k = reinterpret_cast(data.seqlens_k_total); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { - // Not share buffer case - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); - } - - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } - if (!parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; } - - seqlens_k = reinterpret_cast(data.seqlens_k_total); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); - DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, 0, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); } + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -672,7 +644,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index de32d7ea93163..1bf91f9c875eb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -21,6 +21,8 @@ struct GroupQueryAttentionData { const T* past_key = nullptr; const T* past_value = nullptr; int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7f34647f1faef..8583474a1e391 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); - } else { - fail_shape_inference("Missing input 2 (value)"); } } if (ctx.getNumOutputs() > 1) { // has present output if (hasInputShape(ctx, past_key_index)) { + // auto& query_shape = getInputShape(ctx, 0); + // auto& query_dims = query_shape.dim(); auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); if (past_dims.size() != 4) { @@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& } ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not } } } @@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" + "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" @@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "total_sequence_length", "Scalar tensor of total sequence length (past + new).", "M") + .Input(7, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) + .Input(8, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py new file mode 100644 index 0000000000000..42bff9c92b41b --- /dev/null +++ b/onnxruntime/test/python/transformers/rotary_flash.py @@ -0,0 +1,693 @@ +# Copyright (c) 2023, Tri Dao. + + +from typing import Optional, Tuple, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange, repeat + +##### TRITON KERNEL FOR ROTARY ##### + + +# @triton.autotune( +# configs=[ +# triton.Config({"block_m": 2}), +# triton.Config({"block_m": 4}), +# triton.Config({"block_m": 8}), +# triton.Config({"block_m": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + out_, # Pointers to matrices + x_, + cos_, + sin_, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + block_k: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + block_m: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * block_m >= seqlen: + return + rm = pid_m * block_m + tl.arange(0, block_m) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, block_k) + rk_half = tl.arange(0, block_k // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_ + x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to( + tl.float32 + ) + sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to( + tl.float32 + ) + x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x1 = tl.load( + x_ + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + out_ + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, block_k) // 2 + x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + cos_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) + x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa + block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + block_k, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + block_m, + ) + return output + + +##### ROTARY API ##### + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class ApplyRotaryEmbQKV(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dqkv, None, None, None, None, None, None + + +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) + + +class ApplyRotaryEmbKV(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + k = kv[:, :, 0] + apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dkv, None, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply + + +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + kv = apply_rotary_emb_kv_( + kv, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + return q, kv diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 8a839875de2a2..90d28872d3cc8 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -184,7 +185,13 @@ def create_multihead_attention_graph(config): def create_group_query_attention_graph_prompt( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -193,18 +200,22 @@ def create_group_query_attention_graph_prompt( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key" if share_buffer else "", "past_value" if share_buffer else "", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -218,25 +229,9 @@ def create_group_query_attention_graph_prompt( [ config.batch_size, config.q_sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -250,6 +245,27 @@ def create_group_query_attention_graph_prompt( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] if share_buffer: graph_input += [ helper.make_tensor_value_info( @@ -273,6 +289,25 @@ def create_group_query_attention_graph_prompt( ], ), ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -334,7 +369,13 @@ def create_group_query_attention_graph_prompt( def create_group_query_attention_graph_past( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -345,18 +386,22 @@ def create_group_query_attention_graph_past( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key", "past_value", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -370,25 +415,9 @@ def create_group_query_attention_graph_past( [ config.batch_size, config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -411,8 +440,6 @@ def create_group_query_attention_graph_past( config.head_size, ], ), - ] - graph_input += [ helper.make_tensor_value_info( "seqlens_k", TensorProto.INT32, @@ -424,6 +451,46 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -663,21 +730,38 @@ def mha_func(q, k, v, config): def gqa_prompt_func( - q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + window_size=-1, + past_kv_format=Formats.BSNH, + share_buffer=True, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_prompt( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -686,9 +770,17 @@ def gqa_prompt_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -713,17 +805,23 @@ def gqa_prompt_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") @@ -737,21 +835,38 @@ def gqa_prompt_func( def gqa_past_func( - q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_past( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -763,9 +878,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -790,8 +913,6 @@ def gqa_past_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": past_k.detach().cpu().numpy(), "past_value": past_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -805,9 +926,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -1029,9 +1158,12 @@ def parity_check_mha( def parity_check_gqa_prompt( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1080,6 +1212,8 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) + # print(k.shape) + # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1105,19 +1239,47 @@ def parity_check_gqa_prompt( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1125,13 +1287,47 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # print((out - out_ref)[0, :, 0, 0]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1139,10 +1335,16 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1171,9 +1373,12 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1229,13 +1434,42 @@ def parity_check_gqa_prompt_no_buff( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_cache_ref + k_cache_ref = k_ro + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1243,9 +1477,39 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + None, + None, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + None, + None, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1256,7 +1520,17 @@ def parity_check_gqa_prompt_no_buff( # Compare results print( - "KV-buffer", + "No buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1285,9 +1559,12 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1336,6 +1613,7 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) left_window_size = -1 if local: @@ -1359,18 +1637,45 @@ def parity_check_gqa_past( dtype=torch.int32, device="cuda", ) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1378,13 +1683,46 @@ def parity_check_gqa_past( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1394,10 +1732,16 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, " B:", config.batch_size, " S:", @@ -1427,6 +1771,9 @@ def parity_check_gqa_past_no_buff( causal=False, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1503,18 +1850,47 @@ def parity_check_gqa_past_no_buff( device="cuda", ) cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = ( + torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + ) + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1522,13 +1898,47 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((out - out_ref)[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # Make sure past-present buffer updating correctly # assert numpy.allclose( # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True @@ -1540,10 +1950,16 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1671,10 +2087,25 @@ def test_gqa_no_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1684,7 +2115,6 @@ def test_gqa_past(self): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- TEST GQA PAST (TOKEN GEN) ---------") - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( [(1, 128), (1, 1024), (1, 2048)] @@ -1706,6 +2136,7 @@ def test_gqa_past(self): num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") for b in batches: for s, s2 in seqs: for n, n2 in num_h: @@ -1734,23 +2165,30 @@ def test_gqa_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if __name__ == "__main__": diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1034a82cb2854..6e5cd7b57e403 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): numpy_init_version = numpy.__version__ pb_init_version = google.protobuf.__version__ run_subprocess( - [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR + [sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"], + cwd=SCRIPT_DIR, ) run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements-transformers-test.txt similarity index 94% rename from tools/ci_build/requirements.txt rename to tools/ci_build/requirements-transformers-test.txt index 57fc8f08336d2..a5279781462a7 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements-transformers-test.txt @@ -3,7 +3,8 @@ packaging protobuf==3.20.2 numpy==1.24.0 ; python_version < '3.12' numpy==1.26.0 ; python_version >= '3.12' +torch coloredlogs==15.0 transformers==4.36.0 psutil -einops \ No newline at end of file +einops From 6a424ccf8c2f9cd7f191c843547d5f37ef409493 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 24 Jan 2024 03:33:49 +0000 Subject: [PATCH 10/23] Fix AMD pipeline test failures (#19250) ### Description Fix amd test failure ### Motivation and Context --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu | 5 +++-- onnxruntime/contrib_ops/rocm/bert/multihead_attention.h | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 6f98312e4067d..09e7d61b71db9 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -68,6 +68,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; @@ -121,8 +122,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, - num_heads_, mask_filter_value_, scale_, + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h index 84d8b76bbfebe..1d676d7a7bcac 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. From c10be1848cafa7575ba298cbcc01e89dcd841851 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 23 Jan 2024 21:30:22 -0800 Subject: [PATCH 11/23] [TensorRT EP] Avoid calling unavailable function with cpu python package (#19251) C.register_tensorrt_plugins_as_custom_ops() is only available in gpu python package. Add condition to avoid calling it in cpu python package. --- .../python/onnxruntime_inference_collection.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 1a3e22142f80e..09f768f53ea65 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -466,7 +466,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi session_options = self._sess_options if self._sess_options else C.get_default_session_options() - self._register_ep_custom_ops(session_options, providers, provider_options) + self._register_ep_custom_ops(session_options, providers, provider_options, available_providers) if self._model_path: sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model) @@ -510,11 +510,15 @@ def _reset_session(self, providers, provider_options): self._sess_options = self._sess_options_initial self._create_inference_session(providers, provider_options) - def _register_ep_custom_ops(self, session_options, providers, provider_options): + def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers): for i in range(len(providers)): - if providers[i] == "TensorrtExecutionProvider": + if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider": C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i]) - elif isinstance(providers[i], tuple) and providers[i][0] == "TensorrtExecutionProvider": + elif ( + isinstance(providers[i], tuple) + and providers[i][0] in available_providers + and providers[i][0] == "TensorrtExecutionProvider" + ): C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1]) From d7aebf9ea8a4a651088384f219292bae9062439b Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jan 2024 14:15:07 +0800 Subject: [PATCH 12/23] Move Nuget Test from T4 to A10 to reduce release duration (#19253) ### Description ### Motivation and Context Running release process is very painful and boring because some GPU jobs have to wait so long time. ![image](https://github.com/microsoft/onnxruntime/assets/16190118/1c5c981e-68d4-4678-9758-443fbf362802) ![image](https://github.com/microsoft/onnxruntime/assets/16190118/ba0d79ba-1554-4c7a-93dd-6ea8144c9295) ![image](https://github.com/microsoft/onnxruntime/assets/16190118/36cab833-71c1-4ff5-bca5-f4caa9aee0c9) On the one hand, we could move some T4 from PR process since some jobs are not using T4 any more and on the other hand, we can continue to change some jobs' agent from T4 to A4 too. In the future, T4 will mainly be used for the scenarioes that big GPU memory is needed, multiple GPU cards or some special cases. Test runs: https://dev.azure.com/aiinfra/Lotus/_build/results?buildId=401786&view=logs&j=8048494c-e6eb-5e47-5e87-ff0aa863325d cc @YUNQIUGUO @snnn --- .../c-api-noopenmp-packaging-pipelines.yml | 8 ++++---- .../github/azure-pipelines/cuda-packaging-pipeline.yml | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa1a75bfcda45..5a50a9964bead 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -1023,7 +1023,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -1034,7 +1034,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -1046,7 +1046,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -1055,7 +1055,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 1d2ba88652f48..0c24d4897ddf1 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -151,7 +151,7 @@ stages: # Testing - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -162,7 +162,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -174,7 +174,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -184,7 +184,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' From a39ac4a97976c9bea49be6e646ac1fac64278f65 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 24 Jan 2024 10:06:31 -0800 Subject: [PATCH 13/23] [DirectML] Register Pad19 (#19175) ### Description Register Pad19 in DirectML --------- Co-authored-by: Sheil Kumar --- .../src/Operators/DmlOperatorPadding.cpp | 7 +++++++ .../src/Operators/OperatorRegistration.cpp | 6 ++++++ .../core/providers/dml/OperatorAuthorHelper/Attributes.h | 1 + .../providers/dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../providers/dml/OperatorAuthorHelper/OperatorVersions.h | 1 + 5 files changed, 16 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index a014db5adbe61..b243f7e741a70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -51,6 +51,12 @@ class DmlOperatorPadding : public DmlOperator, public PaddingHelper { mode = DML_PADDING_MODE_REFLECTION; } +#if DML_TARGET_VERSION >= 0x6300 + else if (modeString == AttrValue::Wrap) + { + mode = DML_PADDING_MODE_WRAP; + } +#endif else { ML_INVALID_ARGUMENT("Unknown Pad mode attribute."); @@ -116,5 +122,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 18e29c8b99ced..7b53a1102c5a7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -358,6 +358,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad7); DML_OP_EXTERN_CREATION_FUNCTION(Pad11); DML_OP_EXTERN_CREATION_FUNCTION(Pad13); DML_OP_EXTERN_CREATION_FUNCTION(Pad18); +DML_OP_EXTERN_CREATION_FUNCTION(Pad19); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); @@ -747,6 +748,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, + +#if DML_TARGET_VERSION >= 0x6300 + {REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, +#endif + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index e3df1d00b3e8a..9c5d021f52b36 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -149,5 +149,6 @@ namespace AttrValue static constexpr const char* NearestNeighbor = "NN"; static constexpr const char* NotSet = "NOTSET"; static constexpr const char* Reflect = "reflect"; + static constexpr const char* Wrap = "wrap"; } // namespace AttrValue diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 0d425997e6a6a..d4b44f6fa8a9d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1589,6 +1589,7 @@ using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper; +using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 79efc2d2836fe..57cb009b72ebc 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -413,6 +413,7 @@ namespace OperatorHelper namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Pad = 19; static const int sc_sinceVer_Cast = 19; static const int sc_sinceVer_CastLike = 19; static const int sc_sinceVer_Constant = 19; From a33b5bd1fa5ac6d9aabb23cd8aca16b5fc3fc3c5 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Thu, 25 Jan 2024 01:12:21 +0530 Subject: [PATCH 14/23] [JS/WebGPU] Added Uniforms to SkipLayerNorm. (#18788) ### Description Added Uniforms to SkipLayerNorm ### Motivation and Context Improve performance --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 123 ++++++++++-------- 2 files changed, 69 insertions(+), 58 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index cc504093ca0d7..d737a28654220 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -25,7 +25,7 @@ import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; -import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; +import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; @@ -116,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], ['Slice', [slice, parseSliceAttributes]], - ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]], + ['SkipLayerNormalization', [skipLayerNorm]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Softmax', [softmax, parseSoftmaxAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index a2fda9f07d09f..509a722f4b52a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -4,10 +4,10 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo = const hasInputSkipBiasSumOutput = outputCount > 3; const components = getMaxComponents(hiddenSize); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: f32 = ${hiddenSize}; - const hiddenSizeVectorized: u32 = ${hiddenSize / components}; - const epsilon: f32 = ${attributes.epsilon}; - ${shaderHelper.declareVariables(...variables)} + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, + {type: 'uint32', data: components}, + {type: 'uint32', data: hiddenSize}, + {type: 'float32', data: attributes.epsilon}, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniformsArray: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'components', type: 'u32'}, + {name: 'hidden_size', type: 'u32'}, + {name: 'epsilon', type: 'f32'}, + ]; + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + + ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')} + let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; + let offset = global_idx * hidden_size_vectorized; var sum = ${fillVector('f32', components)}; var squareSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - let skipValue = skip[offset + i]; - let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; - let inputValue = x[offset + i]; - let value = inputValue + skipValue + biasValue; - ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + let skip_value = skip[offset + i]; + let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'}; + let input_value = x[offset + i]; + let value = input_value + skip_value + bias_value; + ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32(dataType, components, 'value')}; - sum += f32Value; - squareSum += f32Value * f32Value; + let f32_value = ${castToF32(dataType, components, 'value')}; + sum += f32_value; + squareSum += f32_value * f32_value; } - let mean = ${sumVector('sum', components)} / hiddenSize; - let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); - ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} - ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''} - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i] - + ${hasBetaInput ? 'beta[i]' : '0.0'}; + let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); + let inv_std_dev = inverseSqrt(${ + sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); + ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} + ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${ + hasBetaInput ? 'beta[i]' : '0.0'}; } }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo = if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType}); } - return { name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: { + hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, + inputDependencies: inputs.map((_input, _index) => 'type') + }, getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}), }; }; @@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm context.compute( createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); }; - -export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { - const epsilon = attributes.epsilon as number; - return createAttributeWithCacheKey({epsilon}); -}; From a28abeb24100441c76a777f9ce225cb0ea3a59c3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 24 Jan 2024 14:35:44 -0800 Subject: [PATCH 15/23] Change "#ifdef WIN32" to "#ifdef _WIN32" (#19254) ### Description `_WIN32` is a standard macro listed at https://learn.microsoft.com/en-us/cpp/preprocessor/predefined-macros?view=msvc-170 . But `WIN32` is not. --- .../main/native/ai_onnxruntime_OrtSession_SessionOptions.c | 4 ++-- onnxruntime/core/mlas/lib/amx_common.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 3a1c0d1bb8fa1..4a5e2b7ef3b1e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -8,7 +8,7 @@ #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OrtSession_SessionOptions.h" -#ifdef WIN32 +#ifdef _WIN32 #include #else #include @@ -318,7 +318,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC // Iterate the handles, calling the appropriate close function for (jint i = 0; i < numHandles; i++) { -#ifdef WIN32 +#ifdef _WIN32 FreeLibrary((void*)handles[i]); #else dlclose((void*)handles[i]); diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h index 3eb0700932faa..caf94af02362d 100644 --- a/onnxruntime/core/mlas/lib/amx_common.h +++ b/onnxruntime/core/mlas/lib/amx_common.h @@ -18,7 +18,7 @@ Module Name: #include "mlasi.h" -#ifdef WIN32 +#ifdef _WIN32 #define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) #define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) From bc54ad3f03d7ee333f5e0c62ebf892c32f8a51a5 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 24 Jan 2024 14:37:39 -0800 Subject: [PATCH 16/23] Update abseil to a release tag and register neural_speed (#19255) ### Description Update abseil to a release tag and register neural_speed to CG. ### Motivation and Context Now we are using a non-relesed version of abseil. Using a tag is better. --- cgmanifests/generated/cgmanifest.json | 12 +++++++++++- cmake/deps.txt | 3 ++- cmake/external/abseil-cpp.cmake | 2 +- cmake/external/abseil-cpp.natvis | 10 +++++----- cmake/external/neural_speed.cmake | 9 +++------ .../azure-pipelines/templates/download-deps.yml | 4 ++-- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index bcd0b2a92a5c3..03e3f84547a68 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", + "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -192,6 +192,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, { "component": { "type": "git", diff --git a/cmake/deps.txt b/cmake/deps.txt index fda27e5e93797..ba9c2bb73cf7a 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -34,6 +34,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3bcd4109e2888..57cfbee4644ef 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -19,7 +19,7 @@ if(WIN32 AND NOT Patch_FOUND) set(ABSL_ENABLE_INSTALL ON) endif() # NB! Advancing Abseil version changes its internal namespace, -# currently absl::lts_20230125 which affects abseil-cpp.natvis debugger +# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate # that namespace at build time. FetchContent_Declare( diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 1e5a36fb9efb9..a4fb63b6a8377 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + @@ -24,7 +24,7 @@ - + @@ -51,7 +51,7 @@ - + *($T1 *){value} (*($T1 *){value}) @@ -60,7 +60,7 @@ - + *($T1 *)this (*($T1 *)this) @@ -68,7 +68,7 @@ - + {value.first}, {value.second} ({value.first}, {value.second}) diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake index e66e2acfb209a..ed711351403a7 100644 --- a/cmake/external/neural_speed.cmake +++ b/cmake/external/neural_speed.cmake @@ -7,12 +7,9 @@ endif() if(USE_NEURAL_SPEED) FetchContent_Declare( neural_speed - URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip - URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} ) set(BTLA_USE_OPENMP OFF) - FetchContent_MakeAvailable(neural_speed) - if(NOT neural_speed_POPULATED) - FetchContent_Populate(neural_speed) - endif() + onnxruntime_fetchcontent_makeavailable(neural_speed) endif() diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 537175f6bec73..55f6561b7a44a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.129 + version: 1.0.132 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.129 + version: 1.0.132 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 591f90c0b9e8d0922fcebabffed8d54b67d7a613 Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Thu, 25 Jan 2024 06:49:37 +0800 Subject: [PATCH 17/23] [js/webgpu] Fix issue of timestamp query (#19258) When we enable webgpu profiling mode between session.create and session.run, current implementation has a problem to create querySet (and also queryResolveBuffer) if we share the commandEncoder with inputs upload. This PR fixes this by moving the querySet creation to the place we set queryType. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index afef7042a4280..8ca025d66550c 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -222,16 +222,6 @@ export class WebGpuBackend { getCommandEncoder(): GPUCommandEncoder { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - - if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { - this.querySet = this.device.createQuerySet({ - type: 'timestamp', - count: this.maxDispatchNumber * 2, - }); - this.queryResolveBuffer = this.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); - } } return this.commandEncoder; } @@ -654,6 +644,16 @@ export class WebGpuBackend { } else if (this.device.features.has('timestamp-query')) { this.queryType = 'at-passes'; } + + if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.maxDispatchNumber * 2, + }); + this.queryResolveBuffer = this.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); + } } } onRunStart(): void { From c456f19dbaf6b23928a60e8b356a429ae76376a4 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 24 Jan 2024 15:20:36 -0800 Subject: [PATCH 18/23] remove old quantization tool file (#19247) ### Description remove old python files ### Motivation and Context We have a new op MatMulNBits and this one is deprecated. --- .../python/tools/quantization/__init__.py | 1 - .../quantization/matmul_weight4_quantizer.py | 260 ------------------ .../python/quantization/test_op_matmulfpq4.py | 153 ----------- 3 files changed, 414 deletions(-) delete mode 100644 onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py delete mode 100644 onnxruntime/test/python/quantization/test_op_matmulfpq4.py diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 170c0928fee23..9d397499d45a4 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -5,7 +5,6 @@ MinMaxCalibrater, create_calibrator, ) -from .matmul_weight4_quantizer import MatMulWeight4Quantizer # noqa: F401 from .qdq_quantizer import QDQQuantizer # noqa: F401 from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401 from .quantize import DynamicQuantConfig # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py deleted file mode 100644 index 921e02fb69e9b..0000000000000 --- a/onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py +++ /dev/null @@ -1,260 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import argparse -import struct -from pathlib import Path -from typing import List, Tuple - -import numpy as np -import numpy.typing as npt -import onnx -from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto - -from .onnx_model import ONNXModel -from .quant_utils import attribute_to_kwarg, load_model_with_shape_infer - - -def __q4_block_size(quant_type: int) -> int: - # happens to be 32 for now, but future quantization types - # may have bigger block size - return 32 - - -def __q4_blob_size(quant_type: int) -> int: - if quant_type == MatMulWeight4Quantizer.BlkQ4Sym: - # 4b each value, with one fp32 scale - blob_size = 32 // 2 + 4 - elif quant_type == MatMulWeight4Quantizer.BlkQ4Zp8: - # 4b each value, with one fp32 scale and one uint8 zero point - blob_size = 32 // 2 + 4 + 1 - else: - raise ValueError(f"Unsupported quantization type: {quant_type}") - return blob_size - - -def __q4_buf_size(quant_type: int, rows: int, cols: int) -> int: - block_size = __q4_block_size(quant_type) - blob_size = __q4_blob_size(quant_type) - k_blocks = (rows + block_size - 1) // block_size - return k_blocks * cols * blob_size - - -def int4_block_quant(quant_type: int, fp32weight: npt.ArrayLike) -> np.ndarray: - """4b quantize fp32 weight to a blob""" - - if len(fp32weight.shape) != 2: - raise ValueError("Current int4 block quantization only supports 2D tensors!") - rows, cols = fp32weight.shape - - block_size = __q4_block_size(quant_type) - blob_size = __q4_blob_size(quant_type) - k_blocks = (rows + block_size - 1) // block_size - padded_rows = k_blocks * block_size - pad_len = padded_rows - rows - if pad_len > 0: - fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") - - # block wise quantization, each block comes from a single column - blob_idx = 0 - packed = np.zeros((cols * k_blocks, blob_size), dtype="uint8") - for n in range(cols): - ncol = fp32weight[:, n] - blks = np.split(ncol, k_blocks) - for blk in blks: - packed_blob = packed[blob_idx] - blob_idx += 1 - - if quant_type == MatMulWeight4Quantizer.BlkQ4Sym: - amax_idx = np.argmax(np.abs(blk)) - bmax = blk[amax_idx] - scale = bmax / (-8) - zp = 8 - else: - vmin = np.min(blk) - vmax = np.max(blk) - vmin = min(vmin, 0.0) - vmax = max(vmax, 0.0) - scale = (vmax - vmin) / ((1 << 4) - 1) - zero_point_fp = vmin - if scale != 0.0: - zero_point_fp = 0.0 - vmin / scale - zp = min(15, max(0, round(zero_point_fp))) - - reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 - bf = struct.pack("f", scale) - packed_blob[0] = bf[0] - packed_blob[1] = bf[1] - packed_blob[2] = bf[2] - packed_blob[3] = bf[3] - blob_offset = 4 - if quant_type == MatMulWeight4Quantizer.BlkQ4Zp8: - packed_blob[4] = zp - blob_offset = 5 - - num_segs = block_size // 32 - blk_int = np.clip(np.rint(blk * reciprocal_scale + zp), 0, 15).astype("uint8") - segs = np.split(blk_int, num_segs) - for seg in segs: - packed_blob[blob_offset : (blob_offset + 16)] = np.bitwise_or(seg[0:16], np.left_shift(seg[16:32], 4)) - blob_offset += 16 - return packed.reshape(-1) - - -class MatMulWeight4Quantizer: - """Perform 4b quantization of constant MatMul weights""" - - ################## - # quantization types, must be consistent with native code type - # MLAS_BLK_QUANT_TYPE defined in mlas_q4.h - - # 32 number block, symmetric quantization, with one fp32 as scale, zero point is always 0 - BlkQ4Sym = 0 - - # 32 number block, quantization, with one fp32 as scale, one uint8 zero point - BlkQ4Zp8 = 1 - - def __init__(self, model: ModelProto, quant_type: int): - self.model = ONNXModel(model) - self.quant_type = quant_type - - @staticmethod - def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: - for gid in range(len(graph_path) - 1, -1, -1): - graph = graph_path[gid] - for tensor in graph.initializer: - if tensor.name == name: - return tensor, graph - return None, None - - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" - - if node.op_type != "MatMul": - return node # only care about MatMul for now - - inputB = node.input[1] # noqa: N806 - B, Bs_graph = MatMulWeight4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: - return node # only care about constant weight - - # TODO!! assume B is not used by any other node - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: - return node # can only process 2-D matrix - - rows, cols = B_array.shape - packed = int4_block_quant(self.quant_type, B_array) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" - Bs_graph.initializer.remove(B) - for input in Bs_graph.input: - if input.name == inputB: - Bs_graph.input.remove(input) - break - - B_shape = onnx.numpy_helper.from_array(np.array([rows, cols]).astype(np.int64)) # noqa: N806 - B_shape.name = B.name + "_shape" - Bs_graph.initializer.extend([B_quant, B_shape]) - - kwargs = {} - kwargs["blk_quant_type"] = self.quant_type - matmul_q4_node = onnx.helper.make_node( - "MatMulFpQ4", - inputs=[node.input[0], B_quant.name, B_shape.name], - outputs=[node.output[0]], - name=node.name + "_Q4" if node.name else "", - domain="com.microsoft", - **kwargs, - ) - return matmul_q4_node - - def _process_subgraph(self, graph_stack: List[GraphProto]): - new_nodes = [] - graph = graph_stack[-1] - - for node in graph.node: - graph_attrs = [ - attr - for attr in node.attribute - if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS - ] - if len(graph_attrs): - kwargs = {} - for attr in node.attribute: - if attr.type == onnx.AttributeProto.GRAPH: - # recursive call to take care of sub-graph - graph_stack.append(attr.g) - kv = {attr.name: self._process_subgraph(graph_stack)} - elif attr.type == onnx.AttributeProto.GRAPHS: - value = [] - for subgraph in attr.graphs: - # recursive call to take care of sub-graph - graph_stack.append(subgraph) - value.extend([self._process_subgraph(graph_stack)]) - kv = {attr.name: value} - else: - kv = attribute_to_kwarg(attr) - kwargs.update(kv) - node = onnx.helper.make_node( # noqa: PLW2901 - node.op_type, node.input, node.output, name=node.name, **kwargs - ) - - new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) - - graph.ClearField("node") - graph.node.extend(new_nodes) - graph_stack.pop() - return graph - - def process(self): - # use a stack to keep track of sub-graphs - graph_stack = [self.model.graph()] - opset_import = self.model.opset_import() - - has_ms_domain = False - for opset in opset_import: - if opset.domain == "com.microsoft": - has_ms_domain = True - if not has_ms_domain: - opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - - self._process_subgraph(graph_stack) - - -def parse_args(): - parser = argparse.ArgumentParser( - description="""Blockwise int4 quantization for MatMul 2D weight matrices. - -A weight matrix is partitioned into into blocks, where each block is a -continguous subset inside each column. Each block is quantized into a -set of 4b integers with a scaling factor and an optional offset. -""" - ) - - parser.add_argument("--input_model", required=True, help="Path to the input model file") - parser.add_argument("--output_model", required=True, help="Path to the output model file") - parser.add_argument( - "--quant_bin_path", - required=True, - help="""Currently quantization code is implemented in a separate binary -(onnxruntime_mlas_q4dq) that is compiled with Onnxruntime native code. -Path to this binary needs to be provided here.""", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - input_model_path = args.input_model - output_model_path = args.output_model - q4dq_bin_path = args.quant_bin_path - - model = load_model_with_shape_infer(Path(input_model_path)) - quant = MatMulWeight4Quantizer(model, 0) - quant.process() - quant.model.save_model_to_file(output_model_path, False) diff --git a/onnxruntime/test/python/quantization/test_op_matmulfpq4.py b/onnxruntime/test/python/quantization/test_op_matmulfpq4.py deleted file mode 100644 index 170bb09a0fdeb..0000000000000 --- a/onnxruntime/test/python/quantization/test_op_matmulfpq4.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import tempfile -import unittest -from pathlib import Path -from typing import Dict, Tuple, Union - -import numpy as np -import onnx -from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count - -from onnxruntime.quantization import MatMulWeight4Quantizer, quant_utils - - -class TestOpMatMulFpQ4(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulfpq4.") - - @classmethod - def tearDownClass(cls): - cls._tmp_model_dir.cleanup() - - def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) -> np.ndarray: - line = np.zeros(shape) - line = line.reshape(-1) - - if symmetric: - v = -2.0 - for i in range(line.shape[0]): - if v == 0 or v == -3 or v == 3: - v += 1 - line[i] = v - v += 1 - if v >= 8: - v = -8 - else: - v = 0.0 - for i in range(line.shape[0]): - line[i] = v - v += 1 - if v >= 16: - v = 0 - - return line.reshape(shape) - - def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: - input_data_list = [] - for _i in range(n): - inputs = {} - for name, shape in name2shape.items(): - inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) - input_data_list.extend([inputs]) - dr = TestDataFeeds(input_data_list) - return dr - - def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> None: - # (input) - # | - # MatMul - # | - # (output) - input_name = "input" - output_name = "output" - initializers = [] - - def make_gemm(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): - weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) - initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) - return onnx.helper.make_node( - "MatMul", - [input_name, weight_name], - [output_name], - ) - - in_features = 52 - out_features = 288 - # make MatMulFpQ4 node - matmul_node = make_gemm( - input_name, - [in_features, out_features], - "linear1.weight", - output_name, - ) - - # make graph - input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) - output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) - graph_name = "matmul_test" - graph = helper.make_graph( - [matmul_node], - graph_name, - [input_tensor], - [output_tensor], - initializer=initializers, - ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version - - onnx.save(model, output_model_path) - - def quant_test( - self, - model_fp32_path: str, - data_reader: TestDataFeeds, - quantization_type: int, # 0: BlkQ4Sym, 1: BlkQ4Zp8 - ): - qtype_str = "BlkQ4Sym" if (quantization_type == 0) else "BlkQ4Zp8" - model_int4_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmulfpq4_{qtype_str}.onnx").absolute()) - - # Quantize fp32 model to int4 model - model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = MatMulWeight4Quantizer(model, quantization_type) - quant.process() - quant.model.save_model_to_file(model_int4_path, False) - - quant_nodes = {"MatMulFpQ4": 1} - check_op_type_count(self, model_int4_path, **quant_nodes) - - data_reader.rewind() - - try: - check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next()) - except Exception as exception: - if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: - # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception - pass - else: - raise exception - - def test_quantize_matmul_int4_symmetric(self): - np.random.seed(13) - - model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=True) - data_reader = self.input_feeds(1, {"input": [100, 52]}) - self.quant_test(model_fp32_path, data_reader, quantization_type=MatMulWeight4Quantizer.BlkQ4Sym) - - def test_quantize_matmul_int4_offsets(self): - model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) - data_reader = self.input_feeds(1, {"input": [100, 52]}) - self.quant_test(model_fp32_path, data_reader, quantization_type=MatMulWeight4Quantizer.BlkQ4Zp8) - - -if __name__ == "__main__": - unittest.main() From 7252c6e747de83b65285601281a9d07aea801fba Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 25 Jan 2024 07:37:35 +0800 Subject: [PATCH 19/23] [WebNN EP] Support WebNN async API with Asyncify (#19145) --- js/web/lib/build-def.d.ts | 4 --- js/web/lib/index.ts | 4 +-- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 4 +-- js/web/script/build.ts | 7 +--- js/web/script/test-runner-cli-args.ts | 4 --- .../core/providers/webnn/builders/model.cc | 35 ++++++++----------- .../providers/webnn/builders/model_builder.cc | 12 +++---- .../webnn/webnn_execution_provider.cc | 3 +- onnxruntime/wasm/js_internal_api.js | 4 +++ 10 files changed, 30 insertions(+), 49 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index b3868871a4753..2c9cd88a375bd 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -21,10 +21,6 @@ interface BuildDefinitions { /** * defines whether to disable the whole WebNN backend in the build. */ - readonly DISABLE_WEBNN: boolean; - /** - * defines whether to disable the whole WebAssembly backend in the build. - */ readonly DISABLE_WASM: boolean; /** * defines whether to disable proxy feature in WebAssembly backend in the build. diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index baf45e74addea..b212c0f49df3b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) { require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.DISABLE_WEBNN) { - registerBackend('webnn', wasmBackend, 9); - } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 68054210e79a7..24d7062c85fcb 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; - _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8768643fa7257..046336dc9cac0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -228,7 +228,7 @@ export const createSession = async( await Promise.all(loadingPromises); } - sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index ea0c122cb51de..d3652f3820357 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WEBGL': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'false', - 'BUILD_DEFS.DISABLE_WEBNN': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', @@ -364,7 +363,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -397,7 +395,7 @@ async function main() { // ort.webgpu[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'}, + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, }); // ort.wasm[.min].js await addAllWebBuildTasks({ @@ -411,7 +409,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); // ort.wasm-core[.min].js @@ -421,7 +418,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -434,7 +430,6 @@ async function main() { 'BUILD_DEFS.DISABLE_TRAINING': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 8f6c5f6f04122..ed4dd76a6e315 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); - if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); - } - // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index eaf549ef4e072..ef807a8c4fa26 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap& inputs, "The input of graph has unsupported type, name: ", name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + // Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer. + // As Wasm ArrayBuffer is not detachable. wnn_inputs_[name].call("set", view); -#else - wnn_inputs_.set(name, view); -#endif } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // This vector uses for recording output buffers from WebNN graph compution when WebAssembly - // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews - // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory - // address is different from the non-shared one, additional memory copy is required here. InlinedHashMap output_views; -#endif + for (const auto& output : outputs) { const std::string& name = output.first; const struct OnnxTensorData tensor = output.second; @@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap& inputs, name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS output_views.insert({name, view}); -#else - wnn_outputs_.set(name, view); -#endif } - wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + emscripten::val results = wnn_context_.call( + "compute", wnn_graph_, wnn_inputs_, wnn_outputs_) + .await(); + + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) { const std::string& name = output.first; emscripten::val view = output_views.at(name); - view.call("set", wnn_outputs_[name]); + view.call("set", results["outputs"][name]); } -#endif + // WebNN compute() method would return the input and output buffers via the promise + // resolution. Reuse the buffers to avoid additional allocation. + wnn_inputs_ = results["inputs"]; + wnn_outputs_ = results["outputs"]; + return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cf8a0e23db43b..56f7ead8ccf5d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { for (auto& name : output_names_) { named_operands.set(name, wnn_operands_.at(name)); } - emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + + emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } @@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Pre-allocate the input and output tensors for the WebNN graph - // when WebAssembly multi-threads is enabled since WebNN API only - // accepts non-shared ArrayBufferView. - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews + // for inputs and outputs because they will be transferred after compute() done. + // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution model->AllocateInputOutputBuffers(); -#endif return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2922cf9540a8e..df7871614b267 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } - wnn_context_ = ml.call("createContextSync", context_options); + + wnn_context_ = ml.call("createContext", context_options).await(); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7c70515e73eab..7e9c0a6f99c32 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea }; // replace the original functions with asyncified versions + Module['_OrtCreateSession'] = jsepWrapAsync( + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( Module['_OrtRun'], () => Module['_OrtRun'], From 0c2f0ba90da11ad53c63810e5f3e6fda4e295899 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 25 Jan 2024 07:53:10 +0800 Subject: [PATCH 20/23] [WebNN EP] Support conv1d by reshaping with prepended 1's (#18857) WebNN only supports 4-D inputs for conv2d and convTranspose2d, this PR supports 3-D inputs (i.e. conv1d) by prepending a 1 size dimension and several reshape operations. --- .../core/providers/webnn/builders/helper.h | 9 + .../webnn/builders/impl/conv_op_builder.cc | 221 +++++++++++------- 2 files changed, 141 insertions(+), 89 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 85dafcaf66575..92aa9abc9fdf7 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -54,6 +54,15 @@ std::string GetShapeString(std::vector& shape) { return shape_info.str(); } +inline std::vector GetVecUint32FromVecInt64(const std::vector& int64_vec) { + std::vector uint32_vec; + uint32_vec.reserve(int64_vec.size()); + std::transform(int64_vec.begin(), int64_vec.end(), + std::back_inserter(uint32_vec), + [](int64_t val) -> uint32_t { return SafeInt(val); }); + return uint32_vec; +} + template bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector& array, const logging::Logger& logger) { std::vector unpacked_tensor; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index ceacb7c2b38a3..c74545479e466 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -42,72 +42,61 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod // Helper functions common::Status SetConvBaseOptions(ModelBuilder& model_builder, const Node& node, emscripten::val& options, - const std::vector& strides, - const std::vector& dilations, - std::vector& pads, + const std::vector input_shape, + const std::vector weight_shape, + const std::vector& strides, + const std::vector& dilations, + std::vector& pads, + const bool is_nhwc, + const bool is_conv1d, const logging::Logger& logger) { NodeAttrHelper helper(node); - const auto group = helper.Get("group", static_cast(1)); const auto& input_defs = node.InputDefs(); - std::vector weight_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); - options.set("strides", emscripten::val::array(strides)); - options.set("dilations", emscripten::val::array(dilations)); - options.set("groups", group); + // Add Padding. - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); if (node.OpType() == "Conv") { // Calculate explicit padding for autoPad. if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { std::vector pads_out; ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - helper.Get("pads", std::vector{0, 0, 0, 0}), - helper.Get("strides", std::vector{1, 1}), - helper.Get("dilations", std::vector{1, 1}), - auto_pad_type, - pads_out, - model_builder.GetPreferredLayout() == DataLayout::NCHW)); - std::transform(pads_out.begin(), pads_out.end(), pads.begin(), - [](int64_t pad) -> int32_t { return static_cast(pad); }); + pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); + pads = pads_out; } } else if (node.OpType() == "ConvTranspose") { // When the 'output_shape' is specificed, the 'output_padding' values // in options.outputPadding are ignored. - std::vector dim; - std::vector output_padding{0, 0}; + std::vector dims; + std::vector output_padding{0, 0}; if (helper.HasAttr("output_shape")) { - // Default value of 'output_shape' will be ignore as we already check if - // it's existed. - dim = helper.Get("output_shape", std::vector{-1, -1}); + // Default value of 'output_shape' will be ignored as we already check if it existed. + dims = helper.Get("output_shape", std::vector{-1, -1}); // Extract the height and width. - std::vector output_shape; - if (dim.size() == 2) { - output_shape = dim; - } else if (dim.size() == 4) { - output_shape = {dim[2], dim[3]}; + std::vector output_shape; + if (dims.size() == 1 && is_conv1d) { // ConvTranspose 1d + output_shape = {dims[0], 1}; + } else if (dims.size() == 2 && !is_conv1d) { + output_shape = dims; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); } // Padding values are auto generated. if (helper.HasAttr("kernel_shape")) { - std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); - std::vector total_padding(2); - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); + if (is_conv1d) { // ConvTranspose 1d + kernel_shape.push_back(1); + } + std::vector total_padding(2); for (size_t i = 0; i < 2; i++) { // Get the dimensions of H and W. // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + if (is_nhwc) { + total_padding[i] = strides[i] * (input_shape[i + 1] - 1) + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; } else { - ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, - "WebNN GPU backend preferred layout should be NCHW."); - total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + total_padding[i] = strides[i] * (input_shape[i + 2] - 1) + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; } } AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); @@ -122,18 +111,27 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, } } } - options.set("outputSizes", emscripten::val::array(output_shape)); + options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); } else { - output_padding = helper.Get("output_padding", std::vector{0, 0}); - options.set("outputPadding", emscripten::val::array(output_padding)); + output_padding = helper.Get("output_padding", std::vector{0, 0}); + if (output_padding.size() == 1 && is_conv1d) { // ConvTranspose 1d + output_padding.push_back(0); + } + options.set("outputPadding", emscripten::val::array(GetVecUint32FromVecInt64(output_padding))); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "conv_op_builder only supports Op Conv and ConvTranspose."); } + + const auto group = helper.Get("group", static_cast(1)); + options.set("groups", group); + options.set("strides", emscripten::val::array(GetVecUint32FromVecInt64(strides))); + options.set("dilations", emscripten::val::array(GetVecUint32FromVecInt64(dilations))); + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(padding)); + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(GetVecUint32FromVecInt64(padding))); // Add bias if present. if (input_defs.size() > 2) { @@ -151,7 +149,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, // Both depthwise Conv and ConvTranspose share the same logic to add the layout. Status AddInitializerInNewLayout(ModelBuilder& model_builder, const std::string& name, - bool is_conv) { + bool is_conv, + bool is_conv1d) { const auto& tensor = *model_builder.GetInitializerTensors().at(name); auto data_type = tensor.data_type(); if (!IsSupportedDataType(data_type, model_builder.GetWebnnDeviceType())) { @@ -161,13 +160,13 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, } const auto& shape = tensor.dims(); - std::vector dims; - std::transform(shape.cbegin(), shape.cend(), - std::back_inserter(dims), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); + std::vector dims = GetVecUint32FromVecInt64(std::vector(std::begin(shape), std::end(shape))); + + if (is_conv1d) { + // Support conv1d by prepending a 1 size dimension. + dims.push_back(1); + } - ORT_RETURN_IF_NOT(dims.size() == 4, - "The initializer is not 4D: ", name, " actual dim ", dims.size()); const uint8_t* src = nullptr; Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); src = unpacked_tensor.DataAsByteSpan().data(); @@ -257,57 +256,101 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = emscripten::val::object(); - NodeAttrHelper helper(node); - const auto strides = helper.Get("strides", std::vector{1, 1}); - const auto dilations = helper.Get("dilations", std::vector{1, 1}); - auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + std::vector weight_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); const auto& weight_name = input_defs[1]->Name(); + + NodeAttrHelper helper(node); + auto strides = helper.Get("strides", std::vector{1, 1}); + auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + + const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; + const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; + // Support conv1d by prepending a 1 or 2 size dimensions. + if (is_conv1d) { + // Reshape input. + if (is_nhwc) { + // For NHWC preferred layout, the input has been transposed. + // For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2. + input_shape.insert(input_shape.begin() + 2, 1); + } else { + input_shape.push_back(1); + } + std::vector new_shape = GetVecUint32FromVecInt64(input_shape); + input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + + weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed. + strides.resize(2, 1); // Ensure 2D by appending 1's if needed. + dilations.resize(2, 1); // Ensure 2D by appending 1's if needed. + if (pads.size() == 2) { + pads.insert(pads.begin() + 1, 0); + pads.push_back(0); + } + } + emscripten::val options = emscripten::val::object(); - ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + ORT_RETURN_IF_ERROR(SetConvBaseOptions( + model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); if (op_type == "Conv" || op_type == "ConvInteger") { int groups = options["groups"].as(); - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + if (is_nhwc) { bool depthwise = (groups == input_shape[3] && groups != 1); options.set("inputLayout", emscripten::val("nhwc")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); if (!depthwise) { options.set("filterLayout", emscripten::val("ohwi")); } else { options.set("filterLayout", emscripten::val("ihwo")); } } - emscripten::val filter = model_builder.GetOperand(weight_name); - if (op_type == "Conv") { - output = model_builder.GetBuilder().call("conv2d", input, filter, options); - } else { - emscripten::val x_zero_point = emscripten::val::null(); - emscripten::val w_zero_point = emscripten::val::null(); - if (input_defs.size() >= 3) { - x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); - } else { - x_zero_point = model_builder.GetZeroConstant("uint8"); - } - if (input_defs.size() >= 4) { - w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); - } else { - w_zero_point = model_builder.GetZeroConstant("uint8"); - } - output = model_builder.GetBuilder().call("conv2dInteger", - input, x_zero_point, filter, w_zero_point, options); - } - - } else { - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + } else { // ConvTranspose + if (is_nhwc) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); } - emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + } + + emscripten::val filter = model_builder.GetOperand(weight_name); + if (!is_nhwc && is_conv1d) { + // Reshape weight to 4D for conv1d with NCHW preferred layout. + std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); + filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + } + + if (op_type == "Conv") { + output = model_builder.GetBuilder().call("conv2d", input, filter, options); + } else if (op_type == "ConvInteger") { + emscripten::val x_zero_point = emscripten::val::null(); + emscripten::val w_zero_point = emscripten::val::null(); + if (input_defs.size() >= 3) { + x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + } else { + x_zero_point = model_builder.GetZeroConstant("uint8"); + } + if (input_defs.size() >= 4) { + w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); + } else { + w_zero_point = model_builder.GetZeroConstant("uint8"); + } + output = model_builder.GetBuilder().call("conv2dInteger", + input, x_zero_point, filter, w_zero_point, options); + } else { output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); } + // If it's a conv1d, reshape it back. + if (is_conv1d) { + const auto& output_defs = node.OutputDefs(); + std::vector output_shape; + ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); + std::vector new_shape = GetVecUint32FromVecInt64(output_shape); + output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + } + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } @@ -329,9 +372,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } const auto input_size = input_shape.size(); - if (input_size != 4) { + if (input_size != 4 && input_size != 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size - << ". Only conv 2d is supported."; + << ". Only conv 1d / 2d is supported."; return false; } @@ -342,9 +385,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } const auto weight_size = weight_shape.size(); - if (weight_size != 4) { + if (weight_size != 4 && weight_size != 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size - << ". Only conv 2d is supported."; + << ". Only conv 1d / 2d is supported."; return false; } From 4477f57ee3151287a9759bd09d269f0e258a9eda Mon Sep 17 00:00:00 2001 From: Phoebe Chen Date: Thu, 25 Jan 2024 08:27:05 +0800 Subject: [PATCH 21/23] Enable RISC-V 64-bit Cross-Compiling Support for ONNX Runtime on Linux (#19238) ### Description This pull request introduces the necessary changes to enable RISC-V 64-bit cross-compiling support for the ONNX Runtime on Linux. The RISC-V architecture has gained popularity as an open standard instruction set architecture, and this contribution aims to extend ONNX Runtime's compatibility to include RISC-V, thereby broadening the reach of ONNX models to a wider range of devices. ### Motivation and Context RISC-V is a free and open-source instruction set architecture (ISA) based on established RISC principles. It is provided under open licenses without fees. Due to its extensibility and freedom in both software and hardware, RISC-V is poised for widespread adoption in the future, especially in applications related to AI, parallel computing, and data centers. ### Example Build Command ``` ./build.sh --parallel --config Debug --rv64 --riscv_toolchain_root=/path/to/toolchain/root --skip_tests ``` ### Documentation Updates Relevant sections of the documentation will be updated to reflect the newly supported RISC-V 64-bit cross-compilation feature. https://github.com/microsoft/onnxruntime/pull/19239 --------- Signed-off-by: Phoebe Chen --- cmake/external/xnnpack.cmake | 6 +- cmake/onnxruntime_common.cmake | 4 +- cmake/riscv64.toolchain.cmake | 35 +++++++++ tools/ci_build/build.py | 35 ++++++++- tools/scripts/build_riscv64.sh | 129 +++++++++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 cmake/riscv64.toolchain.cmake create mode 100755 tools/scripts/build_riscv64.sh diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index e661aa51bfc17..41f02ce6f22bc 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -6,10 +6,14 @@ set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(XNNPACK_USE_SYSTEM_LIBS OFF) +endif() + # BF16 instructions cause ICE in Android NDK compiler if(CMAKE_ANDROID_ARCH_ABI STREQUAL armeabi-v7a) set(XNNPACK_ENABLE_ARM_BF16 OFF) -ENDIF() +endif() # fp16 depends on psimd FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 43d5fa9bdee34..6b8c2560b1714 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -189,6 +189,8 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(ARM TRUE) elseif(dumpmachine_output MATCHES "^aarch64.*") set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(RISCV64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -198,7 +200,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() -if (ARM64 OR ARM OR X86 OR X64 OR X86_64) +if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) # msvc compiler report syntax error with cpuinfo arm source files # and cpuinfo does not have code for getting arm uarch info under windows diff --git a/cmake/riscv64.toolchain.cmake b/cmake/riscv64.toolchain.cmake new file mode 100644 index 0000000000000..0fda239f9a628 --- /dev/null +++ b/cmake/riscv64.toolchain.cmake @@ -0,0 +1,35 @@ +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES RISCV_TOOLCHAIN_ROOT) + +if(NOT RISCV_TOOLCHAIN_ROOT) + message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT is not defined. Please set the RISCV_TOOLCHAIN_ROOT variable.") +endif() + +set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_ASM_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH ${RISCV_TOOLCHAIN_ROOT}) +set(CMAKE_SYSROOT "${RISCV_TOOLCHAIN_ROOT}/sysroot") +set(CMAKE_INCLUDE_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/include/") +set(CMAKE_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib/") +set(CMAKE_PROGRAM_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/bin/") + +if(RISCV_QEMU_PATH) + message(STATUS "RISCV_QEMU_PATH=${RISCV_QEMU_PATH} is defined during compilation.") + set(CMAKE_CROSSCOMPILING_EMULATOR "${RISCV_QEMU_PATH};-L;${CMAKE_SYSROOT}") +endif() + +set(CMAKE_CROSSCOMPILING TRUE) + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 6e5cd7b57e403..186bb699ad209 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -328,6 +328,12 @@ def convert_arg_line_to_args(self, arg_line): help="[cross-compiling] Create Windows x86 makefiles. Requires --update and no existing cache " "CMake setup. Delete CMakeCache.txt if needed", ) + parser.add_argument( + "--rv64", + action="store_true", + help="[cross-compiling] Create riscv64 makefiles. Requires --update and no existing cache " + "CMake setup. Delete CMakeCache.txt if needed", + ) parser.add_argument( "--arm", action="store_true", @@ -351,6 +357,18 @@ def convert_arg_line_to_args(self, arg_line): action="store_true", help="[cross-compiling] Create ARM64X Binary.", ) + parser.add_argument( + "--riscv_toolchain_root", + type=str, + default="", + help="Path to RISC-V toolchain root dir. e.g. --riscv_toolchain_root=$HOME/riscv-tools/", + ) + parser.add_argument( + "--riscv_qemu_path", + type=str, + default="", + help="Path to RISC-V qemu. e.g. --riscv_qemu_path=$HOME/qemu-dir/qemu-riscv64", + ) parser.add_argument("--msvc_toolset", help="MSVC toolset to use. e.g. 14.11") parser.add_argument("--windows_sdk_version", help="Windows SDK version to use. e.g. 10.0.19041.0") parser.add_argument("--android", action="store_true", help="Build for Android") @@ -1077,6 +1095,19 @@ def generate_build_tree( "-Donnxruntime_DISABLE_OPTIONAL_TYPE=" + ("ON" if disable_optional_type else "OFF"), ] + if args.rv64: + add_default_definition(cmake_extra_defines, "onnxruntime_CROSS_COMPILING", "ON") + if not args.riscv_toolchain_root: + raise BuildError("The --riscv_toolchain_root option is required to build for riscv64.") + if not args.skip_tests and not args.riscv_qemu_path: + raise BuildError("The --riscv_qemu_path option is required for testing riscv64.") + + cmake_args += [ + "-DRISCV_TOOLCHAIN_ROOT:PATH=" + args.riscv_toolchain_root, + "-DRISCV_QEMU_PATH:PATH=" + args.riscv_qemu_path, + "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join(source_dir, "cmake", "riscv64.toolchain.cmake"), + ] + # By default on Windows we currently support only cross compiling for ARM/ARM64 # (no native compilation supported through this script). if args.arm64 or args.arm64ec or args.arm: @@ -1553,7 +1584,9 @@ def generate_build_tree( ] if is_linux() and platform.machine() == "x86_64": # The following flags needs GCC 8 and newer - cflags += ["-fstack-clash-protection", "-fcf-protection"] + cflags += ["-fstack-clash-protection"] + if not args.rv64: + cflags += ["-fcf-protection"] cxxflags = cflags.copy() if args.use_cuda: cudaflags = cflags.copy() diff --git a/tools/scripts/build_riscv64.sh b/tools/scripts/build_riscv64.sh new file mode 100755 index 0000000000000..65681c0b6307d --- /dev/null +++ b/tools/scripts/build_riscv64.sh @@ -0,0 +1,129 @@ +#!/bin/bash +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + + +# The script is a sample for RISC-V 64-bit cross compilation in +# GNU/Linux, and you should ensure that your environment meets +# ORT requirements. You may need to make changes before using it. + +set -e +set -o pipefail + +# Get directory this script is in +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +OS=$(uname -s) + +if [ "$OS" == "Linux" ]; then + LINUX_DISTRO=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') + if [[ "${LINUX_DISTRO}" == "ubuntu" ]] ;then + DIR_OS="Linux" + else + echo "${LINUX_DISTRO} is not supported" + return 1 + fi +else + echo "$OS is not supported" + return 1 +fi + +function cleanup { + if [ -d "$WORK_DIR" ]; then + rm -rf "$WORK_DIR" + fi +} + +# The riscv toolchain, qemu and other platform related settings. +ORT_ROOT_DIR=$DIR/../.. + +PREBUILT_DIR="${ORT_ROOT_DIR}/riscv_tools" + +read -rp "Enter the riscv tools root path(press enter to use default path:${PREBUILT_DIR}): " INPUT_PATH +if [[ "${INPUT_PATH}" ]]; then + PREBUILT_DIR=${INPUT_PATH} +fi +echo "The riscv tool prefix path: ${PREBUILT_DIR}" + +WORK_DIR=$DIR/.prebuilt + +# The prebuit toolchain download from riscv-collab works with Ubuntu. +RISCV_GNU_TOOLCHAIN_URL="https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download" +TOOLCHAIN_VERSION="2023.11.20" +RISCV_TOOLCHAIN_FILE_NAME="riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.11.20-nightly.tar.gz" +RISCV_TOOLCHAIN_FILE_SHA="98d6531b757fac01e065460c19abe8974976c607a8d88631cc5c1529d90ba7ba" + +TOOLCHAIN_PATH_PREFIX=${PREBUILT_DIR} + +execute () { + if ! eval "$1"; then + echo "command:\"$1\" error" + exit 1 + fi +} + +execute "mkdir -p $WORK_DIR" + +# Call the cleanup function when this tool exits. +trap cleanup EXIT + +# Download and install the toolchain from +# https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download +download_file() { + local file_name="$1" + local install_path="$2" + local file_sha="$3" + + echo "Install $1 to $2" + if [[ "$(ls -A "$2")" ]]; then + read -rp "The file already exists. Keep it (y/n)? " replaced + case ${replaced:0:1} in + y|Y ) + echo "Skip download $1." + return + ;; + * ) + rm -rf "$2" + ;; + esac + fi + + echo "Download ${file_name} ..." + mkdir -p "$install_path" + wget --progress=bar:force:noscroll --directory-prefix="${WORK_DIR}" \ + "${RISCV_GNU_TOOLCHAIN_URL}/${TOOLCHAIN_VERSION}/${file_name}" && \ + echo "${file_sha} ${WORK_DIR}/${file_name}" | sha256sum -c - + echo "Extract ${file_name} ..." + tar -C "${install_path}" -xf "${WORK_DIR}/${file_name}" --no-same-owner \ + --strip-components=1 +} + + +read -rp "Install RISCV toolchain(y/n)? " answer +case ${answer:0:1} in + y|Y ) + download_file "${RISCV_TOOLCHAIN_FILE_NAME}" \ + "${TOOLCHAIN_PATH_PREFIX}" \ + "${RISCV_TOOLCHAIN_FILE_SHA}" + ;; + * ) + echo "Skip install RISCV toolchain." + ;; +esac +echo "download finished." + + +# RISC-V cross compilation in GNU/Linux +RISCV_TOOLCHAIN_ROOT=${TOOLCHAIN_PATH_PREFIX} +RISCV_QEMU_PATH=${TOOLCHAIN_PATH_PREFIX}/bin/qemu-riscv64 +python3 "${ORT_ROOT_DIR}"/tools/ci_build/build.py \ + --build_dir "${ORT_ROOT_DIR}/build/${DIR_OS}" \ + --rv64 \ + --parallel \ + --skip_tests \ + --config RelWithDebInfo \ + --cmake_generator=Ninja \ + --riscv_qemu_path="${RISCV_QEMU_PATH}" \ + --riscv_toolchain_root="${RISCV_TOOLCHAIN_ROOT}" "$@" + + From 7dd1f4b8e27f38b55f2430f84ddaae1128bef9f4 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 24 Jan 2024 18:12:04 -0800 Subject: [PATCH 22/23] Pad-18 Cuda implementation (#19211) ### Description Implement Pad-18 for Cuda. ### Motivation and Context Latest models converted by Dynamo fall back on CPU for Pad with performance degradation. This contributes to https://github.com/microsoft/onnx-rewriter/issues/126 --- docs/OperatorKernels.md | 3 +- .../core/providers/cpu/cpu_provider_shared.cc | 8 +- .../core/providers/cpu/cpu_provider_shared.h | 8 +- onnxruntime/core/providers/cpu/tensor/pad.cc | 252 +++++++++--------- .../core/providers/cpu/tensor/padbase.h | 77 +++++- .../providers/cuda/cuda_execution_provider.cc | 38 +-- onnxruntime/core/providers/cuda/tensor/pad.cc | 37 ++- .../providers/rocm/rocm_execution_provider.cc | 26 +- .../provider_bridge_provider.cc | 9 +- 9 files changed, 287 insertions(+), 171 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 31cca232fde34..9d9b266355335 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -682,7 +682,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 9c55d37f550f4..bf73c59fb78ca 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -87,7 +87,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); } // From cpu/tensor/padbase.h (direct) - Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + + void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) override { + PadBase::ComputePads(ctx, data_rank, pads_data, pads); + } + // From cpu/tensor/split.h (direct) Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 8dee1cd620282..f33eec4b93e98 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -25,6 +25,8 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +using PadsVector = InlinedVector; + struct ProviderHostCPU { // From cpu/tensor/gatherbase.h virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0; @@ -44,7 +46,11 @@ struct ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) = 0; // From cpu/tensor/padbase.h - virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0; + virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0; + + virtual void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) = 0; + // From cpu/tensor/split.h virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc index fe5267f20712b..912280687e229 100644 --- a/onnxruntime/core/providers/cpu/tensor/pad.cc +++ b/onnxruntime/core/providers/cpu/tensor/pad.cc @@ -9,6 +9,8 @@ #include "core/providers/op_kernel_type_control.h" #include "core/util/math.h" +#include + // there's no way to use a raw pointer as the copy destination with std::copy_n // (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset // without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. @@ -167,47 +169,7 @@ ONNX_CPU_OPERATOR_KERNEL( using PadsVector = PadBase::PadsVector; -// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) -template -static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, - size_t block_size, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - for (size_t i = 0; i < block_size; i++) { - *output++ = *input; - input += input_delta; - } - input += input_pitch; - } -} - -// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, -// and inputPitch and inputDelta are just a single value added each iteration. -template -static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - *output++ = *input; - input += input_delta; - } -} - -// For constant padding, there is no input, just a size to write the constant to -template -static void PadAxisConstant(T* output, T constant, size_t size) { - if (size == 1) { - *output = constant; - } else if (size == 2) { - *output = constant; - *(output + 1) = constant; - } else { - // This would be faster with SSE instructions. - // That would mean to have an implementation for each type (uint8, uint32, uint64). - T* end = output + size; - for (; output != end;) - *output++ = constant; - } -} - -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { switch (mode) { case Mode::Constant: { // default behavior is fine @@ -242,34 +204,66 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh return Status::OK(); } -// special handling for edge case where the input has one or more dims with value of 0 -template -static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, - const Mode& mode, - const TensorShape& input_shape, - TensorShapeVector& output_dims, - T value) { - TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); - - auto& output_tensor = *ctx->Output(0, output_shape); - - // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty - if (mode == Mode::Constant) { - // we add pads with the default value to all dims including those with a value of 0 - auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); - std::fill_n(output, output_shape.Size(), value); +static void ComputePadWithAxes( + gsl::span pads_tensor_raw_data, + std::function get_axis, + size_t axes_size, + size_t data_rank, + PadsVector& pads) { + for (size_t i = 0; i < axes_size; ++i) { + const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); + pads[axis] = pads_tensor_raw_data[i]; // xi_begin + pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end } +} - return Status::OK(); +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + pads.reserve(2 * data_rank); + const Tensor* axes_tensor = ctx.Input(3); + if (axes_tensor) { + const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); + ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); + + const int64_t num_axes = axes_tensor->Shape().Size(); + ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), + "Pads tensor size should be equal to twice the number of explicitly provided axes."); + + pads.resize(2 * data_rank, 0); + if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) -> int64_t { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } else if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } + } else { + ORT_ENFORCE(pads_data.size() == 2 * data_rank, + "Pads tensor size should be equal to twice the input dimension count "); + pads.assign(pads_data.begin(), pads_data.end()); + } } // Flatten no padding inner most Axis, so one memcpy cover multiple Axis. // For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as // [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. -static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVector& pads, - const PadsVector& slices, TensorShapeVector& reshaped_dims) { - size_t dims_count = input_dims.size(); +void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims) { + const size_t dims_count = input_dims.size(); size_t inner_axis = dims_count - 1; size_t inner_size = 1; @@ -288,14 +282,14 @@ static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVec } while (inner_axis-- > 0); reshaped_dims.reserve(inner_axis + 1); - std::copy(input_dims.cbegin(), input_dims.cbegin() + inner_axis + 1, std::back_inserter(reshaped_dims)); + std::copy(input_dims.begin(), input_dims.begin() + inner_axis + 1, std::back_inserter(reshaped_dims)); // Flatten inner axis. reshaped_dims[inner_axis] = inner_size; } -static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t new_dim_count, - size_t inner_no_pad_size, PadsVector& reshaped_pad) { +void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad) { size_t inner_axis = new_dim_count - 1; std::copy(src_pad.begin(), src_pad.begin() + inner_axis, reshaped_pad.begin()); std::copy(src_pad.begin() + src_dim_count, src_pad.begin() + src_dim_count + inner_axis, @@ -306,6 +300,68 @@ static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t reshaped_pad[inner_axis + new_dim_count] = src_pad[inner_axis + src_dim_count] * inner_no_pad_size; } +// special handling for edge case where the input has one or more dims with value of 0 +template +static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, + const Mode& mode, + const TensorShape& input_shape, + TensorShapeVector& output_dims, + T value) { + TensorShape output_shape(output_dims); + ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); + + auto& output_tensor = *ctx->Output(0, output_shape); + + // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty + if (mode == Mode::Constant) { + // we add pads with the default value to all dims including those with a value of 0 + auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); + std::fill_n(output, output_shape.Size(), value); + } + + return Status::OK(); +} + +// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) +template +static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, + size_t block_size, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + for (size_t i = 0; i < block_size; i++) { + *output++ = *input; + input += input_delta; + } + input += input_pitch; + } +} + +// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, +// and inputPitch and inputDelta are just a single value added each iteration. +template +static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + *output++ = *input; + input += input_delta; + } +} + +// For constant padding, there is no input, just a size to write the constant to +template +static void PadAxisConstant(T* output, T constant, size_t size) { + if (size == 1) { + *output = constant; + } else if (size == 2) { + *output = constant; + *(output + 1) = constant; + } else { + // This would be faster with SSE instructions. + // That would mean to have an implementation for each type (uint8, uint32, uint64). + T* end = output + size; + for (; output != end;) + *output++ = constant; + } +} + template static Status PadImpl(OpKernelContext* ctx, const PadsVector& pads, @@ -327,7 +383,7 @@ static Status PadImpl(OpKernelContext* ctx, // Reshape input dims TensorShapeVector reshaped_input_dims; - FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); + PadBase::FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); // Reshape padding size_t new_dims_count = reshaped_input_dims.size(); @@ -336,8 +392,8 @@ static Status PadImpl(OpKernelContext* ctx, ? reshaped_input_dims[inner_axis] / output_dims[inner_axis] : 0); PadsVector reshaped_pad(2 * new_dims_count), reshaped_slice(2 * new_dims_count); - ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); - ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); + PadBase::ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); + PadBase::ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); TensorShapeVector reshaped_output_dims = reshaped_input_dims; TensorShapeVector input_starts; @@ -575,20 +631,6 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) { return result; } -template -void ComputePadWithAxes( - gsl::span pads_tensor_raw_data, - gsl::span axes_tensor_raw_data, - size_t data_rank, - PadsVector& pads) { - size_t axes_size = axes_tensor_raw_data.size(); - for (size_t i = 0; i < axes_size; ++i) { - int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank); - pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin - pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end - } -} - Status Pad::Compute(OpKernelContext* ctx) const { const Tensor& input_tensor = *ctx->Input(0); MLDataType data_type = input_tensor.DataType(); @@ -608,48 +650,14 @@ Status Pad::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), "Pads tensor should be a 1D tensor of shape [2 * num_axes] " "or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - pads.reserve(2 * data_rank); - - const Tensor* axes_tensor = ctx->Input(3); - if (axes_tensor) { - const auto& axes_tensor_dims = axes_tensor->Shape().GetDims(); - ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor "); - int64_t axes_size = axes_tensor_dims[0]; - - pads.resize(2 * data_rank, 0); - if (axes_tensor->IsDataType()) { - const int32_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } else if (axes_tensor->IsDataType()) { - const int64_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } - } else { - ORT_ENFORCE(pads_size == 2 * data_rank, - "Pads tensor size should be equal to twice the input dimension count "); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } - } + + const auto pads_data = pads_tensor.DataAsSpan(); + + // Compute Pads by applying axes if specified otherwise copy the supplied pads. + PadBase::ComputePads(*ctx, data_rank, pads_data, pads); // Separate out any negative pads into the slices array - slices.assign(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); value.u64 = 0U; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h index d869ed1a6dda2..43f9cbfc9f9a4 100644 --- a/onnxruntime/core/providers/cpu/tensor/padbase.h +++ b/onnxruntime/core/providers/cpu/tensor/padbase.h @@ -19,9 +19,80 @@ class PadBase { // Pads and slices are usually about twice the shapes involved using PadsVector = InlinedVector; - // Update the output_shape to make it consistent with numpy handling where there are one or more dimensions - // in the input_shape with a value of zero. - static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape); + // The following several functions are shared among the providers + + /// + /// Handle the case when the input shape has zero dim values. + /// Depending on the mode, the input dim with zero value must match the output dim value. + /// + /// + /// Padding mode enum value + /// actual input shape + /// output_shape + /// Error if current mode padding can not be achieved with zero dim values + static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape); + + /// + /// Compute Pads by applying axes if specified otherwise copy the supplied pads. + /// + /// The function queries optional axes input (since version 18) and if present, + /// applies it as a mask to the pads. If axes is not present, the pads are copied as is. + /// If axes are present, they are used as a mask over pads, so only those axes are being padded. + /// + /// kernel context to query axes input + /// input rank + /// pads data from pads input + /// resulting pads + static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads); + + /// + /// Separates negative pad values to slices and zeros them out in original pads. + /// Leaving the rest of slices values as zero. + /// + /// This function is used inline in the Pad CUDA implementation and is not exposed via a provider + /// interfaces. + /// + /// pad values + /// slices output + static void SeparateNegativeToSlices(gsl::span pads, PadsVector& slices) { + slices.assign(pads.size(), 0); + for (size_t index = 0, lim = pads.size(); index < lim; index++) { + if (pads[index] < 0) { + slices[index] = pads[index]; + pads[index] = 0; + } + } + } + + // End provider shared + + /// + /// Flatten no padding inner most Axis, so one memcpy cover multiple Axis. + /// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as + /// [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. + /// + /// This is a helper function pads are expected to be twice the rank + /// + /// original input dims + /// pad values + /// slices + /// result dims + static void FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims); + + /// + /// Used after the inner shape is flattened, so we can apply this function to pads and slices + /// to reshape them as well. + /// + /// pads + /// original dim count + /// expected flattended dim count + /// is the left most dimension that was flattened. + /// In the example above, that would be 224, reverse computed from 224*3 + /// resulting reshaped pads or slices + static void ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad); protected: PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 644bcaaa24cd4..3fc4ed355a12b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1121,10 +1121,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1269,6 +1269,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2008,10 +2012,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2091,13 +2095,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2150,11 +2147,22 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 4584e5fd8272c..bdd6567d2ef34 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -29,15 +29,27 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); @@ -94,28 +106,15 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { if (is_dynamic_) { const Tensor& pads_tensor = *ctx->Input(1); const auto pads_tensor_dims = pads_tensor.Shape().GetDims(); - ORT_ENFORCE(utils::IsPrimitiveDataType(pads_tensor.DataType()), - "Pads tensor should be an INT64 tensor"); ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), - "Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]"); + "Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), - "Pads tensor size should be equal to twice the input dimension count "); + const auto pads_data = pads_tensor.DataAsSpan(); + + PadBase::ComputePads(*ctx, input_shape.NumDimensions(), pads_data, pads); - pads.reserve(2LL * dimension_count); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } // Separate out any negative pads into the slices array - slices.resize(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); T raw_value{}; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7bec337a6be4..fff3d14b763d5 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1158,10 +1158,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1298,6 +1298,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2088,10 +2093,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2228,6 +2233,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 18 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index a3155fe6b86cf..e1d0e310425c5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -547,7 +547,14 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape, const TensorShape& indice_shape, const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); } -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); } +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { + return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); +} + +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads); +} Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors, Prepare& p) const { From 2b87dd373a3567c2c426e2f090b201b8b051a346 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 25 Jan 2024 10:16:41 +0800 Subject: [PATCH 23/23] [ORTModule] Remove Mod from Hash to Avoid Conflict for Triton Code-gen (#19256) Remove mod (10**8) from hash to avoid conflict for Triton code-gen. --- .../python/training/ort_triton/kernel/_mm.py | 20 +++++++++---------- .../training/ort_triton/triton_op_executor.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py index ed92923589d48..a3681a13699a0 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py @@ -11,7 +11,7 @@ import torch from .._cache import ModuleCache, PyCodeCache -from .._utils import next_power_of_2 +from .._utils import gen_unique_name, next_power_of_2 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 @@ -305,18 +305,18 @@ def _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name): def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float) -> int: - return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") def _gen_mm_module( dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"mm_{_gen_mm_key(dtype, m, n, k, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("mm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) src_code = _MM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -333,7 +333,7 @@ def _gen_gemm_key( alpha: float, beta: float, ) -> int: - return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") % (10**8) + return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") def _gen_gemm_module( @@ -348,7 +348,7 @@ def _gen_gemm_module( alpha: float, beta: float, ) -> Tuple[str, ModuleType]: - func_name = f"gemm_{_gen_gemm_key(dtype, m, n, k, stride_cm, stride_cn, trans_a, trans_b, alpha, beta)}" + func_name = gen_unique_name("gemm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) kwargs["stride_cm"] = stride_cm kwargs["stride_cn"] = stride_cn @@ -356,7 +356,7 @@ def _gen_gemm_module( src_code = _GEMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -364,13 +364,13 @@ def _gen_gemm_module( def _gen_bmm_key( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> int: - return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") def _gen_bmm_module( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"bmm_{_gen_bmm_key(dtype, m, n, k, batch_a, batch_b, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("bmm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) batch = batch_a if batch_a >= batch_b else batch_b kwargs["stride_aq"] = m * k if batch_a == batch else 0 @@ -379,7 +379,7 @@ def _gen_bmm_module( src_code = _BMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 1fe61750e651e..f16abc71251ed 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -67,7 +67,7 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument - return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) + return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: