From beb2496748b112ba0b2525c14f1093acbd98c7aa Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 2 Jul 2024 09:24:19 +0800 Subject: [PATCH 01/13] Templatize publishing nuget package (#21199) ### Description It's the prerequisite step of reducing complexity of current zip-nuget pipeline. Some packaging tasks could be cut from the most complex nuget pipline and easily be published ### Motivation and Context --- .../github/azure-pipelines/publish-nuget.yml | 179 +++--------------- .../templates/publish-nuget-steps.yml | 136 +++++++++++++ 2 files changed, 164 insertions(+), 151 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index 367977ff59192..5e827980e039c 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -10,154 +10,31 @@ resources: branch: main stages: -- stage: Publish_NuGet_Package_And_Report - jobs: - - job: Publish_NuGet_Package_And_Report - workspace: - clean: all - variables: - - name: GDN_CODESIGN_TARGETDIRECTORY - value: '$(Agent.TempDirectory)\binfiles' - pool: 'onnxruntime-Win-CPU-2022' - - steps: - # https://learn.microsoft.com/en-us/azure/devops/pipelines/yaml-schema/resources-pipelines-pipeline?view=azure-pipelines#pipeline-resource-metadata-as-predefined-variables - - script: | - echo $(resources.pipeline.build.sourceBranch) - echo $(Build.Reason) - displayName: 'Print triggering sourceBranch Name in resources' - - - checkout: self - submodules: false - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.9' - addToPath: true - - - template: templates/set-version-number-variables-step.yml - - - script: mkdir "$(Build.BinariesDirectory)\nuget-artifact\final-package" - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-CPU' - - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-CPU\*" "$(Build.BinariesDirectory)\nuget-artifact\final-package" - - - template: nuget/templates/get-nuget-package-version-as-variable.yml - parameters: - packageFolder: '$(Build.BinariesDirectory)/nuget-artifact/final-package' - - - task: CmdLine@2 - displayName: 'Post binary sizes to the dashboard database using command line' - inputs: - script: | - echo changing directory to artifact download path - cd $(Build.BinariesDirectory)/nuget-artifact/final-package - echo processing nupkg - SETLOCAL EnableDelayedExpansion - FOR /R %%i IN (*.nupkg) do ( - set filename=%%~ni - IF NOT "!filename:~25,7!"=="Managed" ( - echo processing %%~ni.nupkg - copy %%~ni.nupkg %%~ni.zip - echo copied to zip - echo listing lib files in the zip - REM use a single .csv file to put the data - echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - ) - ) - - - task: AzureCLI@2 - displayName: 'Azure CLI' - #Only report binary sizes to database if the build build was auto-triggered from the main branch - condition: and (succeeded(), and(eq(variables['resources.pipeline.build.sourceBranch'], 'refs/heads/main'), eq(variables['Build.Reason'], 'ResourceTrigger'))) - inputs: - azureSubscription: AIInfraBuildOnnxRuntimeOSS - scriptLocation: inlineScript - scriptType: batch - inlineScript: | - python.exe -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_to_dashboard\requirements.txt && ^ - python.exe $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_binary_sizes_to_dashboard.py --commit_hash=$(Build.SourceVersion) --size_data_file=binary_size_data.txt --build_project=Lotus --build_id=$(Build.BuildId) - workingDirectory: '$(Build.BinariesDirectory)' - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-dml' - - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-Training-CPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-GPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' - artifact: 'drop-signed-nuget-ROCm' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' - artifact: 'drop-signed-nuget-qnn' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-qnn\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - script: | - dir $(Build.BinariesDirectory)\nuget-artifact\final-package - cd $(Build.BinariesDirectory)\nuget-artifact\final-package - nuget verify -Signatures *.nupkg - displayName: List Downloaded Package - - - powershell: | - New-Item -Path $(Agent.TempDirectory) -Name "binfiles" -ItemType "directory" - $base_path_name = Join-Path -Path $(Agent.TempDirectory) -ChildPath "binfiles" - Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\final-package -Filter *.nupkg | - Foreach-Object { - $dir_name = Join-Path -Path $base_path_name -ChildPath $_.Basename - $cmd = "7z.exe x $($_.FullName) -y -o$dir_name" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - dir $(Agent.TempDirectory) - tree $(Agent.TempDirectory) - workingDirectory: '$(Agent.TempDirectory)' - - - task: CodeSign@1 - displayName: 'Run Codesign Validation' - - - - task: PublishSecurityAnalysisLogs@3 - displayName: 'Publish Security Analysis Logs' - continueOnError: true - - - task: PostAnalysis@2 - inputs: - GdnBreakAllTools: true - GdnBreakPolicy: M365 - GdnBreakPolicyMinSev: Error - - #TODO: allow choosing different feeds - - task: NuGetCommand@2 - displayName: 'Copy Signed Native NuGet Package to ORT-NIGHTLY' - inputs: - command: 'push' - packagesToPush: '$(Build.BinariesDirectory)/nuget-artifact/final-package/*.nupkg' - publishVstsFeed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/7982ae20-ed19-4a35-a362-a96ac99897b7' - allowPackageConflicts: true - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() + - template: templates/publish-nuget-steps.yml + parameters: + include_cpu_ep: true + download_artifacts_steps: + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-dml' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-Training-CPU' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-GPU' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' + artifact: 'drop-signed-nuget-ROCm' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' + artifact: 'drop-signed-nuget-qnn' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-qnn\*" $(Build.BinariesDirectory)\nuget-artifact\final-package diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml b/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml new file mode 100644 index 0000000000000..6698501e74bad --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml @@ -0,0 +1,136 @@ +parameters: +- name: include_cpu_ep + type: boolean + default: false +- name: download_artifacts_steps + type: stepList + +stages: +- stage: Publish_NuGet_Package_And_Report + jobs: + - job: Publish_NuGet_Package_And_Report + workspace: + clean: all + variables: + - name: GDN_CODESIGN_TARGETDIRECTORY + value: '$(Agent.TempDirectory)\binfiles' + pool: 'onnxruntime-Win-CPU-2022' + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + # https://learn.microsoft.com/en-us/azure/devops/pipelines/yaml-schema/resources-pipelines-pipeline?view=azure-pipelines#pipeline-resource-metadata-as-predefined-variables + - script: | + echo $(resources.pipeline.build.sourceBranch) + echo $(Build.Reason) + displayName: 'Print triggering sourceBranch Name in resources' + + - checkout: self + submodules: false + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.9' + addToPath: true + + - template: set-version-number-variables-step.yml + + - script: mkdir "$(Build.BinariesDirectory)\nuget-artifact\final-package" + + - template: ../nuget/templates/get-nuget-package-version-as-variable.yml + parameters: + packageFolder: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + + - ${{if eq(parameters.include_cpu_ep, true)}}: + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-CPU' + + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-CPU\*" "$(Build.BinariesDirectory)\nuget-artifact\final-package" + + - task: CmdLine@2 + displayName: 'Post binary sizes to the dashboard database using command line' + inputs: + script: | + echo changing directory to artifact download path + cd $(Build.BinariesDirectory)/nuget-artifact/final-package + echo processing nupkg + SETLOCAL EnableDelayedExpansion + FOR /R %%i IN (*.nupkg) do ( + set filename=%%~ni + IF NOT "!filename:~25,7!"=="Managed" ( + echo processing %%~ni.nupkg + copy %%~ni.nupkg %%~ni.zip + echo copied to zip + echo listing lib files in the zip + REM use a single .csv file to put the data + echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + ) + ) + + - task: AzureCLI@2 + displayName: 'Azure CLI' + #Only report binary sizes to database if the build build was auto-triggered from the main branch + condition: and (succeeded(), and(eq(variables['resources.pipeline.build.sourceBranch'], 'refs/heads/main'), eq(variables['Build.Reason'], 'ResourceTrigger'))) + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptLocation: inlineScript + scriptType: batch + inlineScript: | + python.exe -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_to_dashboard\requirements.txt && ^ + python.exe $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_binary_sizes_to_dashboard.py --commit_hash=$(Build.SourceVersion) --size_data_file=binary_size_data.txt --build_project=Lotus --build_id=$(Build.BuildId) + workingDirectory: '$(Build.BinariesDirectory)' + + - ${{ parameters.download_artifacts_steps }} + + - script: | + dir $(Build.BinariesDirectory)\nuget-artifact\final-package + cd $(Build.BinariesDirectory)\nuget-artifact\final-package + nuget verify -Signatures *.nupkg + displayName: List Downloaded Package + + - powershell: | + New-Item -Path $(Agent.TempDirectory) -Name "binfiles" -ItemType "directory" + $base_path_name = Join-Path -Path $(Agent.TempDirectory) -ChildPath "binfiles" + Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\final-package -Filter *.nupkg | + Foreach-Object { + $dir_name = Join-Path -Path $base_path_name -ChildPath $_.Basename + $cmd = "7z.exe x $($_.FullName) -y -o$dir_name" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + dir $(Agent.TempDirectory) + tree $(Agent.TempDirectory) + workingDirectory: '$(Agent.TempDirectory)' + + - task: CodeSign@1 + displayName: 'Run Codesign Validation' + + + - task: PublishSecurityAnalysisLogs@3 + displayName: 'Publish Security Analysis Logs' + continueOnError: true + + - task: PostAnalysis@2 + inputs: + GdnBreakAllTools: true + GdnBreakPolicy: M365 + GdnBreakPolicyMinSev: Error + + #TODO: allow choosing different feeds + - task: NuGetCommand@2 + displayName: 'Copy Signed Native NuGet Package to ORT-NIGHTLY' + inputs: + command: 'push' + packagesToPush: '$(Build.BinariesDirectory)/nuget-artifact/final-package/*.nupkg' + publishVstsFeed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/7982ae20-ed19-4a35-a362-a96ac99897b7' + allowPackageConflicts: true + + - template: component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' From 7be1d4aad3f984ebe2c4fb0f7db0b9ca67cc8964 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 1 Jul 2024 22:55:20 -0700 Subject: [PATCH 02/13] [TensorRT EP] Update TRT10.0 deprecated api (#20989) ### Description Note: * This PR would remove C4996 suppression in tensorrt_execution_provider.cc only (according to Nvidia, places with nvinfer.h included need C4996 suppression, when /Zc:__cplusplus is enabled in ORT win build) * A follow-up PR will be raised to update deprecated TRT Plugin api usage. Here are deprecated apis to be updated in this PR: | deprecated api | Update | | ------------------------------------------------------------ | ------------------------------------------------------------ | | [kCUBLAS](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#a9e1d81e5a8bfeb38b86e22a66d5f836a) | / | | [kCUBLAS_LT](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#a9e1d81e5a8bfeb38b86e22a66d5f836a) | / | | [kCUDNN](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#a9e1d81e5a8bfeb38b86e22a66d5f836a) | / | | [reallocateOutput](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1v__1__0_1_1_i_output_allocator.html#acae6441d4029584cc1c6550917518691) | Superseded by [reallocateOutputAsync](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1v__1__0_1_1_i_output_allocator.html#aa40eeb891c1dfe4c1bbf1eabe8c705ab) with cudaStream_t argument | | [createExecutionContextWithoutDeviceMemory](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#adc86bcc42b098204997396ef2b1093fb) | Superseded by [createExecutionContext()](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#a35de29aa6134165a5b14a537e6d99e82) with parameter.
Check [ExecutionContextAllocationStrategy::kUSER_MANAGED](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#ac6251a050df629edfc0ce037fa366503) for more detail | ### Motivation and Context TRT deprecated api list: https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/deprecated.html --- .../tensorrt/tensorrt_execution_provider.cc | 58 +++++++++++++------ .../tensorrt/tensorrt_execution_provider.h | 5 +- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 3ca0935b9e46c..be924d6a68268 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -169,11 +169,20 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSource source{}; t = toUpper(t); if (t == "CUBLAS") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 source = nvinfer1::TacticSource::kCUBLAS; +#endif } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif } else if (t == "CUDNN") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 source = nvinfer1::TacticSource::kCUDNN; +#endif } else if (t == "EDGE_MASK_CONVOLUTIONS") { source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; } else if (t == "JIT_CONVOLUTIONS") { @@ -298,6 +307,25 @@ void CudaCall(cudnnStatus_t retCode, const char* exprString return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#if NV_TENSORRT_MAJOR >= 10 +void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#else +// Only override this method when TensorRT <= 8.6 void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -314,6 +342,7 @@ void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*curr // if cudaMalloc fails, returns nullptr. return outputPtr; } +#endif void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { output_shapes.clear(); @@ -3152,14 +3181,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } - -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); @@ -3606,14 +3631,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (context_update) { if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 *(trt_state->context) = std::unique_ptr( trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { *(trt_state->context) = std::unique_ptr( @@ -3830,13 +3853,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f4dae57487f51..ec140579569b9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -116,8 +116,11 @@ using unique_pointer = std::unique_ptr; // class OutputAllocator : public nvinfer1::IOutputAllocator { public: +#if NV_TENSORRT_MAJOR >= 10 + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; +#else void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override; - +#endif void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; void* getBuffer() { From 7df97f1987dcdb798e0c22b3d3ae8f27dfa6a82e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 2 Jul 2024 11:24:04 -0700 Subject: [PATCH 03/13] Add debugging helper to dump string, vector and thread id (#21224) ### Description Add some macro to help print data to console for debugging purpose. Example usage: ``` int input_id; vector some_vector; DUMP_CPU_TENSOR_INIT() DUMP_CPU_TENSOR("some vector", some_vector); DUMP_STRING("input_id=", input_id); ``` - To enable dump thread id, set environment variable `ORT_DUMP_THREAD_ID=0`. - User can disable dumping by environment variable `ORT_ENABLE_CPU_DUMP=0`. ### Motivation and Context --- .../contrib_ops/cpu/utils/console_dumper.h | 2 + .../contrib_ops/cpu/utils/debug_macros.h | 3 ++ .../contrib_ops/cpu/utils/dump_tensor.cc | 52 ++++++++++++++++++- .../contrib_ops/cpu/utils/dump_tensor.h | 11 +++- .../cuda/utils/dump_cuda_tensor.cc | 8 +++ .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 2 + 6 files changed, 75 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 3c255879df199..2782a59d4326d 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -37,6 +37,8 @@ class IConsoleDumper { virtual void Print(const char* name, int index, bool end_line) const = 0; virtual void Print(const char* name, const std::string& value, bool end_line) const = 0; + virtual void Print(const std::string& value) const = 0; + protected: bool is_enabled_; }; diff --git a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h index 37a9b0160ade9..d5cbaa0a3e6b7 100644 --- a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h +++ b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common/make_string.h" // #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) @@ -14,9 +15,11 @@ #if DUMP_CPU_TENSOR_LEVEL > 0 #define DUMP_CPU_TENSOR_INIT() onnxruntime::contrib::CpuTensorConsoleDumper cpu_dumper #define DUMP_CPU_TENSOR(...) cpu_dumper.Print(__VA_ARGS__) +#define DUMP_STRING(...) cpu_dumper.Print(::onnxruntime::MakeString(__VA_ARGS__)) #else #define DUMP_CPU_TENSOR_INIT() #define DUMP_CPU_TENSOR(...) +#define DUMP_STRING(...) #endif #if DUMP_CPU_TENSOR_LEVEL > 1 diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 3a5deef35d6d6..87a9cd3965763 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -1,18 +1,38 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include "contrib_ops/cpu/utils/dump_tensor.h" +#include +#include +#include +#include #include "core/framework/print_tensor_utils.h" #include "contrib_ops/cpu/utils/debug_macros.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { #if DUMP_CPU_TENSOR_LEVEL > 0 +// Environment variable to enable/disable dumping +constexpr const char* kEnableCpuTensorDumper = "ORT_ENABLE_CPU_DUMP"; + +// Environment variable to enable/disable dumping thread id +constexpr const char* kDumpThreadId = "ORT_DUMP_THREAD_ID"; + +// To avoid dumping at the same time from multiple threads +static std::mutex s_mutex; + +static bool s_output_thread_id = false; + template void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1) { + std::unique_lock lock(s_mutex); + + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + if (nullptr != name) { std::cout << std::string(name) << std::endl; } @@ -26,6 +46,11 @@ void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1) { template void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) { + std::unique_lock lock(s_mutex); + + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + if (nullptr != name) { std::cout << std::string(name) << std::endl; } @@ -93,6 +118,21 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) { DumpCpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +CpuTensorConsoleDumper::CpuTensorConsoleDumper() { + is_enabled_ = ParseEnvironmentVariableWithDefault(kEnableCpuTensorDumper, 1) != 0; + s_output_thread_id = ParseEnvironmentVariableWithDefault(kDumpThreadId, 0) != 0; +} + +void CpuTensorConsoleDumper::Print(const std::string& value) const { + if (!is_enabled_) + return; + + std::unique_lock lock(s_mutex); + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + std::cout << value << std::endl; +} + void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { if (!is_enabled_) return; @@ -185,6 +225,8 @@ void CpuTensorConsoleDumper::Print(const char* name, const OrtValue& value) cons void CpuTensorConsoleDumper::Print(const char* name, int index, bool end_line) const { if (!is_enabled_) return; + + std::unique_lock lock(s_mutex); std::cout << std::string(name) << "[" << index << "]"; if (end_line) { @@ -196,6 +238,7 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b if (!is_enabled_) return; + std::unique_lock lock(s_mutex); std::cout << std::string(name) << "=" << value; if (end_line) { @@ -204,6 +247,12 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b } #else +CpuTensorConsoleDumper::CpuTensorConsoleDumper() { +} + +void CpuTensorConsoleDumper::Print(const std::string&) const { +} + void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const { } @@ -254,7 +303,6 @@ void CpuTensorConsoleDumper::Print(const char*, int, bool) const { void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } - #endif } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index d902806fd0d18..f102eae6ec709 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include "core/framework/ort_value.h" #include "contrib_ops/cpu/utils/console_dumper.h" @@ -11,7 +12,7 @@ namespace contrib { class CpuTensorConsoleDumper : public IConsoleDumper { public: - CpuTensorConsoleDumper() = default; + CpuTensorConsoleDumper(); virtual ~CpuTensorConsoleDumper() {} void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; @@ -33,6 +34,14 @@ class CpuTensorConsoleDumper : public IConsoleDumper { void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; void Print(const char* name, const std::string& value, bool end_line) const override; + + void Print(const std::string& value) const override; + + // Output a vector with a threshold for max number of elements to output. Default threshold 0 means no limit. + template + void Print(const char* name, const std::vector& vec, size_t max_count = 0) const { + this->Print(name, vec.data(), 1, static_cast(std::min(max_count, vec.size()))); + } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index fb7af3cfdd54f..e10c2ec63fd51 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -202,6 +202,10 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +void CudaTensorConsoleDumper::Print(const std::string& value) const { + std::cout << value << std::endl; +} + void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, true); @@ -325,6 +329,10 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else + +void CudaTensorConsoleDumper::Print(const std::string&) const { +} + void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 0f25e85bb97d7..6ad0ad9a67b75 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -46,6 +46,8 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; void Print(const char* name, const std::string& value, bool end_line) const override; + + void Print(const std::string& value) const override; }; } // namespace cuda From 116398c1a43ed20b62ea506e676c38d0614e99ca Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 2 Jul 2024 15:37:50 -0700 Subject: [PATCH 04/13] onnxruntime shared lib inside python package (#21223) --- cmake/onnxruntime_python.cmake | 9 +++++++++ setup.py | 21 +++++++++++++++++---- tools/ci_build/build.py | 11 ++++++++++- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 062cc8f9dbff3..07c65e7986b05 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -667,6 +667,15 @@ add_custom_command( $ ) +if (onnxruntime_BUILD_SHARED_LIB) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_OPENVINO) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/setup.py b/setup.py index 3203993e0c4d4..5750833ce35de 100644 --- a/setup.py +++ b/setup.py @@ -297,11 +297,13 @@ def finalize_options(self): "libmklml_gnu.so", "libiomp5.so", "mimalloc.so", + "libonnxruntime.so*", ] dl_libs = ["libonnxruntime_providers_shared.so"] dl_libs.append(providers_cuda_or_rocm) dl_libs.append(providers_tensorrt_or_migraphx) dl_libs.append(providers_cann) + dl_libs.append("libonnxruntime.so*") # DNNL, TensorRT & OpenVINO EPs are built as shared libs libs.extend(["libonnxruntime_providers_shared.so"]) libs.extend(["libonnxruntime_providers_dnnl.so"]) @@ -313,7 +315,12 @@ def finalize_options(self): if nightly_build: libs.extend(["libonnxruntime_pywrapper.so"]) elif platform.system() == "Darwin": - libs = ["onnxruntime_pybind11_state.so", "libdnnl.2.dylib", "mimalloc.so"] # TODO add libmklml and libiomp5 later. + libs = [ + "onnxruntime_pybind11_state.so", + "libdnnl.2.dylib", + "mimalloc.so", + "libonnxruntime.dylib*", + ] # TODO add libmklml and libiomp5 later. # DNNL & TensorRT EPs are built as shared libs libs.extend(["libonnxruntime_providers_shared.dylib"]) libs.extend(["libonnxruntime_providers_dnnl.dylib"]) @@ -323,7 +330,13 @@ def finalize_options(self): if nightly_build: libs.extend(["libonnxruntime_pywrapper.dylib"]) else: - libs = ["onnxruntime_pybind11_state.pyd", "dnnl.dll", "mklml.dll", "libiomp5md.dll"] + libs = [ + "onnxruntime_pybind11_state.pyd", + "dnnl.dll", + "mklml.dll", + "libiomp5md.dll", + "onnxruntime.dll", + ] # DNNL, TensorRT & OpenVINO EPs are built as shared libs libs.extend(["onnxruntime_providers_shared.dll"]) libs.extend(["onnxruntime_providers_dnnl.dll"]) @@ -376,7 +389,7 @@ def finalize_options(self): dl_libs.append("plugins.xml") dl_libs.append("usb-ma2x8x.mvcmd") data = ["capi/libonnxruntime_pywrapper.so"] if nightly_build else [] - data += [path.join("capi", x) for x in dl_libs if path.isfile(path.join("onnxruntime", "capi", x))] + data += [path.join("capi", x) for x in dl_libs if glob(path.join("onnxruntime", "capi", x))] ext_modules = [ Extension( "onnxruntime.capi.onnxruntime_pybind11_state", @@ -384,7 +397,7 @@ def finalize_options(self): ), ] else: - data = [path.join("capi", x) for x in libs if path.isfile(path.join("onnxruntime", "capi", x))] + data = [path.join("capi", x) for x in libs if glob(path.join("onnxruntime", "capi", x))] ext_modules = [] # Additional examples diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index b73a17db3ce13..ae4c9b27544ba 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2593,7 +2593,16 @@ def main(): if args.build_wheel or args.gen_doc or args.use_tvm or args.enable_training: args.enable_pybind = True - if args.build_csharp or args.build_nuget or args.build_java or args.build_nodejs: + if ( + args.build_csharp + or args.build_nuget + or args.build_java + or args.build_nodejs + or (args.enable_pybind and not args.enable_training) + ): + # If pyhon bindings are enabled, we embed the shared lib in the python package. + # If training is enabled, we don't embed the shared lib in the python package since training requires + # torch interop. args.build_shared_lib = True if args.build_nuget and cross_compiling: From 4932e040533a8d5f70e41d020aac441dfcb339ba Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 3 Jul 2024 10:53:31 +0800 Subject: [PATCH 05/13] ORTModule GraphTransitionManager (#19007) ### Problem Currently, the codebase contains some logics pertaining to model re-export checks and graph_builder reinitialization checks. Ideally, these operations should function akin to a state machine. However, upon inspecting the implementation, it becomes apparent that certain states are checked or set in various scattered locations. This fragmentation makes it challenging to comprehend when a re-export or re-initialization will be triggered. For optimal clarity and maintainability, it is advisable to consolidate these states into a cohesive component, rather than dispersing them within the current graph execution manager. Furthermore, the process of model exports and post-export processing for stage 3 support or memory-efficient gradient management introduces considerable complexity. To enhance the codebase's structure, it would be beneficial to extract these intricate functionalities into a dedicated component, divorcing them from the current graph execution manager. As part of the effort to improve the codebase, it's essential to address inconsistencies in handling input/output flatten/unflatten operations. Currently, there are several functions performing these operations recursively, each with slightly different implementations. This inconsistency leads to varying support for input/output data types and structures in different parts of the code. To rectify this, the proposed pull request simplifies these operations into a set of primitive functions, ensuring uniformity. This not only streamlines the code but also facilitates the maintenance of consistency when introducing bug fixes or supporting new data types. One thing to mention here: input output handling is deeply bound to the graph transition mentioned above, so it is difficult to make this change separately. While acknowledging the complexity of these logics, it is reassuring that the codebase benefits from an extensive suite of unit tests that cover all possible branches. Despite the intricacies, ensuring the passage of all tests has been a time-intensive but necessary aspect of this development effort. ### Design Introduce `GraphTransitionManager` and put all model export and post-export processing logics in it. 1. Re-export check 2. Do export 3. Re-post-export process check 4. Do post-export process 5. Return `PostExportProcessedModelInfo`, which contains all the information we need, to pass to ORT to build gradient graph (currently we do the same for training or evaluating, but ideally we should not do it for evaluating, let's keep this behavior as it is now, and make the change later). ``` # Input names for the pre-gradient-build graph. # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed # for example when memory efficient gradient management is enabled. self.onnx_graph_input_names: list[str] = onnx_graph_input_names # A subset of onnx_graph_input_names. # Input names that require gradients for the pre-gradient-build graph. self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict() self.onnx_graph_input_names_user_defined = ( onnx_graph_input_names_user_defined # The ONNX graph input names excluding the parameters, buffers. ) # The ONNX graph input names excluding the parameters, buffers. self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.data_accessor: list[callable] | None = data_accessor # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema``` The `GraphTransitionManager` instance is a property of `GraphExecutionManager` (e.g. `TrainingManager` or ``InferenceManager), 1. Use 'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs)' to check whether the PyTorch module need a re-export or re-post-export-process. 2. Use `self._graph_transition_manager._post_export_processed_model_info.construct_inputs` to construct the list of inputs used for ORT runs. 3. Use `self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)` to restore the outputs in original PyTorch output structure. ### Motivation and Context --- .../ortmodule/_graph_execution_manager.py | 570 +-------- .../ortmodule/_graph_transition_manager.py | 1058 +++++++++++++++++ .../training/ortmodule/_inference_manager.py | 36 +- .../python/training/ortmodule/_io.py | 622 +++++----- .../python/training/ortmodule/_logger.py | 38 +- .../ortmodule/_mem_efficient_grad_mgmt.py | 8 - .../python/training/ortmodule/_onnx_models.py | 11 +- .../training/ortmodule/_training_manager.py | 122 +- .../python/training/ortmodule/_utils.py | 19 +- .../ortmodule/_zero_stage3_compatibility.py | 10 +- .../python/training/ortmodule/ortmodule.py | 4 +- .../python/training/utils/torch_io_helper.py | 8 +- .../python/orttraining_test_ortmodule_api.py | 154 ++- .../orttraining_test_ortmodule_autograd.py | 10 +- .../orttraining_test_ortmodule_fallback.py | 39 +- .../orttraining_test_ortmodule_onnx_ops.py | 4 +- 16 files changed, 1642 insertions(+), 1071 deletions(-) create mode 100755 orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 8e383a5545e42..18999ce2fa1ab 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -4,13 +4,9 @@ # -------------------------------------------------------------------------- import copy -import inspect -import io import logging import os from abc import ABC, abstractmethod # noqa: F401 -from functools import partial -from hashlib import md5 as hash_fn from typing import Dict, List, Optional, Tuple import onnx @@ -19,24 +15,16 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype - -from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils, export_context -from ._fallback import ( - ORTModuleDeviceException, - ORTModuleONNXModelException, - ORTModuleTorchModelException, - _FallbackManager, - _FallbackPolicy, - wrap_exception, -) +from onnxruntime.training.utils import PTable, onnx_dtype_to_pytorch_dtype + +from . import _are_deterministic_algorithms_enabled, _logger, _onnx_models, _utils +from ._fallback import ORTModuleTorchModelException, _FallbackManager, _FallbackPolicy, wrap_exception from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface -from ._io import _FlattenedModule, _InputInfo -from ._logger import LogColor -from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector -from ._utils import check_function_has_param, get_rank +from ._graph_transition_manager import GraphTransitionManager, PostExportProcessedModelInfo +from ._io import _FlattenedModule +from ._runtime_inspector import RuntimeInspector +from ._utils import get_rank from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -85,15 +73,12 @@ def __init__( # Original and flattened (transformed) output module self._flattened_module = module - # onnx models self._onnx_models = _onnx_models.ONNXModels() + self._graph_transition_manager: Optional[GraphTransitionManager] = None - # Model after inference optimization or gradient building. + # Model after inference optimization and then gradient building. self._graph_builder = None self._graph_info = None - self._graph_initializer_names = set() - self._graph_initializer_names_to_train = set() - self._graph_initializers: List[torch.nn.parameter.Parameter] = [] # TrainingAgent or InferenceAgent self._execution_agent = None @@ -107,36 +92,11 @@ def __init__( # To be instantiated in the concrete implementation of GraphExecutionManager self._export_mode = export_mode - # Exporter can take extra arguments for ORTModule extensions - # It cannot overlap with required/immutable arguments (validated in runtime) - self._export_extra_kwargs = {} - - # Input and output infos (including schema) for exported model. - self._input_info: Optional[_InputInfo] = None - self._module_output_schema: Optional[ORTModelInputOutputSchemaType] = None - - # Device where the model is placed. - self._device: Optional[torch.device] = _utils.get_device_from_module(module) - - # Forward function input parameters of the original module. - self._module_parameters: List[inspect.Parameter] = list( - inspect.signature(self._original_module.forward).parameters.values() - ) - - # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. - for input_parameter in self._module_parameters: - if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: - self._logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) # WIP feature to enable caching in Gradient accumulation scenario. self._gradient_accumulation_manager = GradientAccumulationManager() - # Flag to re-export the model due to attribute change on the original module. - # Re-export will be avoided if _skip_check is enabled. - self._original_model_has_changed = False - # Inspector for runtime information, for example input data, memory usage, etc. self._runtime_inspector = RuntimeInspector( self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING @@ -163,9 +123,7 @@ def __init__( configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) - # Will be reset everytime we re-initialize the graph builder. - # Be noted, we will never enable this feature for inference mode. - self._mem_efficient_grad_management_is_enabled = False + self._initialize_graph_transition_manager() def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): @@ -176,6 +134,18 @@ def _get_torch_gpu_allocator_function_addresses(self): self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address() self._torch_empty_cache = torch_gpu_allocator.gpu_caching_allocator_empty_cache_address() + def _initialize_graph_transition_manager(self): + """Creates a new GraphTransitionManager, initializes it and saves it to self._graph_transition_manager""" + self._graph_transition_manager = GraphTransitionManager( + flatten_module=self._flattened_module, + export_mode=self._export_mode, + debug_options=self._debug_options, + runtime_options=self._runtime_options, + time_tracker=self.time_tracker, + runtime_inspector=self._runtime_inspector, + logger=self._logger, + ) + def _validate_module_type(self, module): """Raises ORTModuleTorchModelException if the module is not a torch.nn.Module""" @@ -205,7 +175,9 @@ def forward(self): def _build_graph(self, config): if self._runtime_options.use_static_shape: - self._graph_builder.build(config, self._input_info.shape) + self._graph_builder.build( + config, self._graph_transition_manager._model_info_for_export.onnx_graph_input_shapes + ) else: self._graph_builder.build(config) @@ -259,7 +231,8 @@ def _get_session_config(self): # Enable memory efficient execution order for training if 1). memory efficient grad management is enabled # or 2). memory optimizer is enabled. use_memory_efficient_topo_sort = (self._export_mode == torch.onnx.TrainingMode.TRAINING) and ( - self._mem_efficient_grad_management_is_enabled or self._runtime_options.memory_optimizer_is_enabled() + self._graph_transition_manager._post_export_processed_model_info.is_mem_efficient_grad_management_enabled + or self._runtime_options.memory_optimizer_is_enabled() ) session_options.execution_order = ( onnxruntime.ExecutionOrder.MEMORY_EFFICIENT @@ -283,266 +256,6 @@ def _get_session_config(self): return session_options, providers, provider_options - @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) - @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) - def _export_model(self, *inputs, **kwargs) -> bool: - # 1. Set the self._device from the user module - # 2. Verify input schema matches the schema used on the previous model export - # 3. Export the user model under self._export_training_flag mode - # Return True if the model needs to be exported, False if no export is required. - - # Note: Model is only exported when: - # 1. Model has never been exported before. - # 2. Model input schema has changed (changes in inputs requiring gradient, shape, boolean inputs values change, etc) - # Model is not re-exported when the model parameters change. This can happen when the model is stateful, - # or the user explicitly changed model parameters after the onnx export. - - # Record random states here and restore later in case any of them gets changed during the export, - # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. - random_states = _utils.get_random_states() - - schema = _io._extract_schema({"args": copy.copy(inputs), "kwargs": copy.copy(kwargs)}, self._device) - if ( - self._onnx_models.exported_model - and schema == self._input_info.schema - and not self._original_model_has_changed - ): - # All required models have already been exported previously - return False - - self._set_device_from_module(inputs, kwargs) - embedding_hook_handles = self._add_check_embedding_sparsity_hook() - label_hook_handles = self._add_check_label_sparsity_hook() - - from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step - - with export_context(), no_increase_global_step(): - self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) - - for hook in embedding_hook_handles: - hook.remove() - for hook in label_hook_handles: - hook.remove() - - if self._debug_options.save_onnx_models.save: - self._onnx_models.save_exported_model( - self._debug_options.save_onnx_models.path, - self._debug_options.save_onnx_models.name_prefix, - self._export_mode, - ) - - if self._runtime_options.run_symbolic_shape_infer: - self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes( - self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True - ) - - # Restore the recorded random states - _utils.set_random_states(random_states) - - return True - - def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inputs, **kwargs) -> onnx.ModelProto: - """Exports PyTorch `self._flattened_module` to ONNX for inferencing or training, - using `*inputs` and `**kwargs` as input - - TODO: How to support dynamic axes? Dimensions are determined by samples - """ - - # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # DEVINFO -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # INFO -> [Rank 0] FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) - # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) - # Be noted: rank 0 log only is controlled by logger configured in _logger.py - torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO - - # Setup dynamic axes for onnx model - self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) - need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned( - self._original_module, self._device - ) - if not need_deep_copy: - if self._runtime_options.deepcopy_before_model_export: - self._logger.warning( - "Since the user requested not to deep copy this model, " - "the initial weights may not be preserved and could change slightly during the forward run. " - "This could cause a minor difference between the ORTModule and the PyTorch run for the " - "first iteration. The computation will proceed as normal, but this should be noted." - ) - else: - self._logger.warning( - "Due to the limited GPU memory execution manager does not create a deep copy of this model. " - "Therefore, the initial weights might be slightly altered during the forward run. " - "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " - "first iteration. The computation will continue as usual, but this should be noted." - ) - ( - output_names, - output_dynamic_axes, - self._module_output_schema, - ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy - ) - self._input_info.dynamic_axes.update(output_dynamic_axes) - - # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs - self._flattened_module._input_info = self._input_info - - self._logger.info("Exporting the PyTorch model to ONNX...") - - # Leverage cached model if available - cache_dir = self._runtime_options.ortmodule_cache_dir - if cache_dir: - filename = os.path.join( - cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" - ) - if os.path.exists(cache_dir) and os.path.isfile(filename): - self._logger.warning( - f"Cached model detected! Cached model will be used to save export and initialization time." - f"If you want the model to be re-exported then DELETE {filename}." - ) - exported_model = onnx.load(filename) - return exported_model - - # Export torch.nn.Module to ONNX - f = io.BytesIO() - - # Deepcopy inputs, since input values may change after model run. - # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). - # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) - # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) - # Ops behaving differently under train/eval mode need to be exported with the - # correct training flag to reflect the expected behavior. - # For example, the Dropout node in a model is dropped under eval mode. - assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" - - try: - from ._zero_stage3_compatibility import stage3_export_context - - with torch.no_grad(), stage3_export_context(self._runtime_options.enable_zero_stage3_support, self): - required_export_kwargs = { - "input_names": self._input_info.names, - "output_names": output_names, - "opset_version": self._runtime_options.onnx_opset_version, - "do_constant_folding": False, - "training": self._export_mode, - "dynamic_axes": self._input_info.dynamic_axes, - "verbose": torch_exporter_verbose_log, - "export_params": False, - "keep_initializers_as_inputs": True, - } - - if check_function_has_param(torch.onnx.export, "autograd_inlining"): - # From some PyTorch version, autograd_inlining is a valid argument. - # We allow it to be True if custom autograd function is disabled (where autograd.Function - # anyway is not supported in ONNX until it can be inlined). - required_export_kwargs["autograd_inlining"] = ( - not self._runtime_options.enable_custom_autograd_function - ) - - invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() - - if len(invalid_args) != 0: - error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." - raise RuntimeError(error_msg) - - torch.onnx.export( - self._flattened_module, - sample_inputs_as_tuple, - f, - **required_export_kwargs, - **self._export_extra_kwargs, - ) - except Exception as e: - message = _utils.get_exception_as_string(e) - - # Special handling when Huggingface transformers gradient checkpoint usage pattern found. - # For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this: - # File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward - # layer_outputs = self._gradient_checkpointing_func( - # File "/site-packages/torch/_compile.py", line 24, in inner - # return torch._dynamo.disable(fn, recursive)(*args, **kwargs) - # File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn - # raise RuntimeError( - # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment. - if ( - "_gradient_checkpointing_func" in message - and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message - ): - is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1 - notes = ( - " Your model is running with gradient checkpointing, yet the PyTorch exporter\n" - " failed during tracing the graph. Try to enable ORTModule's\n" - " gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n" - " using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n" - ) - if is_ckpt_activation_allowed: - # If the user allows the gradient checkpointing export, we should inform the user to disable it, - # to make layerwise recompute work. - notes += ( - " We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n" - " which enables gradient checkpointing torch.autograd.Functions(s) to export.\n" - " To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n" - " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" - ) - - self._logger.error( - f"{LogColor.RED}\n" - "******************************** IMPORTANT NOTE *******************************\n" - f"{notes}" - "*******************************************************************************\n" - f"{LogColor.ENDC}\n" - ) - - raise wrap_exception( # noqa: B904 - ORTModuleONNXModelException, - RuntimeError(f"There was an error while exporting the PyTorch model to ONNX: \n\n{message}"), - ) - exported_model = onnx.load_model_from_string(f.getvalue()) - - if self._runtime_options.enable_custom_autograd_function: - from ._custom_autograd_function_exporter import post_process_enabling_autograd_function - - exported_model = post_process_enabling_autograd_function(exported_model) - - if self._runtime_options.enable_zero_stage3_support: - from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat - - exported_model = post_processing_enable_zero_stage3_compat( - exported_model, - self._zero_stage3_param_map, - [name for name, _ in self._flattened_module.named_parameters()], - ) - - # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( - # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) - # find input info mismatch, will re-initialize the graph builder. - # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - - # Cache model for future runs - if cache_dir: - if not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - filename = os.path.join( - cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" - ) - self._logger.info(f"Caching model for future runs to {filename}.") - onnx.save(exported_model, filename) - - return exported_model - - def _set_device_from_module(self, inputs, kwargs): - """Get the device from the module and save it to self._device""" - - device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(inputs, kwargs) - if not self._device or self._device != device: - self._device = device - if not self._device: - raise wrap_exception( - ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") - ) - def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfiguration: graph_transformer_config = C.TrainingGraphTransformerConfiguration() graph_transformer_config.propagate_cast_ops_config = C.PropagateCastOpsConfiguration() @@ -563,68 +276,25 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati return graph_transformer_config @_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) - def _initialize_graph_builder(self): + def _initialize_graph_builder(self, post_export_processed_model_info: PostExportProcessedModelInfo): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" - self._mem_efficient_grad_management_is_enabled = ( - self._export_mode != torch.onnx.TrainingMode.EVAL - and self._runtime_options.enable_mem_efficient_grad_management - ) - - # We post process the exported model because the trainable parame might be changed, so this path is - # re-triggered by reinitialize_graph_builder. - exported_model = copy.deepcopy(self._onnx_models.exported_model) - self._onnx_models.processed_exported_model = exported_model - - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training - - # Override the options if model is not modified. - (self._mem_efficient_grad_management_is_enabled, exported_model, self._param_trigger_grad) = ( - post_processing_enable_mem_efficient_training( - exported_model, self._flattened_module.named_parameters(), self._device - ) - ) - - if self._runtime_options.run_symbolic_shape_infer: - exported_model = SymbolicShapeInference.infer_shapes( - exported_model, auto_merge=True, guess_output_rank=True - ) - - # All initializer names along with user inputs are a part of the onnx graph inputs - # since the onnx model was exported with the flag keep_initializers_as_inputs=True - # We need to use the raw exported model here since the graph inputs include both user inputrs and - # parameters. - onnx_initializer_names = {p.name for p in exported_model.graph.input} - - # TODO: PyTorch exporter bug: changes the initializer order in ONNX model - initializer_names = [ - name for name, _ in self._flattened_module.named_parameters() if name in onnx_initializer_names - ] - initializer_names_to_train = [ - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in onnx_initializer_names - ] - # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() - grad_builder_config.initializer_names = initializer_names - grad_builder_config.initializer_names_to_train = initializer_names_to_train - - input_names_require_grad = self._input_info.require_grad_names + grad_builder_config.initializer_names = ( + post_export_processed_model_info.onnx_graph_input_names + ) # containing both user defined and buffers/parameters. + grad_builder_config.initializer_names_to_train = ( + post_export_processed_model_info.onnx_graph_input_names_require_grad + ) # containing both user defined and parameters requiring gradients. + + input_names_require_grad = post_export_processed_model_info.onnx_graph_input_names_require_grad_user_defined if self._runtime_options.enable_zero_stage3_support: from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - - # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. - input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization @@ -636,33 +306,9 @@ def _initialize_graph_builder(self): # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(exported_model.SerializeToString(), grad_builder_config) - - raw_onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} - - raw_initializer_names = [ - name for name, _ in self._flattened_module.named_parameters() if name in raw_onnx_initializer_names - ] - raw_initializer_names_to_train = [ - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in raw_onnx_initializer_names - ] - - # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train - # a set (unordered_set in the backend) that does not require a copy on each reference. - self._graph_initializer_names = set(raw_initializer_names) - self._graph_initializer_names_to_train = set(raw_initializer_names_to_train) - - # Initializers can be cached and used since they are expected not to be re-instantiated - # between forward calls. - self._graph_initializers = [ - param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names - ] - - def signal_model_changed(self): - """Signals the execution manager to re-export the model on the next forward call""" - self._original_model_has_changed = True + self._graph_builder.initialize( + post_export_processed_model_info._post_export_processed_model.SerializeToString(), grad_builder_config + ) def __getstate__(self): state = copy.copy(self.__dict__) @@ -671,6 +317,7 @@ def __getstate__(self): "_onnx_models", "_graph_builder", "_graph_info", + "_graph_transition_manager", # Not pickled as it is re-constructed in __setstate__ "_execution_agent", "_torch_alloc", "_torch_free", @@ -686,82 +333,12 @@ def __setstate__(self, state): _utils.reinitialize_graph_execution_manager(self) - def _add_check_embedding_sparsity_hook(self): - """ - Add hook to check embedding sparsity and enable padding elimination if applicable. - 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. - 2. Register forward pre hook to the Embedding module and the hook will check sparsity of the embedding input. - 3. If the sparsity is below a threshold, enable padding elimination by adding FlagAndPrintDensity after the - output. GraphTransformer of PaddingElimination will check the FlagAndPrintDensity and do the actual - padding elimination graph modification. - 4. Return the hook handles for later removal. + self._initialize_graph_transition_manager() - """ - if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda": - return [] - - def _embedding_hook(name, module, args): - ebd_input = args[0] - if ebd_input is None or not isinstance(ebd_input, torch.Tensor): - self._logger.warning("Embedding input is not a tensor.") - return None - - valid_token = torch.count_nonzero(ebd_input - module.padding_idx) - total_token = ebd_input.numel() - embed_density = float(valid_token) / float(total_token) * 100 - - if embed_density < 90: - self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) - self._runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density - return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") - else: - self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) - return None - - embedding_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): - if isinstance(sub_module, torch.nn.modules.sparse.Embedding): - if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: - embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) - - return embedding_hook_handles - - def _add_check_label_sparsity_hook(self): - """ - Add hook to check label sparsity and enable sceloss compute optimization if applicable. - 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. - 2. If the sparsity is below a threshold, enable sceloss compute optimization by adding FlagAndPrintDensity after the - output. GraphTransformer of InsertGatherBeforeSceLoss will check the FlagAndPrintDensity and do the actual - sceloss compute optimization graph modification. - - """ - if not self._runtime_options.enable_label_sparse_optimizer: - return None - - def _label_hook(name, module, args): - label_input = args[1] - if label_input is None or not isinstance(label_input, torch.Tensor): - self._logger.warning("Label input is not a tensor.") - return None - - valid_token = torch.count_nonzero(label_input - module.ignore_index) - total_token = label_input.numel() - label_density = float(valid_token) / float(total_token) * 100 - - if label_density < 90: - self._logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) - self._runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density - return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) - else: - self._logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) - return None - - label_check_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): - if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): - label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) - - return label_check_hook_handles + @property + def _device(self): + # Graph transition manager is responsible for detecting and managing the device to use. + return self._graph_transition_manager._device @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): @@ -775,34 +352,24 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): enable sparsity-based optimization. """ - detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) - - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: - self._append_pull_weight_trigger_as_input(kwargs, detected_device) - - param_to_append_as_onnx_graph_inputs = [] - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + if ( + self._runtime_options.enable_zero_stage3_support + or self._graph_transition_manager._post_export_processed_model_info.is_mem_efficient_grad_management_enabled + ): + self._append_pull_weight_trigger_as_input(kwargs, self._device) - param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters(), self._onnx_models.exported_model + if ( + self._runtime_inspector.memory_ob.is_enabled() + and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed + ): + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device ) - else: - param_to_append_as_onnx_graph_inputs = self._graph_initializers - - _io._combine_input_buffers_initializers( - param_to_append_as_onnx_graph_inputs, - self._graph_builder.get_graph_info().user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - detected_device, - self._runtime_inspector, - self._zero_stage3_param_map, - ) + self._runtime_inspector.memory_ob.collect_symbolic_dim_values( + self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, + prepared_input_map, + ) + self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True if self._runtime_inspector._sceloss_module_to_ignore_density_map: self._runtime_options.label_sparsity_ratio = ",".join( @@ -828,19 +395,6 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic device=device, ).requires_grad_() - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import ( - MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, - ) - - kwargs[MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME] = torch.zeros( - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), - device=device, - ).requires_grad_() - def _log_feature_stats(self): if get_rank() != 0: return diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py new file mode 100755 index 0000000000000..80bb00e0c3ac1 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -0,0 +1,1058 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import copy +import inspect +import io +import logging +import os +from collections import OrderedDict +from functools import partial +from hashlib import md5 as hash_fn +from typing import Mapping, Sequence + +import onnx +import torch + +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +from onnxruntime.training.utils import ( + ORTModelInputOutputSchemaType, + ORTModelInputOutputType, + PrimitiveType, + onnx_dtype_to_pytorch_dtype, + unflatten_data_using_schema, +) + +from . import _io, _utils, export_context +from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception +from ._logger import LogColor, LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction +from ._onnx_models import _get_onnx_file_name, _save_model +from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector +from ._utils import check_function_has_param, get_rank +from ._zero_stage3_compatibility import stage3_export_context +from .options import DebugOptions, _RuntimeOptions + + +class ExportedModelInfo: + """Encapsulates the information of the exported model. + + After ONNX model export, the model info is collected and encapsulated in this class, including: + 1. The ONNX graph input names. + 2. Graph input requiring gradient information. + 3. The model's forward function signature and args/kwargs schema, used as a cache key to compare with the current + inputs to see if the model needs to be re-exported. + + This data structure is returned by the GraphTransitionManager._export_model method. + + """ + + def __init__( + self, + module_forward_args_schema: ORTModelInputOutputSchemaType, + module_forward_kwargs_schema: ORTModelInputOutputSchemaType, + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_names_user_defined: list[str], + onnx_graph_input_names_require_grad_user_defined: list[str], + exported_model: onnx.ModelProto, + module_forward_output_schema: ORTModelInputOutputSchemaType, + ): + # Used as a baseline to compare with the current inputs (args/kwargs) to see if the model needs to be re-exported. + self.module_forward_args_schema: ORTModelInputOutputSchemaType | None = module_forward_args_schema + self.module_forward_kwargs_schema: ORTModelInputOutputSchemaType | None = module_forward_kwargs_schema + + # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use + # keep_initializers_as_inputs=True for model export) + # Be noted: all inputs are used by the model for its compute. + self.onnx_graph_input_names: list[str] = copy.deepcopy(onnx_graph_input_names) + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contain both the user-defined input names, the buffer names, and the parameter names (since we use + # keep_initializers_as_inputs=True for model export) + # Be noted: all inputs are used by the model for its compute. + self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) + + # Input names parsed from the model's forward function signature. + # Be noted: all inputs are used by the model for its compute. + # The ONNX graph input names exclude the parameters, and buffers. + self.onnx_graph_input_names_user_defined = copy.deepcopy(onnx_graph_input_names_user_defined) + + # A subset of onnx_graph_input_names_user_defined. + self.onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + onnx_graph_input_names_require_grad_user_defined + ) + + # Exported model proto. + self.exported_model: onnx.ModelProto | None = exported_model + + # Used for unflattening the outputs from the ORT forward run. + self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema + + def __str__(self): + return f"""ExportedModelInfo class: + \tonnx_graph_input_names: {self.onnx_graph_input_names} + \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} + \tmodule_forward_args_schema: {self.module_forward_args_schema} + \tmodule_forward_kwargs_schema: {self.module_forward_kwargs_schema} + \tmodule_forward_output_schema: {self.module_forward_output_schema} + """ + + def __repr__(self): + return self.__str__() + + +class PostExportProcessedModelInfo: + """Encapsulates the information of the post-export processed model. + + After ONNX model post-export processing, the model info is collected and encapsulated in this class, including: + 1. The ONNX graph input names, dynamic axes, and input data accessor functions. + 2. Graph input requiring gradient information. + 3. The interface to construct the inputs for the ORT forward run, from original given inputs running for PyTorch. + 4. The interface to restore the outputs from the ORT forward run, back to the original data structure. + + """ + + def __init__( + self, + flatten_module: torch.nn.Module, + onnx_graph_input_names_user_defined: list[str], + onnx_graph_input_names_require_grad_user_defined: list[str], + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], + module_forward_output_schema: ORTModelInputOutputSchemaType, + post_export_processed_model: onnx.ModelProto, + onnx_graph_input_data_accessor_user_defined: dict[str, callable], + onnx_graph_input_const_as_tensor: dict[str, torch.device], + enable_mem_efficient_grad_management: bool, + ): + self._flattened_module = flatten_module + + # Input names parsed from the model's forward function signature. + # Be noted: all inputs are used by the model for its compute. + # The ONNX graph input names exclude the parameters, and buffers. + self.onnx_graph_input_names_user_defined = copy.deepcopy(onnx_graph_input_names_user_defined) + + # A subset of onnx_graph_input_names_user_defined. + self.onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + onnx_graph_input_names_require_grad_user_defined + ) + + # Input names for the pre-gradient-build graph. + # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed + # for example when memory efficient gradient management is enabled. + self.onnx_graph_input_names: list[str] = copy.deepcopy(onnx_graph_input_names) + + # A subset of onnx_graph_input_names. + # Input names that require gradients for the pre-gradient-build graph. + self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map + + self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model + + # A function to access the input data from the args and kwargs. + # If it is not None, the length is same as onnx_graph_input_names_user_defined. + # For i-th input name, we can use the i-th function to get the input data from args and kwargs. + self.onnx_graph_input_data_accessor_user_defined: dict[str, callable] | None = ( + onnx_graph_input_data_accessor_user_defined + ) + + self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor + + self.is_mem_efficient_grad_management_enabled = enable_mem_efficient_grad_management + + # Used for unflattening the outputs from the ORT forward run. + self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema + + # A buffer to hold the inputs for the ORT forward run. For performance, we reuse the same buffer for each run. + self._buffer_for_ort_runs: dict[str, torch.Tensor] | None = None + + def __str__(self): + return f"""PostExportProcessedModelInfo class: + \tonnx_graph_input_names: {self.onnx_graph_input_names} + \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} + \tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map} + \tonnx_graph_input_names_user_defined: {self.onnx_graph_input_names_user_defined} + \tonnx_graph_input_names_require_grad_user_defined: {self.onnx_graph_input_names_require_grad_user_defined} + \tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys() if self._buffer_for_ort_runs else None} + """ + + def __repr__(self): + return self.__str__() + + def construct_inputs( + self, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + constant_as_tensor: bool, + device: torch.device, + ): + """Constructs the inputs for the forward method + + The inputs are constructed in the order they appear in the model's forward function signature + """ + from ._mem_efficient_grad_mgmt import ( + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) + + # First time construct the buffer for the ORT forward run. + if self._buffer_for_ort_runs is None: + self._buffer_for_ort_runs = OrderedDict() + + # Create the buffers for the inputs that are either parameters or buffers in the original module. + # For user inputs, fill with None for now, and will be filled dynamically during the forward run. + + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} + + for input_name in self.onnx_graph_input_names: + if input_name in parameter_names: + self._buffer_for_ort_runs[input_name] = parameter_names[input_name] + elif input_name in buffer_names: + self._buffer_for_ort_runs[input_name] = buffer_names[input_name] + else: + self._buffer_for_ort_runs[input_name] = ( + None # Fill None for user input first, will be overridden later. + ) + + for name in self.onnx_graph_input_names_user_defined: + if self.is_mem_efficient_grad_management_enabled and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: + self._buffer_for_ort_runs[name] = torch.zeros( + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + continue + + if name in self.onnx_graph_input_data_accessor_user_defined: + assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" + data = self.onnx_graph_input_data_accessor_user_defined[name](args, kwargs) + if name in self.onnx_graph_input_const_as_tensor: + data = PrimitiveType.get_tensor(data, device) + self._buffer_for_ort_runs[name] = data + else: + raise wrap_exception( + ORTModuleONNXModelException, + RuntimeError(f"Input is present in ONNX graph but not provided: {name}."), + ) + + return self._buffer_for_ort_runs + + def restore_outputs(self, ort_flatten_outputs: list[torch.Tensor]): + """Restores the outputs from the ORT forward run, back to the original data structure""" + + try: + return unflatten_data_using_schema(ort_flatten_outputs, self.module_forward_output_schema) + except TypeError as e: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule fails to unflatten user output: {e}"), + ) from None + + +class GraphTransitionManager: + """Manage the graph transition from 1). PyTorch to ONNX export and 2). ONNX to ONNX post-export processing.""" + + def __init__( + self, + flatten_module: torch.nn.Module, + export_mode: int, + debug_options: DebugOptions, + runtime_options: _RuntimeOptions, + time_tracker: TimeTracker, + runtime_inspector: RuntimeInspector, + logger: logging.Logger, + ): + self._device = _utils._get_device_from_module(flatten_module) + self._export_mode = export_mode + + self._debug_options = debug_options + self._runtime_options = runtime_options + + self._export_extra_kwargs = {} + + self._logger = logger + + # Tracker for ORTModule model export. + self._time_tracker = time_tracker + + self._runtime_inspector = runtime_inspector + + # A signal to indicate if the original model has changed and need a re-export. + self._original_model_has_changed = False + + self._flatten_module = flatten_module + + # Forward function input parameters of the original module. + self._module_forward_func_parameters: list[inspect.Parameter] = list( + inspect.signature(self._flatten_module._original_module.forward).parameters.values() + ) + # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. + for input_parameter in self._module_forward_func_parameters: + if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: + logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") + + # Model info collected from the original module's forward function signature and args/kwargs, used for ONNX export. + self._model_info_for_export: _io.ModelInfoForExport | None = None + self._exported_model_info: ExportedModelInfo | None = None + + # Model info after export and post export processing. + self._post_export_processed_model_info = None + + def get_post_processed_model( + self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] + ) -> tuple[bool, PostExportProcessedModelInfo]: + """Check if the post-export processed ONNX model can be reused, otherwise, reconstruct the model. + + Return True if the model can be reused, otherwise, return False. + The model can be reused when the following conditions are met: + a. The model has been exported before, and the inputs (args/outputs) schemas are the same as the previous ones. + b. If it is in training mode, the graph inputs requiring gradient are the same as the previous ones. + + """ + + if self._device is None: + device = _utils.get_device_from_module_and_inputs(self._flatten_module._original_module, args, kwargs) + if not self._device or self._device != device: + self._device = device + if not self._device: + raise wrap_exception( + ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") + ) + + # Extract the schema from the args and kwargs, and compare it with the pre-exported one if already exported. + cur_model_info_for_export = _io.parse_inputs_for_onnx_export( + self._module_forward_func_parameters, + args, + kwargs, + True, + self._device, + self._export_mode, + self._logger, + self._export_extra_kwargs, + ) + + need_export_model = GraphTransitionManager._export_check( + prev_exported_model_info=self._exported_model_info, + original_model_has_changed=self._original_model_has_changed, + cur_args_schema=cur_model_info_for_export.onnx_graph_input_arg_schema, + cur_kwargs_schema=cur_model_info_for_export.onnx_graph_input_kwarg_schema, + logger=self._logger, + ) + + if need_export_model: + # Note related to the _io.FlattenedModule export!!! + # + # The _io.FlattenedModule serves as a module wrapper designed to support tuple inputs and outputs for + # PyTorch run during ONNX export. (Remember the PyTorch exporter handles tuple inputs and outputs better.) + # Internally, it facilitates the acceptance of tuple inputs and the generation of tuple outputs by invoking + # the original module's forward function. The workflow involves the following steps: + + # 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and schemas for the + # flattened args and kwargs are generated. This schemas are essential for the subsequent un-flattening + # process. + + # 2. The flattened inputs (args + kwargs) are passed to the _io.FlattenedModule's forward run. + + # 3. The args schema and kwargs schema, etc are conveyed to the _io.FlattenedModule by setting the + # corresponding attributes. + + # 4. Within the _io.FlattenedModule's forward run, the inputs are un-flattened to the original args and + # kwargs using the associated schemas, and then they are passed to the original module's forward function. + + # 5. Upon the completion of the forward function, the outputs from the original module are flattened and + # returned to the caller. + + # 6. The 1-D flattened output tensors retain the same order as the outputs from the ONNX Runtime (ORT) + # forward run. To facilitate un-flattening during subsequent ORT runs, the output schema is saved as + # an attribute named `_output_schema` in the _io.FlattenedModule. + + copied_args = copy.copy(args) + copied_kwargs = copy.copy(kwargs) + flatten_inputs = [] + + # This looks a bit duplicated with `extract_data_and_schema` function, but this might be better to + # defined as a specialized logic that is the counter-part of `parse_inputs_for_onnx_export`, which handles + # args and kwargs separately. + for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor_user_defined.items(): + d = data_accessor(copied_args, copied_kwargs) + if name in cur_model_info_for_export.onnx_graph_input_const_as_tensor: + flatten_inputs.append( + PrimitiveType.get_tensor( + d, + cur_model_info_for_export.onnx_graph_input_const_as_tensor[name], + ) + ) + else: + if isinstance(d, torch.Tensor): + flatten_inputs.append(d) + + # Ignore all other non-tensor inputs. + + self._flatten_module._device = self._device + self._flatten_module._args_schema = cur_model_info_for_export.onnx_graph_input_arg_schema + self._flatten_module._kwargs_schema = cur_model_info_for_export.onnx_graph_input_kwarg_schema + self._flatten_module._num_positionals = cur_model_info_for_export.num_positional_args + + self._logger.info(f"do_export started, model info for export: {cur_model_info_for_export}") + + ( + exported_model, + module_output_schema, # Retrieved from _io.FlattenedModule's _output_schema + onnx_graph_input_names, + onnx_graph_input_names_require_grad, + ) = GraphTransitionManager._export_model( + flattened_module=self._flatten_module, + model_info_for_export=cur_model_info_for_export, + flatten_module_inputs=flatten_inputs, + deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, + device=self._device, + ortmodule_cache_dir=self._runtime_options.ortmodule_cache_dir, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + enable_embedding_sparse_optimizer=self._runtime_options.enable_embedding_sparse_optimizer, + onnx_opset_version=self._runtime_options.onnx_opset_version, + stage3_param_handle=self, + debug_options=self._debug_options, + time_tracker=self._time_tracker, + runtime_inspector=self._runtime_inspector, + logger=self._logger, + ) + + # Get the intersection of all user-defined input names (parsed from forward function signature) and + # the exported model input names including both user-defined input names and training parameter/buffer names. + # It is possible some user-defined input names are not in the exported model input names, if it is not used + # by the model for its compute. + onnx_graph_input_names_user_defined = [ + input_name + for input_name in cur_model_info_for_export.onnx_graph_input_names + if input_name in onnx_graph_input_names + ] + onnx_graph_input_names_require_grad_user_defined = [ + input_name + for input_name in cur_model_info_for_export.onnx_graph_input_names_require_grad + if input_name in onnx_graph_input_names_require_grad + ] + + self._exported_model_info = ExportedModelInfo( + module_forward_args_schema=cur_model_info_for_export.onnx_graph_input_arg_schema, + module_forward_kwargs_schema=cur_model_info_for_export.onnx_graph_input_kwarg_schema, + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, + onnx_graph_input_names_user_defined=onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined=onnx_graph_input_names_require_grad_user_defined, + exported_model=exported_model, + module_forward_output_schema=module_output_schema, + ) + + self._model_info_for_export = cur_model_info_for_export + + # Reset the signal to indicate the original model has changed. + self._original_model_has_changed = False + + # Save the exported model + if self._debug_options.save_onnx_models.save: + _save_model( + self._exported_model_info.exported_model, + os.path.join( + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "torch_exported", self._export_mode + ), + ), + ) + + self._logger.info(f"do_export completed, exported graph infos: {self._exported_model_info}") + + need_re_processed = False + if need_export_model: + need_re_processed = True + else: + need_re_processed, updated_onnx_graph_input_requires_grads = GraphTransitionManager._reprocess_check( + flatten_module=self._flatten_module, + exported_model_info=self._exported_model_info, + export_mode=self._export_mode, + model_info_for_export=self._model_info_for_export, + args=args, + kwargs=kwargs, + ) + if need_re_processed: + # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare + # using new iteration data. + self._exported_model_info.onnx_graph_input_names_require_grad = updated_onnx_graph_input_requires_grads + + if need_re_processed: + # At this point, the exported model is ready, and we can start post-export processing. + self._post_export_processed_model_info = GraphTransitionManager._post_export_process( + flatten_module=self._flatten_module, + export_mode=self._export_mode, + exported_model_info=self._exported_model_info, + model_info_for_export=self._model_info_for_export, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, + stage3_param_handle=self, + enable_mem_efficient_grad_management=self._export_mode != torch.onnx.TrainingMode.EVAL + and self._runtime_options.enable_mem_efficient_grad_management, + logger=self._logger, + ) + + # Save the post_processed model + if self._debug_options.save_onnx_models.save: + _save_model( + self._post_export_processed_model_info._post_export_processed_model, + os.path.join( + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "post_processed", self._export_mode + ), + ), + ) + + return need_re_processed, self._post_export_processed_model_info + + @staticmethod + def _export_check( + prev_exported_model_info: ExportedModelInfo | None, + original_model_has_changed: bool, + cur_args_schema: ORTModelInputOutputSchemaType, + cur_kwargs_schema: ORTModelInputOutputSchemaType, + logger: logging.Logger, + ): + """Check if the model needs to be exported, if yes, return True. + + If either of the following conditions is met, return True: + 1. The model has never been exported before. + 2. The original_model_has_changed is True. + 3. The model input schema parsed from args and kwargs has changed. + """ + + need_export_model = prev_exported_model_info is None # never exported before + + need_export_model = need_export_model or original_model_has_changed + + need_export_model = ( + need_export_model + or cur_args_schema != prev_exported_model_info.module_forward_args_schema + or cur_kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema + ) + + logger.info(f"_export_check completed - need_export_model: {need_export_model}") + + return need_export_model + + @staticmethod + def _reprocess_check( + flatten_module: _io._FlattenedModule, + exported_model_info: ExportedModelInfo, + export_mode: int, + model_info_for_export: _io.ModelInfoForExport, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + ) -> bool: + """Check if the exported model needs to be re-processed, if yes, + return True and the updated onnx_graph_input_requires_grads. + + For the following cases, return True: + 1. The export mode is TRAINING and the model's input names (including both user input and module parameters) + requiring gradient change. + """ + if export_mode == torch.onnx.TrainingMode.TRAINING: + # If inputs requiring gradient change from forward to the next, the gradient graph builder + # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad + + # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. + # This can happen when the user changes the model parameters after the onnx export. + # Model may have unused params dropped after export, so we only check those inputs existing in onnx graph. + + onnx_graph_input_requires_grads = [] + parameter_names = {k: v for k, v in flatten_module.named_parameters()} + for input_name in exported_model_info.onnx_graph_input_names: + if input_name in exported_model_info.onnx_graph_input_names_user_defined: + assert ( + input_name in model_info_for_export.onnx_graph_input_data_accessor_user_defined + ), f"{input_name} model_info_for_export.onnx_graph_input_data_accessor_user_defined" + # We assume the data accessor should be the same as the one used for the previous export, because + # there is args and kwargs schema check during export check phase. + if model_info_for_export.onnx_graph_input_data_accessor_user_defined[input_name]( + args, kwargs + ).requires_grad: + onnx_graph_input_requires_grads.append(input_name) + else: + assert input_name in parameter_names, f"{input_name} not exist parameter_names" + if parameter_names[input_name].requires_grad: + onnx_graph_input_requires_grads.append(input_name) + + if onnx_graph_input_requires_grads == exported_model_info.onnx_graph_input_names_require_grad: + return False, [] + return True, onnx_graph_input_requires_grads + + return False, [] + + @staticmethod + def _post_export_process( + flatten_module, + export_mode, + exported_model_info: ExportedModelInfo, + model_info_for_export: _io.ModelInfoForExport, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + run_symbolic_shape_infer: bool, + stage3_param_handle: type, + enable_mem_efficient_grad_management: bool, + logger: logging.Logger, + ): + """Post process the exported model, generate the processed model which will be used for initializing graph builder.""" + + # Deepcopy the exported model, in case modification affects the exported model. + post_processed_model = copy.deepcopy(exported_model_info.exported_model) + + if enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + post_processed_model = post_process_enabling_autograd_function(post_processed_model) + + if run_symbolic_shape_infer: + # MUST call symbolic shape inference after custom autograd function post-processing is done, + # Otherwise, there is no ctx output for PythonOp. + post_processed_model = GraphTransitionManager._infer_shapes(post_processed_model) + + if export_mode == torch.onnx.TrainingMode.TRAINING: + if enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + post_processed_model = post_processing_enable_zero_stage3_compat( + post_processed_model, + stage3_param_handle._zero_stage3_param_map, + [name for name, _ in flatten_module.named_parameters()], + ) + + onnx_graph_input_names_user_defined = copy.deepcopy(exported_model_info.onnx_graph_input_names_user_defined) + onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + exported_model_info.onnx_graph_input_names_require_grad_user_defined + ) + onnx_graph_input_names = copy.deepcopy(exported_model_info.onnx_graph_input_names) + onnx_graph_input_names_require_grad = copy.deepcopy(exported_model_info.onnx_graph_input_names_require_grad) + + if enable_mem_efficient_grad_management: + # Remove those trainable parameters from graph input, as they will be retrieved from weight pull node. + from ._mem_efficient_grad_mgmt import get_params_connected_to_pull_param_trigger + + # MUST call this before post_processing_enable_mem_efficient_training, otherwise, the onnx graph input + # will be modified. + parameter_not_as_graph_input_names = get_params_connected_to_pull_param_trigger( + flatten_module.named_parameters(), post_processed_model + ) + + if len(parameter_not_as_graph_input_names) > 0: + for k in parameter_not_as_graph_input_names: + if k in onnx_graph_input_names: + onnx_graph_input_names.remove(k) + + if k in onnx_graph_input_names_require_grad: + onnx_graph_input_names_require_grad.remove(k) + + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + + from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + + # Override the options if model is not modified. + + ( + enable_mem_efficient_grad_management, # Update the flag to indicate the mem efficient grad management is enabled. + post_processed_model, + stage3_param_handle._param_trigger_grad, + ) = post_processing_enable_mem_efficient_training( + post_processed_model, flatten_module.named_parameters(), parameter_not_as_graph_input_names + ) + + if run_symbolic_shape_infer: + post_processed_model = SymbolicShapeInference.infer_shapes( + post_processed_model, auto_merge=True, guess_output_rank=True + ) + + post_export_processed_model_info = PostExportProcessedModelInfo( + flatten_module, + onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined, + onnx_graph_input_names, + onnx_graph_input_names_require_grad, + model_info_for_export.onnx_graph_input_dynamic_axes_map, + exported_model_info.module_forward_output_schema, + post_processed_model, + model_info_for_export.onnx_graph_input_data_accessor_user_defined, + model_info_for_export.onnx_graph_input_const_as_tensor, + enable_mem_efficient_grad_management, + ) + + logger.info( + f"_post_export_process completed, post-export processed graph infos: {post_export_processed_model_info}" + ) + + return post_export_processed_model_info + + @staticmethod + def _infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto: + """Infer shapes for the exported model.""" + + model = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True) + return model + + @staticmethod + @TrackTimeForStaticFunction(ORTModuleInitPhase.EXPORT) + @SuppressLogs(ORTModuleInitPhase.EXPORT, is_ort_filter=False) + def _export_model( + *, + flattened_module: torch.nn.Module, + model_info_for_export: _io.ModelInfoForExport, + flatten_module_inputs: Sequence[ORTModelInputOutputType], + deepcopy_before_model_export: bool, + device: torch.device, + ortmodule_cache_dir: str, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + enable_embedding_sparse_optimizer: bool, + onnx_opset_version: int, + stage3_param_handle: type, + debug_options: DebugOptions, + time_tracker: TimeTracker, # time_tracker MUST be provided here to support TrackTimeForStaticFunction + runtime_inspector: RuntimeInspector, + logger: logging.Logger, + ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: + + # Add hooks to check the sparsity of the embedding and label inputs during the export. + embedding_hook_handles = GraphTransitionManager._add_check_embedding_sparsity_hook( + enable_embedding_sparse_optimizer, device, logger, runtime_inspector, flattened_module + ) + label_hook_handles = GraphTransitionManager._add_check_label_sparsity_hook( + enable_embedding_sparse_optimizer, logger, runtime_inspector, flattened_module + ) + + # Record random states here and restore later in case any of them gets changed during the export, + # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. + random_states = _utils.get_random_states() + + torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING + from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step + + with export_context(), no_increase_global_step(): + exported_model, module_output_schema = GraphTransitionManager._get_exported_model( + flattened_module=flattened_module, + model_info_for_export=model_info_for_export, + flatten_module_inputs=flatten_module_inputs, + deepcopy_before_model_export=deepcopy_before_model_export, + device=device, + ortmodule_cache_dir=ortmodule_cache_dir, + enable_custom_autograd_function=enable_custom_autograd_function, + enable_zero_stage3_support=enable_zero_stage3_support, + onnx_opset_version=onnx_opset_version, + torch_exporter_verbose_log=torch_exporter_verbose_log, + stage3_param_handle=stage3_param_handle, + logger=logger, + ) + + onnx_graph_input_names = [input.name for input in exported_model.graph.input] + parameter_names = [name for name, _ in flattened_module.named_parameters()] + onnx_graph_input_names_require_grad = [ + input.name + for input in exported_model.graph.input + if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad + ] + + # Restore the recorded random states + _utils.set_random_states(random_states) + + # Clean up all hooks. + for hook in embedding_hook_handles: + hook.remove() + + for hook in label_hook_handles: + hook.remove() + + return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad + + @staticmethod + def _get_exported_model( + flattened_module: torch.nn.Module, + model_info_for_export: _io.ModelInfoForExport, + flatten_module_inputs: Sequence[ORTModelInputOutputType], + deepcopy_before_model_export: bool, + device: torch.device, + ortmodule_cache_dir: str, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + onnx_opset_version: int, + torch_exporter_verbose_log: bool, + stage3_param_handle: type, + logger: logging.Logger, + ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType]: + """Exports PyTorch `flattened_module` to ONNX for inferencing or training.""" + + need_deep_copy = deepcopy_before_model_export and _io.can_module_be_deep_cloned(flattened_module, device) + if not need_deep_copy: + if deepcopy_before_model_export: + logger.warning( + "Since the user requested not to deep copy this model, " + "the initial weights may not be preserved and could change slightly during the forward run. " + "This could cause a minor difference between the ORTModule and the PyTorch run for the " + "first iteration. The computation will proceed as normal, but this should be noted." + ) + else: + logger.warning( + "Due to the limited GPU memory execution manager does not create a deep copy of this model. " + "Therefore, the initial weights might be slightly altered during the forward run. " + "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " + "first iteration. The computation will continue as usual, but this should be noted." + ) + ( + output_names, + output_dynamic_axes, + module_output_schema, + ) = _io.parse_outputs_for_onnx_export_and_extract_schema( + flattened_module, flatten_module_inputs, logger, need_deep_copy + ) + + # Combine the dynamic axes from inputs and outputs + dynamic_axes = copy.deepcopy(model_info_for_export.onnx_graph_input_dynamic_axes_map) + + dynamic_axes.update(output_dynamic_axes) + + logger.info("Exporting the PyTorch model to ONNX...") + + # Leverage cached model if available + cache_dir = ortmodule_cache_dir + if cache_dir: + filename = os.path.join( + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + ) + if os.path.exists(cache_dir) and os.path.isfile(filename): + logger.warning( + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." + ) + exported_model = onnx.load(filename) + return exported_model, module_output_schema + + # Export torch.nn.Module to ONNX + f = io.BytesIO() + + # Deepcopy inputs, since input values may change after model run. + # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). + # Therefore, deepcopy only the data component of the input tensors for export. + kwargs = {} + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*flatten_module_inputs, **kwargs) + assert len(sample_kwargs_copy) == 0, "Currently, kwargs are not supported for ONNX export." + sample_inputs_as_tuple = sample_inputs_copy + + # Ops behaving differently under train/eval mode need to be exported with the + # correct training flag to reflect the expected behavior. + # For example, the Dropout node in a model is dropped under eval mode. + assert model_info_for_export.export_mode is not None, "Please use a concrete instance of ExecutionManager" + + try: + with torch.no_grad(), stage3_export_context( + enable_zero_stage3_support, stage3_param_handle, flattened_module + ): + required_export_kwargs = { + "input_names": model_info_for_export.onnx_graph_input_names, # did not contains paramerter as its input yet + "output_names": output_names, + "opset_version": onnx_opset_version, + "do_constant_folding": False, + "training": model_info_for_export.export_mode, + "dynamic_axes": dynamic_axes, + "verbose": torch_exporter_verbose_log, + "export_params": False, + "keep_initializers_as_inputs": True, + } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs["autograd_inlining"] = not enable_custom_autograd_function + + invalid_args = model_info_for_export.export_extra_kwargs.keys() & required_export_kwargs.keys() + + if len(invalid_args) != 0: + error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." + raise RuntimeError(error_msg) + + torch.onnx.export( + flattened_module, + sample_inputs_as_tuple, + f, + **required_export_kwargs, + **model_info_for_export.export_extra_kwargs, + ) + except Exception as e: + message = _utils.get_exception_as_string(e) + + # Special handling when Huggingface transformers gradient checkpoint usage pattern found. + # For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this: + # File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward + # layer_outputs = self._gradient_checkpointing_func( + # File "/site-packages/torch/_compile.py", line 24, in inner + # return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + # File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn + # raise RuntimeError( + # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment. + if ( + "_gradient_checkpointing_func" in message + and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message + ): + is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1 + notes = ( + " Your model is running with gradient checkpointing, yet the PyTorch exporter\n" + " failed during tracing the graph. Try to enable ORTModule's\n" + " gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n" + " using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n" + ) + if is_ckpt_activation_allowed: + # If the user allows the gradient checkpointing export, we should inform the user to disable it, + # to make layerwise recompute work. + notes += ( + " We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n" + " which enables gradient checkpointing torch.autograd.Functions(s) to export.\n" + " To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n" + " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" + ) + + logger.error( + f"{LogColor.RED}\n" + "******************************** IMPORTANT NOTE *******************************\n" + f"{notes}" + "*******************************************************************************\n" + f"{LogColor.ENDC}\n" + ) + + raise wrap_exception( # noqa: B904 + ORTModuleONNXModelException, + RuntimeError( + f"There was an error while exporting the PyTorch model to ONNX: " + f"\n\n{_utils.get_exception_as_string(e)}" + ), + ) + exported_model = onnx.load_model_from_string(f.getvalue()) + + # Cache model for future runs + if cache_dir: + if not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.join( + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + ) + logger.info(f"Caching model for future runs to {filename}.") + onnx.save(exported_model, filename) + + return exported_model, module_output_schema + + def signal_model_changed(self): + """Signals the execution manager to re-export the model on the next forward call""" + self._original_model_has_changed = True + + @staticmethod + def _add_check_embedding_sparsity_hook( + enable_embedding_sparse_optimizer: bool, + device: torch.device, + logger: logging.Logger, + runtime_inspector: RuntimeInspector, + flattened_module: torch.nn.Module, + ) -> list: + """ + Add hook to check embedding sparsity and enable padding elimination if applicable. + 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. + 2. Register forward pre hook to the Embedding module and the hook will check sparsity of the embedding input. + 3. If the sparsity is below a threshold, enable padding elimination by adding FlagAndPrintDensity after the + output. GraphTransformer of PaddingElimination will check the FlagAndPrintDensity and do the actual + padding elimination graph modification. + 4. Return the hook handles for later removal. + + """ + if not enable_embedding_sparse_optimizer or device.type != "cuda": + return [] + + def _embedding_hook(name, module, args): + ebd_input = args[0] + if ebd_input is None or not isinstance(ebd_input, torch.Tensor): + logger.warning("Embedding input is not a tensor.") + return None + + valid_token = torch.count_nonzero(ebd_input - module.padding_idx) + total_token = ebd_input.numel() + embed_density = float(valid_token) / float(total_token) * 100 + + if embed_density < 90: + logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) + runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density + return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") + else: + logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) + return None + + embedding_hook_handles = [] + for name, sub_module in flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.sparse.Embedding): + if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: + embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) + + return embedding_hook_handles + + @staticmethod + def _add_check_label_sparsity_hook( + enable_label_sparse_optimizer: bool, + logger: logging.Logger, + runtime_inspector: RuntimeInspector, + flattened_module: torch.nn.Module, + ) -> list: + """ + Add hook to check label sparsity and enable sceloss compute optimization if applicable. + 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. + 2. If the sparsity is below a threshold, enable sceloss compute optimization by adding FlagAndPrintDensity after the + output. GraphTransformer of InsertGatherBeforeSceLoss will check the FlagAndPrintDensity and do the actual + sceloss compute optimization graph modification. + + """ + if not enable_label_sparse_optimizer: + return None + + def _label_hook(name, module, args): + label_input = args[1] + if label_input is None or not isinstance(label_input, torch.Tensor): + logger.warning("Label input is not a tensor.") + return None + + valid_token = torch.count_nonzero(label_input - module.ignore_index) + total_token = label_input.numel() + label_density = float(valid_token) / float(total_token) * 100 + + if label_density < 90: + logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) + runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density + return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) + else: + logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) + return None + + label_check_hook_handles = [] + for name, sub_module in flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): + label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) + + return label_check_hook_handles diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 642dc9b0f4dd6..61db462ad3bb8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -11,11 +11,10 @@ from onnxruntime.capi import _pybind_state as C -from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils +from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import unflatten_user_output from ._logger import ORTModuleInitPhase, TrackTime from ._utils import save_tuning_results, set_tuning_results from .options import DebugOptions, _SkipCheck @@ -109,15 +108,19 @@ def forward(self, *inputs, **kwargs): build_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._onnx_models.exported_model + or not self._graph_transition_manager._exported_model_info ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) # Exporting module to ONNX for the first time - build_graph = self._export_model(*inputs, **kwargs) + + ( + build_graph, + post_export_processed_model_info, + ) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs) if build_graph: - # If model was exported, then initialize the graph builder. - self._initialize_graph_builder() + # TODO(): do we need call it for inferencing mode??? + self._initialize_graph_builder(post_export_processed_model_info) # Build the inference graph if build_graph: @@ -134,7 +137,7 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent ): - module_device = _utils.get_device_from_module(self._original_module) + module_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) create_execution_session = ( build_graph @@ -144,7 +147,7 @@ def forward(self, *inputs, **kwargs): _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != module_device: - self._device = module_device + self._graph_transition_manager._device = module_device if create_execution_session: # Create execution session creates the inference_session @@ -160,23 +163,15 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - prepared_input_list = _io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device, - self._runtime_inspector, - self._zero_stage3_param_map, + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device ) user_outputs, _ = InferenceManager.execution_session_run_forward( self._execution_agent, self._onnx_models.optimized_model, self._device, - *prepared_input_list, + *prepared_input_map.values(), ) if ( @@ -188,7 +183,8 @@ def forward(self, *inputs, **kwargs): self._execution_agent._inference_session, False, self._runtime_options.tuning_results_path ) - return unflatten_user_output(self._module_output_schema, user_outputs) + return self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) + except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 1ba62194bf63e..8ad3d0df3e4fa 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -7,10 +7,10 @@ import gc import inspect from collections import OrderedDict, abc +from functools import partial from logging import Logger -from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple -import onnx import torch from onnxruntime.training.utils import ( @@ -20,9 +20,9 @@ extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_io_helper import _TensorStub -from ._fallback import ORTModuleIOError, ORTModuleONNXModelException, wrap_exception -from ._runtime_inspector import RuntimeInspector +from ._fallback import ORTModuleIOError, wrap_exception class _OutputIdentityOp(torch.autograd.Function): @@ -76,195 +76,6 @@ def symbolic(g, self): return g.op("Identity", self) -def flatten_kwargs(kwargs, device): - def _flatten_kwargs(value, name): - if PrimitiveType.is_primitive_type(value): - flattened_kwargs[name] = PrimitiveType.get_tensor(value, device) - elif isinstance(value, torch.Tensor): - flattened_kwargs[name] = value - elif isinstance(value, abc.Sequence): - # If the input is a sequence (like a list), expand the list so that - # each element of the list has a corresponding entry in the flattened - # kwargs dict - for idx, val in enumerate(value): - _flatten_kwargs(val, f"{name}_{idx}") - elif isinstance(value, abc.Mapping): - # If the input is a mapping (like a dict), expand the dict so that - # each element of the dict has an entry in the flattened kwargs dict - for key, val in value.items(): - _flatten_kwargs(val, f"{name}_{key}") - - flattened_kwargs = {} - for key, value in kwargs.items(): - _flatten_kwargs(value, key) - - return flattened_kwargs - - -class _InputInfo: - def __init__( - self, - names: List[str], - shape: List[List[int]], - require_grad_names: Optional[List[str]] = None, - dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, - schema: Optional[ORTModelInputOutputSchemaType] = None, - num_positionals=0, - ): - self.names: List[str] = names - self.shape: List[List[int]] = shape - self.require_grad_names: List[str] = require_grad_names if require_grad_names else [] - self.dynamic_axes: Dict[str, Dict[int, str]] = dynamic_axes if dynamic_axes else {} - self.schema: ORTModelInputOutputSchemaType = schema if schema else [] - self.num_positionals = num_positionals - self.kwargs = None - - def __repr__(self) -> str: - return f"""_InputInfo class: - \tNames: {self.names} - \tShape: {self.shape} - \tRequire gradient: {self.require_grad_names} - \tDynamic axes: {self.dynamic_axes} - \tSchema: {self.schema} - \t#Positionals (total): {self.num_positionals}""" - - def flatten( - self, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device, - ) -> Sequence[ORTModelInputOutputType]: - """Flatten args and kwargs in a single tuple of tensors with strict ordering""" - - ret = [PrimitiveType.get_tensor(arg, device) if PrimitiveType.is_primitive_type(arg) else arg for arg in args] - flattened_kwargs = flatten_kwargs(kwargs, device) - ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs] - self.kwargs = kwargs - - # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter - # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise. - if not kwargs: - ret.append({}) - - return ret - - def unflatten( - self, flat_args: Sequence[ORTModelInputOutputType] - ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: - """Unflatten tuple of tensors into args and kwargs""" - - args = tuple(flat_args[: self.num_positionals]) - kwargs = self.kwargs - self.kwargs = None - return args, kwargs - - -def _combine_input_buffers_initializers( - params: List[torch.nn.parameter.Parameter], - onnx_input_names: List[str], - input_info: Optional[_InputInfo], - named_buffer: Iterator[Tuple[str, torch.Tensor]], - inputs: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device, - rt_inspector: RuntimeInspector, - zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], -): - """Creates forward `*inputs` list from user input and PyTorch initializers - - ONNX Runtime forward requires an ordered list of: - * User input: computed from forward InferenceSession - * Initializers: computed from original PyTorch model parameters. - """ - - def _expand_inputs(current_input, non_none_inputs, name=""): - # The exporter handles input lists by expanding them so that each - # element of the list is its own input. - # ORTModule must match this behavior by also expanding the inputs. - if current_input is None or isinstance(current_input, str): - # Drop all None and string inputs - return - if isinstance(current_input, abc.Sequence): - # If the input is a sequence (like a list), expand the list so that - # each element of the list is an input by itself - for i, inp in enumerate(current_input): - _expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i)) - elif isinstance(current_input, abc.Mapping): - # If the input is a mapping (like a dict), expand the dict so that - # each element of the dict is an input by itself - for key, val in current_input.items(): - _expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key) - else: - # else just collect all the non none inputs within non_none_inputs - if isinstance(non_none_inputs, abc.Sequence): - non_none_inputs.append(current_input) - elif isinstance(non_none_inputs, abc.Mapping): - non_none_inputs[name] = current_input - - # User inputs - non_none_inputs = [] - _expand_inputs(inputs, non_none_inputs) - flattened_kwargs_inputs = {} - _expand_inputs(kwargs, flattened_kwargs_inputs) - buffer_names_dict = None - result = [] - onnx_input_to_value_map = OrderedDict() - - for input_idx, name in enumerate(onnx_input_names): - inp = None - if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None: - # Only use keywords coming from user that are expected by ONNX model - inp = flattened_kwargs_inputs[name] - - if inp is None: - try: - # Only use positionals coming from user that are expected by ONNX model - # if input_idx >= len(input_info.names), IndexError will be thrown - if name != input_info.names[input_idx]: - # When ONNX drops unused inputs, get correct index from user input - # if name is not in input_info.names, ValueError will be thrown - input_idx = input_info.names.index(name) # noqa: PLW2901 - inp = non_none_inputs[input_idx] - except (IndexError, ValueError): - # ONNX input name is not present in input_info.names. - pass - - if inp is None: - # Registered buffers are translated to user_input+initializer in ONNX - if buffer_names_dict is None: - buffer_names_dict = {buffer_name: i for buffer_name, i in named_buffer} - try: # noqa: SIM105 - inp = buffer_names_dict[name] - except KeyError: - # ONNX input name is not present in the registered buffer dict. - pass - - if inp is not None: - if PrimitiveType.is_primitive_type(inp): - inp = PrimitiveType.get_tensor(inp, device) - - result.append(inp) - onnx_input_to_value_map[name] = inp - else: - raise wrap_exception( - ORTModuleONNXModelException, RuntimeError(f"Input is present in ONNX graph but not provided: {name}.") - ) - - # params is a list of all initializers known to the onnx graph - if zero_stage3_offload_param_map: - for p in params: - if p not in zero_stage3_offload_param_map.values(): - result.append(p) - else: - result.extend(params) - - if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed: - rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map) - rt_inspector.memory_ob.symbolic_dim_collecting_completed = True - - return result - - def deepcopy_model_input( *args, **kwargs ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: @@ -288,113 +99,153 @@ def extract_tensor(value): return sample_args_copy, sample_kwargs_copy -def unflatten_user_output(output_schema: Optional[ORTModelInputOutputSchemaType], outputs: List[torch.Tensor]): +def _extract_schema( + data: ORTModelInputOutputType, device +) -> Tuple[Sequence[ORTModelInputOutputType], ORTModelInputOutputSchemaType]: try: - return unflatten_data_using_schema(outputs, output_schema) + flatten_data, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) + return flatten_data, schema except TypeError as e: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule fails to unflatten user output: {e}"), - ) from None + raise wrap_exception(ORTModuleIOError, TypeError(f"ORTModule fails to extract schema from data: {e}")) from None -def _extract_schema(data: ORTModelInputOutputType, device) -> ORTModelInputOutputSchemaType: - try: - _, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) - return schema - except TypeError as e: - raise wrap_exception(ORTModuleIOError, TypeError(f"ORTModule fails to extract schema from data: {e}")) from None +class _FlattenedModule(torch.nn.Module): + def __init__(self, original_module: torch.nn.Module): + super().__init__() + self._original_module: torch.nn.Module = original_module + # Before ONNX export, we flatten the args and kwargs into a 1-D list of tensors to make torch.export happy. + # As a result, we need to unflatten the args and kwargs back to the original structure before calling the + # original module's forward function. + # So we need set those information that are needed to unflatten the args and kwargs, before calling the + # torch.export. + self._device: Optional[torch.device] = None + self._args_schema: Optional[ORTModelInputOutputSchemaType] = None + self._kwargs_schema: Optional[ORTModelInputOutputSchemaType] = None + self._num_positionals: Optional[int] = None + + # Similarly, to make torch.export happy, we need to flatten the original module's outputs into a 1-D list of tensors. + # Need to keep the output schema to unflatten the outputs back to the original structure. + # Then those code depends on the original structure of the outputs can work properly. + self._output_schema: Optional[ORTModelInputOutputSchemaType] = None -def _parse_outputs_and_extract_names_and_dynamic_axes(module_output) -> Tuple[List[str], Dict[str, Dict[int, str]]]: - """Parses through the module output and returns output names and dynamic axes""" + def forward(self, *args): + new_args = unflatten_data_using_schema(args[: self._num_positionals], self._args_schema) - def _populate_output_names_and_dynamic_axes( - output, output_names: List[str], output_dynamic_axes: Dict[str, Dict[int, str]], output_idx: List[int] - ): - # Depth first traversal to traverse through the entire output collecting output names and dynamic axes - - if output is None: - return - elif isinstance(output, torch.Tensor): - # Naming the outputs with a hyphen ensures that there can be no input with the same - # name, preventing collisions with other NodeArgs (for example an input to forward called output0) - output_name = f"output-{output_idx[0]}" - output_idx[0] += 1 - output_names.append(output_name) - output_dynamic_axes[output_name] = {} - for dim_idx in range(len(output.shape)): - output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) - return - - if isinstance(output, abc.Sequence): - for value in output: - _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) - elif isinstance(output, abc.Mapping): - for _, value in sorted(output.items()): - _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(output)}"), - ) + new_kwargs = unflatten_data_using_schema(args[self._num_positionals :], self._kwargs_schema) - output_names: List[str] = [] - output_dynamic_axes: Dict[str, Dict[int, str]] = {} - output_idx: List[int] = [0] - _populate_output_names_and_dynamic_axes(module_output, output_names, output_dynamic_axes, output_idx) + original_outputs = self._original_module(*new_args, **new_kwargs) - return output_names, output_dynamic_axes + # Flatten the outputs + flatten_outputs, self._output_schema = _extract_schema(original_outputs, self._device) + # Append _OutputIdentityOp to the outputs to support passthrough outputs + final_flatten_outputs = [] + for output in flatten_outputs: + final_flatten_outputs.append(_OutputIdentityOp.apply(output)) -def _transform_output_to_flat_tuple(data): - """Converts the data to a flat tuple by iterating over the entire data structure""" + return final_flatten_outputs - def _flatten_data(data, flat_data): - # Recursively traverse over the data and populate the flat_data with torch.Tensors - if data is None: - return - elif isinstance(data, torch.Tensor): - identity = _OutputIdentityOp.apply - flat_data.append(identity(data)) - elif isinstance(data, abc.Sequence): - for value in data: - _flatten_data(value, flat_data) - elif isinstance(data, abc.Mapping): - for _, value in sorted(data.items()): - _flatten_data(value, flat_data) - else: - raise wrap_exception( - ORTModuleIOError, TypeError(f"ORTModule does not support the following data type {type(data)}.") - ) +class ModelInfoForExport: + def __init__( + self, + onnx_graph_input_names: List[str], + onnx_graph_input_names_require_grad: List[str], + onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], + onnx_graph_input_shapes: List[List[int]], + onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = None, + onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = None, + onnx_graph_input_arg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + onnx_graph_input_kwarg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + num_positional_args: int = 0, + export_mode: Optional[int] = None, + export_extra_kwargs: Optional[Dict[str, any]] = None, + ): + # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL + self.export_mode = export_mode + + # Exporter can take extra arguments for ORTModule extensions + # It cannot overlap with required/immutable arguments (validated in runtime) + self.export_extra_kwargs = export_extra_kwargs + + # Input names parsed and then flatten from the model's forward function signature. + # This should contains ONLY the user defined input names + # Be noted: some of the input might not be used by the model for its compute. + self.onnx_graph_input_names: List[str] = onnx_graph_input_names + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contains ONLY the user defined input names + # Be noted: some of the input might not be used by the model for its compute. + self.onnx_graph_input_names_require_grad: List[str] = onnx_graph_input_names_require_grad + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = onnx_graph_input_dynamic_axes_map + + self.onnx_graph_input_shapes: List[List[int]] = onnx_graph_input_shapes + + # The input args schema for the original model's forward function. + # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs + # of the export model). + self.onnx_graph_input_arg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_arg_schema + + # The input kwargs schema for the original model's forward function. + # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs + # of the export model). + self.onnx_graph_input_kwarg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_kwarg_schema + + self.num_positional_args: int = num_positional_args + + # A function to access the input data from the args and kwargs. + # If it is not None, the length is same as onnx_graph_input_names. + # For i-th input name, we can use the i-th function to get the input data from args and kwargs. + self.onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = ( + onnx_graph_input_data_accessor_user_defined + ) + + self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor + + def __str__(self) -> str: + return f"""ModelInfoForExport class: + \tExport mode: {self.export_mode} + \tExport extra kwargs: {self.export_extra_kwargs} + \tInput names: {self.onnx_graph_input_names} + \tInput names require grad: {self.onnx_graph_input_names_require_grad} + \tInput dynamic axes: {self.onnx_graph_input_dynamic_axes_map} + \tInput shapes: {self.onnx_graph_input_shapes} + \tInput args schema: {self.onnx_graph_input_arg_schema} + \tInput kwargs schema: {self.onnx_graph_input_kwarg_schema} + \tNum input args: {self.num_positional_args}""" - flat_data = [] - _flatten_data(data, flat_data) - return tuple(flat_data) + def __repr__(self) -> str: + return self.__str__() -class _FlattenedModule(torch.nn.Module): - def __init__(self, original_module: torch.nn.Module): - super().__init__() - self._original_module: torch.nn.Module = original_module +def _arg_access_with_index_func(arg_index, args, kwargs): + return args[arg_index] - # Before `forward` is called, _ort_module must be assigned - # Updated input info is needed to expand args into *args, **kwargs - self._input_info: Optional[_InputInfo] = None - def forward(self, *args): - new_args, new_kwargs = self._input_info.unflatten(args) - return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs)) +def _kwarg_access_with_name_func(name, args, kwargs): + return kwargs[name] + + +class SkipRetValue: + """A placeholder class to indicate that the return value of a function should be skipped""" def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], - onnx_graph: Optional[onnx.ModelProto], - schema: ORTModelInputOutputSchemaType, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], -) -> _InputInfo: + constant_as_tensor: bool, + device: torch.device, + export_mode: int, + logger: Logger, + export_extra_kwargs: Optional[Dict[str, any]] = None, +) -> ModelInfoForExport: """Parses through the model inputs and returns _InputInfo. Loop through all input parameters, try to flatten them into a 1-D list of inputs. For nested data in the inputs, @@ -414,67 +265,149 @@ def parse_inputs_for_onnx_export( Args: all_input_parameters: All inspected input parameters from the original model forward function signature. - onnx_graph: (optional) The onnx graph. - schema: The schema extracted from the positional arguments and keyword arguments of the model. args: The positional arguments of the model. kwargs: The keyword arguments of the model. + constant_as_tensor: Whether to treat constant inputs as tensors. + device: The device to be used for constant inputs. """ + arg_tensor_idx = [-1] + kwarg_tensor_idx = [-1] + def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: dynamic_axes[name] = {} for dim_idx in range(len(input.shape)): dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"}) return dynamic_axes - def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): + def _warn_of_constant_inputs(data): + logger.info(f"Received input of type {type(data)} is treated as a constant by ORT by default.") + + def _add_input( + name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable, tensor_idx: List[int] + ): """Returns number of expanded non none inputs that _add_input processed""" - if name in input_names or input_value is None or isinstance(input_value, str): - # Drop all None and string inputs and return 0. - return + # in case the input is already handled. + if name in visited_input_names: + return SkipRetValue() + + visited_input_names.append(name) + + value = input_value + primitive_dtype = None + if value is None: + _warn_of_constant_inputs(value) + data_accessors[name] = cur_func + return value + elif isinstance(value, str): + _warn_of_constant_inputs(value) + data_accessors[name] = cur_func + return value + elif PrimitiveType.is_primitive_type(value): + if constant_as_tensor: + # This has special handling for bool type to string conversion. + primitive_dtype = PrimitiveType.get_primitive_dtype(value) + value = PrimitiveType.get_tensor(value, device) + const_to_tensor_inputs[name] = device + + else: + data_accessors[name] = cur_func + _warn_of_constant_inputs(value) + return value + elif isinstance(value, abc.Sequence): + sequence_type = type(value) + stubbed_schema = [] - if isinstance(input_value, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself. - for i, val in enumerate(input_value): + for i, val in enumerate(value): # Name each input with the index appended to the original name of the # argument. - _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) + + def _access_func(i, cur_func, args, kwargs): + return cur_func(args, kwargs)[i] + + input_schema = _add_input( + f"{name}_{i}", + val, + onnx_graph_input_names, + partial(_access_func, i, cur_func), + tensor_idx, + ) + + if not isinstance(input_schema, SkipRetValue): + stubbed_schema.append(input_schema) # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. - return - elif isinstance(input_value, abc.Mapping): + + try: + # namedtuple can be created by passing the list sequence to method _make + stubbed_schema = sequence_type._make(stubbed_schema) + except AttributeError: + # If attribute error is encountered, create the sequence directly + stubbed_schema = sequence_type(stubbed_schema) + return stubbed_schema + + elif isinstance(value, abc.Mapping): + dict_type = type(value) + stubbed_schema = OrderedDict() + # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. - for key, val in input_value.items(): - _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) + for key, val in value.items(): + + def _access_func(key, cur_func, args, kwargs): + return cur_func(args, kwargs)[key] + + input_schema = _add_input( + f"{name}_{key}", + val, + onnx_graph_input_names, + partial(_access_func, key, cur_func), + tensor_idx, + ) + + if not isinstance(input_schema, SkipRetValue): + stubbed_schema[key] = input_schema # Return here since the dict by itself is not a valid input. # All the elements of the dict have already been added as inputs individually. - return - # InputInfo should contain all the names irrespective of whether they are - # a part of the onnx graph or not. - input_names.append(name) + stubbed_schema = dict_type(**stubbed_schema) + return stubbed_schema - if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor): - if input_value.requires_grad: + if isinstance(value, torch.Tensor): + onnx_graph_input_names.append(name) + data_accessors[name] = cur_func + if value.requires_grad: input_names_require_grad.append(name) - dynamic_axes.update(_add_dynamic_shape(name, input_value)) - input_shape.append(list(input_value.size())) + dynamic_axes.update(_add_dynamic_shape(name, value)) + input_shape.append(list(value.size())) + tensor_idx[0] += 1 + return _TensorStub( + tensor_idx[0], + dtype=primitive_dtype if primitive_dtype else str(value.dtype), # special handle for bool primitive + shape_dims=len(value.size()), + name=name, + ) - # Ignore optional inputs explicitly specified as None - # ONNX exporter may remove unused inputs - onnx_graph_input_names: List[str] = [] - if onnx_graph is not None: - onnx_graph_input_names = {inp.name for inp in onnx_graph.graph.input} + raise ORTModuleIOError(f"ORTModule does not support input type {type(value)} for input {name}") - input_names: List[str] = [] + visited_input_names: List[str] = [] + + onnx_graph_input_names: List[str] = [] dynamic_axes: Dict[str, Dict[int, str]] = {} input_names_require_grad: List[str] = [] input_shape: List[List[int]] = [] + input_arg_schema: ORTModelInputOutputSchemaType = [] + input_kwarg_schema: ORTModelInputOutputSchemaType = OrderedDict() + data_accessors: Dict[str, Callable] = OrderedDict() + const_to_tensor_inputs: Dict[str, torch.device] = OrderedDict() + num_positional_args: int = 0 + var_positional_idx = 0 # Be noted, all_input_parameters is a list of inspect.Parameters parsed from the original module's forward method. @@ -504,7 +437,17 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = args[args_i] - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + pre_tensor_idx = arg_tensor_idx[0] + schema = _add_input( + name, + inp, + onnx_graph_input_names, + partial(_arg_access_with_index_func, args_i), + arg_tensor_idx, + ) + num_positional_args += arg_tensor_idx[0] - pre_tensor_idx + if not isinstance(schema, SkipRetValue): + input_arg_schema.append(schema) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -514,25 +457,51 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): name = input_parameter.name inp = None input_idx += var_positional_idx # noqa: PLW2901 - if input_idx < len(args) and args[input_idx] is not None: + access_func = None + if input_idx < len(args): inp = args[input_idx] - elif name in kwargs and kwargs[name] is not None: + access_func = partial(_arg_access_with_index_func, input_idx) + pre_tensor_idx = arg_tensor_idx[0] + schema = _add_input(name, inp, onnx_graph_input_names, access_func, arg_tensor_idx) + num_positional_args += arg_tensor_idx[0] - pre_tensor_idx + if not isinstance(schema, SkipRetValue): + input_arg_schema.append(schema) + elif name in kwargs: inp = kwargs[name] - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + access_func = partial(_kwarg_access_with_name_func, name) + schema = _add_input(name, inp, onnx_graph_input_names, access_func, kwarg_tensor_idx) + if not isinstance(schema, SkipRetValue): + input_kwarg_schema[name] = schema + elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): - _add_input(name, inp, onnx_graph, onnx_graph_input_names) - - return _InputInfo( - names=input_names, - shape=input_shape, - require_grad_names=input_names_require_grad, - dynamic_axes=dynamic_axes, - schema=schema, - num_positionals=len(args), + schema = _add_input( + name, + inp, + onnx_graph_input_names, + partial(_kwarg_access_with_name_func, name), + kwarg_tensor_idx, + ) + if not isinstance(schema, SkipRetValue): + input_kwarg_schema[name] = schema + + exported_graph = ModelInfoForExport( + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=input_names_require_grad, + onnx_graph_input_dynamic_axes_map=dynamic_axes, + onnx_graph_input_shapes=input_shape, + onnx_graph_input_data_accessor_user_defined=data_accessors, + onnx_graph_input_const_as_tensor=const_to_tensor_inputs, + onnx_graph_input_arg_schema=input_arg_schema, + onnx_graph_input_kwarg_schema=input_kwarg_schema, + num_positional_args=num_positional_args, + export_mode=export_mode, + export_extra_kwargs=export_extra_kwargs, ) + return exported_graph + def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: """Calculate the total parameter size in bytes""" @@ -568,20 +537,19 @@ def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.de def parse_outputs_for_onnx_export_and_extract_schema( module, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], + flatten_args: Sequence[ORTModelInputOutputType], logger: Logger, - device: Optional[torch.device], clone_module: bool, ): # Perform a forward call to grab outputs output_names = None output_dynamic_axes = None deep_copied = False + kwargs = {} logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. - sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) + sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*flatten_args, **kwargs) try: if clone_module: # Deepcopy model, in case model is stateful and changes after model run. @@ -600,9 +568,17 @@ def parse_outputs_for_onnx_export_and_extract_schema( sample_outputs = model_copy(*sample_args_copy, **sample_kwargs_copy) # Parse the output and extract the output_names and output_dynamic_axes to be used for onnx export - output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) + output_names: List[str] = [] + output_dynamic_axes: Dict[str, Dict[int, str]] = {} + for output_idx, output in enumerate(sample_outputs): + output_name = f"output-{output_idx}" + output_names.append(output_name) + output_dynamic_axes[output_name] = {} + for dim_idx in range(len(output.shape)): + output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) + + original_module_output_schema = model_copy._output_schema - output_schema = _extract_schema(sample_outputs, device) if deep_copied: del model_copy gc.collect() @@ -611,4 +587,4 @@ def parse_outputs_for_onnx_export_and_extract_schema( # Release the memory cached by torch. torch.cuda.empty_cache() # Return output names, output dynamic axes and output schema - return output_names, output_dynamic_axes, output_schema + return output_names, output_dynamic_axes, original_module_output_schema diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index 91b99d4323d6f..4d54e8e59fb50 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -165,6 +165,24 @@ def wrapper(graph_execution_manager, *args, **kwargs): return wrapper +class TrackTimeForStaticFunction: + """A function decorator to track time spent in different phases of ORT backend first-time initialization.""" + + def __init__(self, phase: ORTModuleInitPhase): + self.phase = phase + + def __call__(self, func: Callable): + def wrapper(*args, **kwargs): + if "time_tracker" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'time_tracker' kwarg.") + kwargs["time_tracker"].start(self.phase) + result = func(*args, **kwargs) + kwargs["time_tracker"].end(self.phase) + return result + + return wrapper + + @contextmanager def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): """Suppress output from being printed to stdout and stderr. @@ -255,27 +273,27 @@ def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True): self.is_ort_filter = is_ort_filter def __call__(self, func: Callable): - def wrapper(graph_execution_manager, *args, **kwargs): - if not hasattr(graph_execution_manager, "_logger"): - raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.") + def wrapper(*args, **kwargs): + if "logger" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'logger' kwarg.") - if not hasattr(graph_execution_manager, "_debug_options"): - raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.") + if "debug_options" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'debug_options' kwarg.") with _suppress_os_stream_output( - enable=graph_execution_manager._debug_options.log_level >= LogLevel.DEVINFO, + enable=kwargs["debug_options"].log_level >= LogLevel.DEVINFO, on_exit=partial( _log_with_filter, - graph_execution_manager._logger, + kwargs["logger"], ( - graph_execution_manager._debug_options.onnxruntime_log_filter + kwargs["debug_options"].onnxruntime_log_filter if self.is_ort_filter - else graph_execution_manager._debug_options.torch_exporter_filter + else kwargs["debug_options"].torch_exporter_filter ), self.phase.to_string(), ), ): - result = func(graph_execution_manager, *args, **kwargs) + result = func(*args, **kwargs) return result return wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 241b7a5c5344b..93d151ea1217d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -26,14 +26,6 @@ def get_params_connected_to_pull_param_trigger( return {k: v for k, v in named_params if v.requires_grad and k in onnx_initializer_names} -def get_params_not_connected_to_pull_param_trigger( - named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto -): - # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. - onnx_initializer_names = {p.name for p in exported_model.graph.input} - return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names] - - def post_processing_enable_mem_efficient_training( exported_model: ModelProto, named_params: dict[str, torch.nn.parameter.Parameter], diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index a0001a2f201f1..4b6011f0786ec 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -23,8 +23,7 @@ def _save_model(model: onnx.ModelProto, file_path: str): class ONNXModels: """Encapsulates all ORTModule onnx models. - 1. exported_model: Model that is exported by torch.onnx.export - 2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, + 1. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, for training mode, it's an optimized model after the gradients graph has been built. In addition, ORTModule also saves two other models, to the user-provided path: a. the pre_grad_model which is the model before the gradients graph is built. @@ -32,16 +31,8 @@ class ONNXModels: It has further optimizations done by the InferenceSession and is saved by the InferenceSession. """ - exported_model: Optional[onnx.ModelProto] = None - processed_exported_model: Optional[onnx.ModelProto] = None optimized_model: Optional[onnx.ModelProto] = None - def save_exported_model(self, path, name_prefix, export_mode): - # save the ortmodule exported model - _save_model( - self.exported_model, os.path.join(path, _get_onnx_file_name(name_prefix, "torch_exported", export_mode)) - ) - def save_optimized_model(self, path, name_prefix, export_mode): # save the ortmodule optimized model _save_model( diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index f35e3f74ba60a..3708343a228fc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -12,12 +12,12 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type -from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils +from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils from ._execution_agent import TrainingAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import _FlattenedModule, _InputInfo, unflatten_user_output +from ._io import _FlattenedModule from ._logger import ORTModuleInitPhase, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results @@ -247,27 +247,17 @@ def forward(self, *inputs, **kwargs): build_gradient_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._onnx_models.exported_model + or not self._graph_transition_manager._exported_model_info ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) - build_gradient_graph = self._export_model(*inputs, **kwargs) + ( + build_gradient_graph, + post_export_processed_model_info, + ) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs) if build_gradient_graph: - # If model was exported, then initialize the graph builder - self._initialize_graph_builder() - - # Since the schema was just extracted while trying to export the model and it was either - # saved to self._input_info.schema or checked for equality with the self._input_info.schema - # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. - input_info = _io.parse_inputs_for_onnx_export( - self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs - ) - - # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. - # Order of or operation is important here because we always need to call - # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. - build_gradient_graph = self._reinitialize_graph_builder(input_info) or build_gradient_graph + self._initialize_graph_builder(post_export_processed_model_info) # Build the gradient graph if build_gradient_graph: @@ -284,9 +274,7 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent ): - device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) + device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) create_execution_session = ( build_gradient_graph or self._device != device @@ -294,7 +282,7 @@ def forward(self, *inputs, **kwargs): ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != device: - self._device = device + self._graph_transition_manager._device = device if create_execution_session: # Create execution session creates the training_session @@ -309,36 +297,16 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: + if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - param_to_append_as_onnx_graph_inputs = [] - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger - - param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters(), self._onnx_models.exported_model - ) - - else: - param_to_append_as_onnx_graph_inputs = self._graph_initializers - - prepared_input_list = _io._combine_input_buffers_initializers( - param_to_append_as_onnx_graph_inputs, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device, - self._runtime_inspector, - self._zero_stage3_param_map, + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device ) - outputs = unflatten_user_output( - self._module_output_schema, - self._forward_class.apply(*prepared_input_list), - ) + user_outputs = self._forward_class.apply(*prepared_input_map.values()) + + outputs = self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) if ( create_execution_session @@ -390,21 +358,15 @@ def _build_graph(self, graph_transformer_config): # Map each input/initializer to its gradient index in the graph output, or -1 is gradient is not required. self._gradient_map = [] - num_user_input_grads = len(self._input_info.require_grad_names) - require_grad_names_set = set(self._input_info.require_grad_names) - require_grad_names_index = 0 - for input_name in self._graph_info.user_input_names: - if input_name in require_grad_names_set: - self._gradient_map.append(require_grad_names_index) - require_grad_names_index += 1 - else: - self._gradient_map.append(-1) - initializer_index = num_user_input_grads - for initializer_name in self._graph_info.initializer_names: - if initializer_name in self._graph_initializer_names_to_train: - self._gradient_map.append(initializer_index) - initializer_index += 1 + index_for_input_requires_grad = 0 + for input_name in self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names: + if ( + input_name + in self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names_require_grad + ): + self._gradient_map.append(index_for_input_requires_grad) + index_for_input_requires_grad += 1 else: self._gradient_map.append(-1) @@ -414,7 +376,8 @@ def _create_execution_agent(self): session_options, providers, provider_options = self._get_session_config() fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] - device_type = self._device if type(self._device) is str else self._device.type.lower() # noqa: E721 + device_type = self._device if isinstance(self._device, str) else self._device.type.lower() + if device_type == "ort": fw_outputs_device_info = [C.get_ort_device(self._device.index)] * ( len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map) @@ -491,39 +454,6 @@ def _create_execution_agent(self): self._execution_agent._inference_session, True, self._runtime_options.tuning_results_path ) - def _reinitialize_graph_builder(self, input_info: _InputInfo): - """Return true if the module graph builder was reinitialized""" - - # Model may have unused params dropped after export and not part of self._graph_initializer_names_to_train - # To see if any trainable initializers changed, compare self._graph_initializer_names_to_train - # with initializers in module named_parameters that are known to the onnx graph. - initializer_names_to_train_set_user_model = { - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in self._graph_initializer_names - } - - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - - # Remove the inputs we added during model post-processing. - existing_require_grad_names = [ - n for n in self._input_info.require_grad_names if n != MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - ] - else: - existing_require_grad_names = self._input_info.require_grad_names - - # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder - # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad - if ( - input_info.require_grad_names != existing_require_grad_names - or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train - ): - self._input_info = input_info - self._initialize_graph_builder() - return True - return False - def __getstate__(self): state = super().__getstate__() diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 5faa1c62bae4f..c299d1c5db4e7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -153,7 +153,15 @@ def get_device_str(device: Union[str, int, torch.device]) -> str: return device -def get_device_from_module(module) -> Optional[torch.device]: +def get_device_from_module_and_inputs(module, inputs, kwargs): + """Get the device from the module and save it to self._device""" + + device = _get_device_from_module(module) or _get_device_from_inputs(inputs, kwargs) + + return device + + +def _get_device_from_module(module) -> Optional[torch.device]: """Returns the first device found in the `module`'s parameters or None Args: @@ -179,7 +187,7 @@ def get_device_from_module(module) -> Optional[torch.device]: return device -def get_device_from_inputs(args, kwargs) -> Optional[torch.device]: +def _get_device_from_inputs(args, kwargs) -> Optional[torch.device]: """Returns device from first PyTorch Tensor within args or kwargs Args: @@ -192,9 +200,12 @@ def get_device_from_inputs(args, kwargs) -> Optional[torch.device]: device = None if args: - device = torch.device(args[0].device) + if args[0] is not None and hasattr(args[0], "device"): + device = torch.device(args[0].device) elif kwargs: - device = torch.device(next(iter(kwargs.values())).device) + v = next(iter(kwargs.values())) + if v is not None and hasattr(v, "device"): + device = torch.device(v.device) return device diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index ff110c431d300..11d978e71d8a8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -395,7 +395,7 @@ def _update_python_op_input_related_attributes( @contextmanager -def stage3_export_context(enable: bool, graph_execution_manager): +def stage3_export_context(enable: bool, stage3_param_handle, flattened_module): """Context manager for stage3 specific model export. Some export functions are overridden when entering the context; the original functions are restored when exiting the context. @@ -411,9 +411,7 @@ def stage3_export_context(enable: bool, graph_execution_manager): # Delay collecting stage3 parameters here instead of in the graph execution manager, # to make sure DeepSpeed initialization is done, so that the parameters ds_status are correct. - graph_execution_manager._zero_stage3_param_map = _get_all_zero_stage3_params( - graph_execution_manager._flattened_module - ) + stage3_param_handle._zero_stage3_param_map = _get_all_zero_stage3_params(flattened_module) try: from torch.onnx._internal import _beartype @@ -428,8 +426,8 @@ def _get_tensor_rank(x) -> Optional[int]: from torch.onnx.symbolic_helper import _is_tensor input_name = x.debugName() - if input_name in graph_execution_manager._zero_stage3_param_map: - rank = len(graph_execution_manager._zero_stage3_param_map[input_name].ds_shape) + if input_name in stage3_param_handle._zero_stage3_param_map: + rank = len(stage3_param_handle._zero_stage3_param_map[input_name].ds_shape) return rank if not _is_tensor(x) or x.type() is None: diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index b5c52bdaef3c6..ba6f7c2d0c03a 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -306,7 +306,9 @@ def __setattr__(self, name: str, value) -> None: # Re-export will be avoided if _skip_check is enabled. if isinstance(self._torch_module, TorchModuleORT): for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode).signal_model_changed() + self._torch_module._execution_manager( + training_mode + )._graph_transition_manager.signal_model_changed() else: # Setting any new attributes should be done on ORTModule only when 'torch_module' is not defined diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index 34cc1ca942a8c..4824ed7137021 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -5,7 +5,7 @@ import copy import warnings -from collections import abc +from collections import OrderedDict, abc from typing import List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -221,8 +221,8 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): return stubbed_schema elif isinstance(data, abc.Mapping): dict_type = type(data) - stubbed_schema = {} - for key, val in sorted(data.items()): + stubbed_schema = OrderedDict() + for key, val in data.items(): stubbed_schema[key] = _flatten_from_data(val, f"{prefix_name}_{key}" if prefix_name else f"{key}") stubbed_schema = dict_type(**stubbed_schema) return stubbed_schema @@ -305,7 +305,7 @@ def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType, return data_schema elif isinstance(data_schema, abc.Mapping): new_user_output = copy.copy(data_schema) - for key, schema_val in sorted(data_schema.items()): + for key, schema_val in data_schema.items(): new_user_output[key] = _replace_stub_with_tensor_value(schema_val, data) data_schema = new_user_output diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f35bb47f6b41d..541473b1561db 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -30,7 +30,7 @@ import onnxruntime.training.ortmodule as ortmodule_module from onnxruntime.training.optim import AdamWMode, FusedAdam -from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils +from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck from onnxruntime.training.utils import pytorch_type_to_onnx_dtype @@ -463,9 +463,11 @@ def test_forward_call_single_positional_argument(): N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) + # Check that the original forward signature is preserved. assert inspect.signature(model.forward) == inspect.signature(ort_model.forward) x = torch.randn(N, D_in, device=device) + # Make sure model runs without any exception prediction = ort_model(x) assert prediction is not None @@ -699,7 +701,15 @@ def test_input_requires_grad_saved(device): model = ORTModule(model) x = torch.randn(N, D_in, device=device, requires_grad=True) + 1 model(x) - assert "input1" in model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names + assert model._torch_module._execution_manager( + model._is_training() + )._graph_transition_manager._model_info_for_export.onnx_graph_input_names_require_grad == ["input1"] + assert ( + "input1" + in model._torch_module._execution_manager( + model._is_training() + )._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names_require_grad + ) @pytest.mark.parametrize("device", ["cuda", "cpu"]) @@ -841,6 +851,7 @@ def forward(self, input): device = "cuda" pt_model = NeuralNetTranspose(perm).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x): @@ -2655,7 +2666,10 @@ def test_exception_raised_for_custom_class_return_value_module(device): # ORT backend with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3482,6 +3496,13 @@ def train_step(model, x): _test_helpers.assert_values_are_close(pt_out, ort_out) +def _repr_schema(ortmodule): + tm = ortmodule._torch_module._execution_manager(ortmodule._is_training())._graph_transition_manager + return repr(tm._exported_model_info.module_forward_args_schema) + repr( + tm._exported_model_info.module_forward_kwargs_schema + ) + + def test_forward_dynamic_args(): os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" @@ -3506,21 +3527,21 @@ def test_forward_dynamic_args(): for _ in range(10): output = model(*args_size1) assert output is not None - hash_args_size1 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size1 = hash(_repr_schema(model)) assert hash_args_size1 is not None # Decrease number of inputs and train some more for _ in range(10): output = model(*args_size2) assert output is not None - hash_args_size2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size2 = hash(_repr_schema(model)) assert hash_args_size2 != hash_args_size1 # Increase number of inputs and train some more for _ in range(10): output = model(*args_size3) assert output is not None - hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size3 = hash(_repr_schema(model)) assert hash_args_size3 != hash_args_size2 del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3545,35 +3566,35 @@ def test_forward_dynamic_kwargs(): for _ in range(10): output = model(one) assert output is not None - hash_x = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x = hash(_repr_schema(model)) assert hash_x is not None # Train with x and y as inputs for _ in range(10): output = model(one, y=one) assert output is not None - hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_y = hash(_repr_schema(model)) assert hash_x_y != hash_x # Train with x and z as inputs for _ in range(10): output = model(one, z=one) assert output is not None - hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_z = hash(_repr_schema(model)) assert hash_x_z != hash_x_y # Train with x, y and z as inputs for _ in range(10): output = model(one, y=one, z=one) assert output is not None - hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_y_z = hash(_repr_schema(model)) assert hash_x_y_z != hash_x_z # Return to original input with x as input for _ in range(10): output = model(one) assert output is not None - hash_x2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x2 = hash(_repr_schema(model)) assert hash_x2 != hash_x_y_z assert hash_x2 == hash_x @@ -4003,10 +4024,14 @@ def forward(self, input1, bool_argument): input1 = torch.randn(N, D_in, device=device) ort_model(input1, bool_arguments[0]) - exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model1 = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model ort_model(input1, bool_arguments[1]) - exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model2 = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model assert exported_model1 != exported_model2 @@ -4172,36 +4197,6 @@ def test_stateless_model_unspecified_device(): _test_helpers.assert_values_are_close(pt_y, ort_y) -@pytest.mark.parametrize( - "model", - [ - (UnusedBeginParameterNet(784, 500, 400, 10)), - (UnusedMiddleParameterNet(784, 500, 400, 10)), - (UnusedEndParameterNet(784, 500, 400, 10)), - ], -) -def test_unused_parameters_does_not_unnecessarily_reinitialize(model): - device = "cuda" - - N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 # noqa: F841, N806 - model = model.to(device) - ort_model = ORTModule(copy.deepcopy(model)) - training_manager = ort_model._torch_module._execution_manager(ort_model._is_training()) - - x = torch.randn(N, D_in, device=device) - _ = ort_model(x) - - input_info = _io.parse_inputs_for_onnx_export( - training_manager._module_parameters, - training_manager._onnx_models.exported_model, - training_manager._input_info.schema, - x, - {}, - ) - - assert not training_manager._reinitialize_graph_builder(input_info) - - def test_load_state_dict_for_wrapped_ortmodule(): class WrapperModule(torch.nn.Module): def __init__(self, ortmodule): @@ -4264,14 +4259,23 @@ def test_hf_save_pretrained(): assert p1.data.ne(p2.data).sum() == 0 -def test_ortmodule_string_inputs_are_ignored(): +def test_ortmodule_string_inputs_are_ignored(caplog): pt_model = MyStrNet() - target_str = "Received input of type which may be treated as a constant by ORT by default." - with pytest.warns(UserWarning, match=target_str): - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) - x = torch.randn(1, 2) - out = ort_model(x, "hello") - _test_helpers.assert_values_are_close(out, x + 1) + target_str = "Received input of type is treated as a constant by ORT by default." + + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) + x = torch.randn(1, 2) + out = ort_model(x, "hello") + _test_helpers.assert_values_are_close(out, x + 1) + + found_log = False + for record in caplog.records: + msg = record.getMessage() + if target_str in msg: + found_log = True + break + + assert found_log, f"Expected to find log message '{target_str}' in the logs, but didn't find it." def test_ortmodule_list_input(): @@ -4831,17 +4835,31 @@ def forward(self, a): ort_model = ORTModule(pt_model) _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model1 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + exported_model1 = ort_model._torch_module._execution_manager( + True + )._graph_transition_manager._exported_model_info.exported_model for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is False + assert ( + ort_model._torch_module._execution_manager( + training_mode + )._graph_transition_manager._original_model_has_changed + is False + ) ort_model.input_flag = False for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is True + assert ( + ort_model._torch_module._execution_manager( + training_mode + )._graph_transition_manager._original_model_has_changed + is True + ) _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model2 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + exported_model2 = ort_model._torch_module._execution_manager( + True + )._graph_transition_manager._exported_model_info.exported_model assert exported_model1 != exported_model2 @@ -4999,7 +5017,9 @@ def test_override_pytorch_exporter_kwargs(): model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"custom_opsets": None} + ort_model._torch_module._execution_manager(True)._graph_transition_manager._export_extra_kwargs = { + "custom_opsets": None + } # Make sure model runs without any exception prediction = ort_model(x) @@ -5016,7 +5036,7 @@ def test_override_pytorch_exporter_kwargs__invalid(): model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"verbose": False} + ort_model._torch_module._execution_manager(True)._graph_transition_manager._export_extra_kwargs = {"verbose": False} with pytest.raises(_fallback.ORTModuleONNXModelException) as type_error: _ = ort_model(x) assert "The following PyTorch exporter arguments cannot be specified: '{'verbose'}'." in str(type_error.value) @@ -5029,7 +5049,9 @@ class ORTModuleExtension(ORTModule): def __init__(self, module, debug_options=None): super().__init__(module, debug_options) for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"verbose": None} + self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = { + "verbose": None + } N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5049,7 +5071,9 @@ def __init__(self, module, debug_options=None): super().__init__(module, debug_options) # modify GraphExecutionManager internally for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"custom_opsets": None} + self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = { + "custom_opsets": None + } N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5280,11 +5304,12 @@ def run_step(model, x): ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: - model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models - for name in ["exported_model", "optimized_model"]: - onx = getattr(model_onx, name) + for onnx_model in [ + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model, + ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model, + ]: opv = None - for op in onx.opset_import: + for op in onnx_model.opset_import: if not op.domain: opv = op.version assert opv == 13 @@ -5323,7 +5348,9 @@ def test_opset_version_change(opset_version): prediction.backward() # Check opset version on ONNX model - exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model assert exported_model.opset_import[0].version == opset_version if original_env is not None: @@ -5334,6 +5361,7 @@ def test_serialize_ortmodule(): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 pt_model = SerializationNet(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) x_1 = torch.randn(N, D_in, device=device) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 99c15034cdafe..95012aa0507a5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1415,10 +1415,12 @@ def check_pythonop_training_mode(model, is_eval_mode): ## make sure the ort's PythonOp's training_mode is correct if is_eval_mode: onnx_nodes = ( - model._torch_module._execution_manager._inference_manager._onnx_models.exported_model.graph.node + model._torch_module._execution_manager._inference_manager._graph_transition_manager._exported_model_info.exported_model.graph.node ) else: - onnx_nodes = model._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + onnx_nodes = ( + model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node + ) found_pythonop = False for node in onnx_nodes: @@ -1837,7 +1839,9 @@ def forward(self, model_input): ortmodule = ORTModule(TestModel(output_size)).train() _ = ortmodule(torch.randn(output_size, dtype=torch.float)) - onnx_nodes = ortmodule._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + onnx_nodes = ( + ortmodule._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node + ) found_pythonop = False for node in onnx_nodes: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 34453c89157a8..4e0fcafecffe5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -49,6 +49,7 @@ def test_ortmodule_fallback_forward(is_training, fallback_enabled, matching_poli class Point: x: int y: int + device: str = "cpu" # Otherwise, no device can be found from inputs, and the test will fail earlier. class UnsupportedInputModel(torch.nn.Module): def __init__(self): @@ -78,11 +79,17 @@ def forward(self, point): else: with pytest.raises(_fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) - assert "ORTModule fails to extract schema from data" in str(type_error.value) + assert ( + "ORTModule does not support input type .Point'> for input point" + in str(type_error.value) + ) else: with pytest.raises(_fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) - assert "ORTModule fails to extract schema from data" in str(type_error.value) + assert ( + "ORTModule does not support input type .Point'> for input point" + in str(type_error.value) + ) @pytest.mark.parametrize( @@ -250,11 +257,17 @@ def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_polic else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) @pytest.mark.parametrize( @@ -302,20 +315,18 @@ def __init__(self, x): with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) assert ( - "ORTModule fails to extract schema from data: " - "Unsupported flatten data type: " - ".CustomClass'>" in str(ex_info.value) + "ORTModule does not support input type " + ".CustomClass'> " + "for input custom_class_obj" in str(ex_info.value) ) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) - assert ( - "ORTModule fails to extract schema from data: " - "Unsupported flatten data type: " - ".CustomClass'>" in str(ex_info.value) - ) + assert ( + "ORTModule does not support input type " + ".CustomClass'> " + "for input custom_class_obj" in str(ex_info.value) + ) @pytest.mark.parametrize( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 35c5b736bd962..e1def2022d63f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -69,7 +69,9 @@ def run_step(model, x): self.assert_values_are_close(ort_prediction, pt_prediction, **kwargs) self.assert_gradients_match_and_reset_gradient(ort_model, pt_model, **kwargs) - onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._onnx_models.exported_model + onnx_graph_inf = ( + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model + ) onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model if debug: with open("debug_%s_ortmodule_infer.onnx" % name, "wb") as f: From f39ee14b46f7977e43c6cae3b8e8c56b968a7d98 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Wed, 3 Jul 2024 14:55:31 +0800 Subject: [PATCH 06/13] Add GQA support for ROCm (#21032) --- cmake/CMakeLists.txt | 2 +- cmake/onnxruntime_rocm_hipify.cmake | 1 - .../contrib_ops/cuda/bert/attention_impl.h | 7 + .../cuda/bert/attention_strided_copy.cu | 63 ++- .../cuda/bert/group_query_attention_impl.cu | 6 +- .../cuda/bert/rotary_embedding_impl.cu | 59 +- .../cuda/bert/rotary_embedding_impl.h | 20 + .../contrib_ops/rocm/bert/attention_impl.h | 7 + .../rocm/bert/group_query_attention.cu | 526 ++++++++++++++++++ .../rocm/bert/group_query_attention.h | 38 ++ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 4 + .../transformers/test_flash_attn_rocm.py | 86 +++ .../orttraining-pai-ci-pipeline.yml | 27 + .../pai/rocm-ci-pipeline-env.Dockerfile | 6 +- 14 files changed, 810 insertions(+), 42 deletions(-) create mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.h create mode 100644 onnxruntime/test/python/transformers/test_flash_attn_rocm.py diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c4412e0934f17..a9b0dfb30cc4e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -241,7 +241,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) +cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2be68146b5e94..2966a4624a966 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -88,7 +88,6 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" - "bert/group_query_attention_helper.h" "bert/group_query_attention.h" "bert/group_query_attention.cc" "bert/group_query_attention_impl.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 36fd7708de04b..fda7ac2784129 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -176,6 +176,13 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const T* qkv_buffer, T* present); +template +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu index 1466f5fcfe0be..66e56e701c558 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -12,23 +12,27 @@ namespace cuda { template __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) { const int h = threadIdx.x; const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; + + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + if (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; } } template __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int* in_seqlens_offset, const int* out_seqlens_offset) { // Use when (H*)*num_heads > 1024 int h = threadIdx.x; const int n = threadIdx.y; @@ -37,9 +41,12 @@ __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, const int h_step = blockDim.x; + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + while (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; h += h_step; } @@ -77,10 +84,11 @@ template using ToBytes = typename ToByteType::T; template -Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block) { +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block) { int batch_size = in_shape.x; int num_heads = in_shape.y; int sequence_length = in_shape.z; @@ -102,11 +110,13 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else if (0 == (head_size % 2)) { // pack 2 element together using Bytes = ToBytes; @@ -120,27 +130,44 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else { using Bytes = ToBytes; if (head_size * num_heads <= max_threads_per_block) { const dim3 block(head_size, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } return CUDA_CALL(cudaGetLastError()); } +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block) { + const int* in_seqlens_offset = nullptr; + const int* out_seqlens_offset = nullptr; + return LaunchStridedCopy( + stream, in, in_shape, in_strides, in_seqlens_offset, + out, out_strides, out_seqlens_offset, + max_threads_per_block); +} + template Status LaunchStridedCopy( cudaStream_t stream, const float* in, int4 in_shape, longlong4 in_strides, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index c38929697f3cb..3099b52cce13e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -577,7 +577,7 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; int b = tid / seqlen; @@ -592,7 +592,7 @@ __global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { position_ids[tid] = seqlens_k[tid]; @@ -600,7 +600,7 @@ __global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, } // Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index 1b28b288f3d7c..ad0a83c9cde65 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -25,8 +25,9 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int64_t* position_ids, // (1) or BxS const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int position_ids_format, - const bool interleaved, const int batch_stride, const int seq_stride, - const int head_stride) { + const bool interleaved, + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous +) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -40,10 +41,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH return; } - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - - const T* input_data = input + block_offset; - T* output_data = output + block_offset; + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; if (i >= rotary_embedding_dim) { output_data[i] = input_data[i]; @@ -77,34 +76,58 @@ template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, const T* cos_cache, const T* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool is_input_bnsh_format) { + int4 in_strides; + int4 out_strides; + if (is_input_bnsh_format) { + int in_head_stride = sequence_length * head_size; + int out_head_stride = sequence_length * head_size; + in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; + out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; + } else { + int in_head_stride = head_size; + int out_head_stride = head_size; + in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1}; + out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1}; + } + return LaunchRotaryEmbeddingKernel( + stream, output, input, position_ids, + cos_cache, sin_cache, batch_size, + sequence_length, num_heads, head_size, + rotary_embedding_dim, max_sequence_length, + position_ids_format, interleaved, + max_threads_per_block, + in_strides, out_strides); +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, + int4 in_strides, int4 out_strides // strides in bnsh coord +) { // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); int tpb = (head_size + 31) / 32 * 32; const dim3 block(tpb); const dim3 grid(sequence_length, batch_size, num_heads); - // Default input tensor shape is [batch, seq, hidden_size] - int head_stride = head_size; - int seq_stride = num_heads * head_stride; - int batch_stride = sequence_length * seq_stride; - if (is_input_bnsh_format) { - seq_stride = head_size; - head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; - } - assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, rotary_embedding_dim, position_ids_format, - interleaved, batch_stride, seq_stride, head_stride); + interleaved, in_strides, out_strides); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 6053814b835bb..dd0ac6a6e3274 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -28,6 +28,26 @@ Status LaunchRotaryEmbeddingKernel( const int max_threads_per_block, const bool is_input_bnsh_format); +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + int4 in_strides, + int4 out_strides); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 3164e8c211099..349df045becf2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -169,6 +169,13 @@ Status ClassifyAttentionMode(AttentionType type, const std::vector& past, const std::vector& present); +template +Status LaunchStridedCopy( + hipStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(hipStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu new file mode 100644 index 0000000000000..92c780d4a9d41 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -0,0 +1,526 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/group_query_attention.h" +#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" + +#ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "ck_tile/core/numeric/integer.hpp" +#include "fmha_fwd.hpp" +#endif + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +std::string GetCkFmhaDataTypeString(); + +template <> +std::string GetCkFmhaDataTypeString() { + return "fp16"; +} + +template <> +std::string GetCkFmhaDataTypeString() { + return "bf16"; +} + +__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = seqlens[idx] + inc; + } +} + +Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); + return HIP_CALL(hipGetLastError()); +} + +__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = idx * length_per_seq; + } + if (idx == 0) { + out[num_elems] = num_elems * length_per_seq; + } +} + +Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqstart_init_kernel<<>>(out, num_elems, length_per_seq); + return HIP_CALL(hipGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, + int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return HIP_CALL(hipGetLastError()); +} + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : RocmKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + is_unidirectional_ = true; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { +#if USE_COMPOSABLE_KERNEL_CK_TILE + auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); + const Tensor* query = ctx->Input(0); + const Tensor* key = ctx->Input(1); + const Tensor* value = ctx->Input(2); + const Tensor* past_key = ctx->Input(3); + const Tensor* past_value = ctx->Input(4); + const Tensor* seqlens_k = ctx->Input(5); + const Tensor* total_seqlen = ctx->Input(6); + const Tensor* cos_cache = ctx->Input(7); + const Tensor* sin_cache = ctx->Input(8); + + auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + + GroupQueryAttentionParameters parameters; + using HipT = typename ToHipType::MappedType; + + const int max_thr_per_blk = device_prop.maxThreadsPerBlock; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + max_thr_per_blk)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + // parameters.zeros_count = kZerosCount; + // parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = ctx->Output(0, output_shape); + Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + + int4 past_shape; + std::vector present_dims; + Strides present_strides; + Strides past_strides; + if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + past_shape = { + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; + past_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); + present_dims = { + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; + present_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + } else { // BNSH + past_shape = { + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; + past_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); + present_dims = { + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; + present_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); + } + TensorShape present_shape(present_dims); + Tensor* present_key = ctx->Output(1, present_shape); + Tensor* present_value = ctx->Output(2, present_shape); + + Strides query_strides; + Strides key_strides; + Strides value_strides; + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord + const HipT* query_ptr = reinterpret_cast(query->DataRaw()); + const HipT* key_ptr; + const HipT* value_ptr; + if (!parameters.is_packed_qkv) { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + value_strides = key_strides; + key_ptr = reinterpret_cast(key->DataRaw()); + value_ptr = reinterpret_cast(value->DataRaw()); + } else { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + value_strides = query_strides; + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key_ptr = query_ptr + key_offset; + value_ptr = key_ptr + value_offset; + } + + IAllocatorUniquePtr rotary_q_tmp; + IAllocatorUniquePtr rotary_k_tmp; + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); + + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); + rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); + auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, + reinterpret_cast(seqlens_k->DataRaw()), + reinterpret_cast(rotary_position_ids_tmp.get()), + hip_stream, max_thr_per_blk)); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + query_strides.ForBNSHCoord(), + rotary_q_strides.ForBNSHCoord())); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + key_strides.ForBNSHCoord(), + rotary_k_strides.ForBNSHCoord())); + query_ptr = reinterpret_cast(rotary_q_tmp.get()); + key_ptr = reinterpret_cast(rotary_k_tmp.get()); + query_strides = rotary_q_strides; + key_strides = rotary_k_strides; + } + + const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; + IAllocatorUniquePtr seqlens_k_tmp; + + // build present kv cache + auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); + auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); + if (parameters.is_prompt) { + // copy prompt kv to present kv + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); + parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: + if (!parameters.kv_share_buffer) { + // copy past to present, + // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are + // not the same, aka, can not be as simple as strided + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + // In the case of share buffer + ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); + ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); + } + // then append new kv to present + size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + + // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. + // we should call fmha with total sequence lenghts + seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); + seqlens_k_ptr = seqlens_k_tmp.get(); + } + static_assert(std::is_same_v); + + const float scale = parameters.scale == 0.0f + ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + bias_enum bias_type = bias_enum::no_bias; + + mask_info mask = [&]() { + if (local_window_size_ != -1) { + mask_info ret; + ret.type = mask_enum::window_generic; + ret.left = local_window_size_; + ret.right = parameters.is_unidirectional ? 0 : -1; + // ret.x = kv_sequence_length - (sequence_length - ret.left); + // ret.y = sequence_length + (ret.right - kv_sequence_length); + return ret; + } + + if (parameters.is_prompt && is_unidirectional_) { + return mask_info::decode("t", sequence_length, kv_sequence_length); + } + + return mask_info::decode("0", sequence_length, kv_sequence_length); + }(); + + auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_q_tmp.get(), batch_size, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_k_tmp.get(), batch_size, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + + fmha_fwd_args args{ + query_ptr, + present_key->DataRaw(), + present_value->DataRaw(), + nullptr, // bias, alibi/element + nullptr, // lse, logsumexp buffer + output->MutableDataRaw(), + seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode + seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode + seqlens_k_ptr, // seqlen_k_ptr, for group mode + sequence_length, // seqlen_q, for batch mode + kv_sequence_length, // seqlen_k, for batch mode + parameters.batch_size, // batch + parameters.sequence_length, // max_seqlen_q + parameters.head_size, // hdim_q + parameters.head_size, // hdim_v + parameters.num_heads, + parameters.kv_num_heads, + scale, + 1.0f, // scale_p of squant, useless + 1.0f, // scale_o of squant, useless + static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S + batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S + static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N + 0, // nhead_stride_bias + batch_size, // nhead_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B + static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B + 0, // batch_stride_bias + num_heads * batch_size, // batch_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B + mask.left, // window_size_left + mask.right, // window_size_right + static_cast(mask.type)}; + +#if 0 + std::cout + << "\n sequence_length:" << sequence_length + << "\n kv_sequence_length:" << kv_sequence_length + << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; + + std::cout + << "\n q_ptr:" << args.q_ptr + << "\n k_ptr:" << args.k_ptr + << "\n v_ptr:" << args.v_ptr + << "\n bias_ptr:" << args.bias_ptr + << "\n lse_ptr:" << args.lse_ptr + << "\n o_ptr:" << args.o_ptr + << "\n seqstart_q_ptr:" << args.seqstart_q_ptr + << "\n seqstart_k_ptr:" << args.seqstart_k_ptr + << "\n seqlen_k_ptr:" << args.seqlen_k_ptr + << "\n seqlen_q:" << args.seqlen_q + << "\n seqlen_k:" << args.seqlen_k + << "\n batch:" << args.batch + << "\n max_seqlen_q:" << args.max_seqlen_q + << "\n hdim_q:" << args.hdim_q + << "\n hdim_v:" << args.hdim_v + << "\n nhead_q:" << args.nhead_q + << "\n nhead_k:" << args.nhead_k + << "\n scale_s:" << args.scale_s + << "\n scale_p:" << args.scale_p + << "\n scale_o:" << args.scale_o + << "\n stride_q:" << args.stride_q + << "\n stride_k:" << args.stride_k + << "\n stride_v:" << args.stride_v + << "\n stride_bias:" << args.stride_bias + << "\n stride_o:" << args.stride_o + << "\n nhead_stride_q:" << args.nhead_stride_q + << "\n nhead_stride_k:" << args.nhead_stride_k + << "\n nhead_stride_v:" << args.nhead_stride_v + << "\n nhead_stride_bias:" << args.nhead_stride_bias + << "\n nhead_stride_lse:" << args.nhead_stride_lse + << "\n nhead_stride_o:" << args.nhead_stride_o + << "\n batch_stride_q:" << args.batch_stride_q + << "\n batch_stride_k:" << args.batch_stride_k + << "\n batch_stride_v:" << args.batch_stride_v + << "\n batch_stride_bias:" << args.batch_stride_bias + << "\n batch_stride_lse:" << args.batch_stride_lse + << "\n batch_stride_o:" << args.batch_stride_o + << "\n window_size_left:" << args.window_size_left + << "\n window_size_right:" << args.window_size_right + << "\n mask_type:" << args.mask_type + << std::endl; +#endif + + fmha_fwd_traits traits{ + parameters.head_size, + parameters.head_size, // v head size + GetCkFmhaDataTypeString(), + !parameters.is_prompt, // true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest + mask.type, + bias_type, + false, // has_lse + false, // do_fp8_static_quant, aka, squant + }; + + ck_tile::stream_config stream_config{ + hip_stream, + false // time_kernel + }; + + auto duration = fmha_fwd(traits, args, stream_config); + if (duration < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); + } + HIP_RETURN_IF_ERROR(hipGetLastError()); + + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h new file mode 100644 index 0000000000000..ce0de1f761aa5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class GroupQueryAttention final : public RocmKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + + private: + static std::once_flag arch_checking_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7e5e7d7ee076d..4284b4254f485 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -71,6 +71,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -227,6 +229,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py new file mode 100644 index 0000000000000..fe7e39722237f --- /dev/null +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -0,0 +1,86 @@ +import platform +import unittest + +import torch +from parameterized import parameterized +from test_flash_attn_cuda import ( + Formats, + gqa_no_past_flash_attention_test_cases, + gqa_past_flash_attention_test_cases, + parity_check_gqa_past, + parity_check_gqa_past_no_buff, + parity_check_gqa_prompt, + parity_check_gqa_prompt_no_buff, +) + +import onnxruntime + + +class TestGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return + print("------- FLASH ATTENTION (PROMPT CASE) --------") + + parity_check_gqa_prompt( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return + print("------- FLASH ATTENTION (TOKEN GEN) -------") + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 7ada4ee6757c9..001062452644e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -255,6 +255,33 @@ jobs: arguments: -n $(Agent.Name) -d $HIP_VISIBLE_DEVICES -r $DRIVER_RENDER displayName: 'Check ROCm Environment' + # TODO: move to use ci_build/build.py driven tests + - task: CmdLine@2 + inputs: + script: |- + docker run --rm \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --device=/dev/kfd \ + --device=/dev/dri/renderD$DRIVER_RENDER \ + --group-add $(video) \ + --group-add $(render) \ + --user onnxruntimedev \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + -e OPENBLAS_NUM_THREADS=1 \ + -e OPENMP_NUM_THREADS=1 \ + -e MKL_NUM_THREADS=1 \ + -e PYTHONPATH=/build/$(BuildConfig) \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ + /bin/bash -c " + set -ex; \ + pip install -r /onnxruntime_src/tools/ci_build/requirements-transformers-test.txt; \ + pytest /onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py -v -n 4 --reruns 1" + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Run tranformers tests' + condition: succeededOrFailed() + - task: CmdLine@2 inputs: script: |- diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 59f6c0ab2136c..b94826ae0e4bc 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -77,7 +77,11 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi RUN export MAJOR=$(cut -d '.' -f 1 <<< "$ROCM_VERSION") && \ export MINOR=$(cut -d '.' -f 2 <<< "$ROCM_VERSION") && \ export PATCH=$(cut -d '.' -f 3 <<< "$ROCM_VERSION") && \ - pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ && \ + if (( MAJOR >= 6 )); then \ + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm${MAJOR}.${MINOR} ; \ + else \ + pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ ; \ + fi && \ pip install torch-ort --no-dependencies ##### Install Cupy to decrease CPU utilization From 30b6e82e7d4840c2a6d7a7b7bd48836c5fab1724 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 4 Jul 2024 11:07:04 +0800 Subject: [PATCH 07/13] Make ROCm packaging stages to a single workflow (#21235) ### Description Make current ROCm packaging stages to a single workflow. Reduce the possibility of all nightly packages can't be generated by one failed stage ### Motivation and Context Our plan is to reduce the complexity of the current zip-nuget pipeline to improve the stability and performance of nightly packages generation. ROCm packaging stages has no dependencies with other packaging jobs and it's the most time-consuming route. After this change, the most used CPU/CUDA/Mobile packaging workflow duration can be reduced roughly from 3h20m to 2h30m. --- .../c-api-noopenmp-packaging-pipelines.yml | 285 -------------- .../github/azure-pipelines/publish-nuget.yml | 5 - .../rocm-nuget-packaging-pipeline.yml | 353 ++++++++++++++++++ .../stages/set_packaging_variables_stage.yml | 22 ++ 4 files changed, 375 insertions(+), 290 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3aadefecaab87..abfe4df3f47ac 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -94,28 +94,6 @@ stages: PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} -- stage: Debug - dependsOn: Setup - jobs: - - job: D1 - pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' - variables: - MyVar: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] - steps: - - checkout: none - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - bash: echo $(MyVar) - - bash: echo $(BuildTime) - - bash: echo $(BuildDate) - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - template: stages/download-java-tools-stage.yml - template: templates/c-api-cpu.yml @@ -167,269 +145,6 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} -# ROCm -- stage: Linux_C_API_Packaging_ROCm_x64 - dependsOn: [] - jobs: - - job: Linux_C_API_Packaging_ROCm_x64 - workspace: - clean: all - timeoutInMinutes: 240 - pool: onnxruntime-Ubuntu2204-AMD-CPU - variables: - RocmVersion: '5.6' - RocmVersionPatchSuffix: '' - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - submodules: recursive - - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml - submodules: false - - # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, - # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur - --build-arg BUILD_UID=$(id -u) - --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) - --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root - --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: - --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib - Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - CheckOutManyLinux: true - - - template: templates/set-version-number-variables-step.yml - - - task: Bash@3 - displayName: 'Build' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh - arguments: >- - -S $(Build.SourcesDirectory) - -B $(Build.BinariesDirectory) - -V $(RocmVersion) - -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - -P python3.10 - - - script: | - set -e -x - mkdir $(Build.ArtifactStagingDirectory)/testdata - cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata - ls -al $(Build.ArtifactStagingDirectory) - displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. - - - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml - parameters: - buildConfig: 'Release' - artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' - libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - template: templates/clean-agent-build-directory-step.yml - - -- stage: NuGet_Packaging_ROCm - dependsOn: - - Setup - - Linux_C_API_Packaging_ROCm_x64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_ROCm - workspace: - clean: all - # we need to use a 2022 pool to create the nuget package with MAUI targets. - # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing - pool: 'Onnxruntime-Win-CPU-2022' - variables: - breakCodesignValidationInjection: ${{ parameters.DoEsrp }} - ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] - - steps: - - checkout: self - submodules: true - fetchDepth: 1 - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-linux-x64-rocm' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'Reconstruct Build Directory' - inputs: - targetType: inline - script: | - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { - # *.tar will be created after *.tgz is extracted - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory - foreach ($ort_dir in $ort_dirs) - { - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname - } - - Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo - - - script: | - tree /F - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Inspect Build Binaries Directory' - - - script: | - mklink /D /J models C:\local\models - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Create models link' - - - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.10.x - inputs: - versionSpec: 6.10.x - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: > - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:IsLinuxBuild=true - -p:IsWindowsBuild=false - -p:IsMacOSBuild=false - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: UsePythonVersion@0 - displayName: 'Use Python' - inputs: - versionSpec: 3.8 - - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - configuration: RelWithDebInfo - platform: 'Any CPU' - msbuildArguments: > - -t:CreatePackage - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:CurrentTime=$(BuildTime) - -p:CurrentDate=$(BuildDate) - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - template: templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)' - DoEsrp: ${{ parameters.DoEsrp }} - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-ROCm' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 - displayName: 'Clean C#' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - -- template: nuget/templates/test_linux.yml - parameters: - AgentPool: AMD-GPU - ArtifactSuffix: 'ROCm' - StageSuffix: 'ROCm' - NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' - SpecificArtifact: ${{ parameters.specificArtifact }} - CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' - BuildId: ${{ parameters.BuildId }} - template: nuget/templates/dml-vs-2022.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index 5e827980e039c..01957a6eec045 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -29,11 +29,6 @@ stages: artifact: 'drop-signed-nuget-GPU' - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' - artifact: 'drop-signed-nuget-ROCm' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - download: build displayName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' artifact: 'drop-signed-nuget-qnn' diff --git a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml new file mode 100644 index 0000000000000..f4022a80b0568 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml @@ -0,0 +1,353 @@ +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +- name: UseIncreasedTimeoutForTests + displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. + type: boolean + default: false + +- name: DoCompliance + displayName: Run Compliance Tasks? + type: boolean + default: true + +- name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release + type: boolean + default: true + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. + type: boolean + default: false + +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + +- name: NugetPackageSuffix + displayName: Suffix to append to nuget package + type: string + default: 'NONE' + +resources: + repositories: + - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step + type: github + endpoint: ort-examples + name: microsoft/onnxruntime-inference-examples + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +variables: +- name: ReleaseVersionSuffix + value: '' + +stages: +- template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + +# ROCm +- stage: Linux_C_API_Packaging_ROCm_x64 + dependsOn: [] + jobs: + - job: Linux_C_API_Packaging_ROCm_x64 + workspace: + clean: all + timeoutInMinutes: 240 + pool: onnxruntime-Ubuntu2204-AMD-CPU + variables: + RocmVersion: '5.6' + RocmVersionPatchSuffix: '' + steps: + - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime + submodules: recursive + - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml + submodules: false + + # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, + # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur + --build-arg BUILD_UID=$(id -u) + --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 + --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) + --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root + --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: + --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib + Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) + CheckOutManyLinux: true + + - template: templates/set-version-number-variables-step.yml + + - task: Bash@3 + displayName: 'Build' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh + arguments: >- + -S $(Build.SourcesDirectory) + -B $(Build.BinariesDirectory) + -V $(RocmVersion) + -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) + -P python3.10 + + - script: | + set -e -x + mkdir $(Build.ArtifactStagingDirectory)/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata + ls -al $(Build.ArtifactStagingDirectory) + displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. + + - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml + parameters: + buildConfig: 'Release' + artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' + artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' + libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' + + - template: templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - template: templates/clean-agent-build-directory-step.yml + +- stage: NuGet_Packaging_ROCm + dependsOn: + - Setup + - Linux_C_API_Packaging_ROCm_x64 + condition: succeeded() + jobs: + - job: NuGet_Packaging_ROCm + workspace: + clean: all + # we need to use a 2022 pool to create the nuget package with MAUI targets. + # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing + pool: 'Onnxruntime-Win-CPU-2022' + variables: + breakCodesignValidationInjection: ${{ parameters.DoEsrp }} + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + + steps: + - checkout: self + submodules: true + fetchDepth: 1 + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-linux-x64-rocm' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: PowerShell@2 + displayName: 'Reconstruct Build Directory' + inputs: + targetType: inline + script: | + Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { + # *.tar will be created after *.tgz is extracted + $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + + Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { + $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + + $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory + foreach ($ort_dir in $ort_dirs) + { + $dirname = Split-Path -Path $ort_dir -Leaf + $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) + Write-Output "Renaming $ort_dir to $dirname" + Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname + } + + Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo + + - script: | + tree /F + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Inspect Build Binaries Directory' + + - script: | + mklink /D /J models C:\local\models + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Create models link' + + - task: NuGetToolInstaller@0 + displayName: Use Nuget 6.10.x + inputs: + versionSpec: 6.10.x + + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# bindings' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: > + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:IsLinuxBuild=true + -p:IsWindowsBuild=false + -p:IsMacOSBuild=false + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: templates/win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + DisplayName: 'ESRP - Sign C# dlls' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: 3.8 + + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + configuration: RelWithDebInfo + platform: 'Any CPU' + msbuildArguments: > + -t:CreatePackage + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:CurrentTime=$(BuildTime) + -p:CurrentDate=$(BuildDate) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - template: templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)' + DoEsrp: ${{ parameters.DoEsrp }} + + - template: templates/validate-package.yml + parameters: + PackageType: 'nuget' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' + PlatformsSupported: 'linux-x64' + VerifyNugetSigning: false + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-ROCm' + targetPath: '$(Build.ArtifactStagingDirectory)' + + - task: MSBuild@1 + displayName: 'Clean C#' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + +- template: nuget/templates/test_linux.yml + parameters: + AgentPool: AMD-GPU + ArtifactSuffix: 'ROCm' + StageSuffix: 'ROCm' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' + SpecificArtifact: ${{ parameters.specificArtifact }} + CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' + BuildId: ${{ parameters.BuildId }} + +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' + ArtifactName: 'drop-signed-nuget-ROCm' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml index 3e2b3b585df9a..e0f0e6f358e70 100644 --- a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml @@ -44,3 +44,25 @@ stages: - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' + +- stage: Debug + dependsOn: Setup + jobs: + - job: D1 + pool: + name: 'onnxruntime-Ubuntu2204-AMD-CPU' + variables: + MyVar: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + steps: + - checkout: none + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - bash: echo $(MyVar) + - bash: echo $(BuildTime) + - bash: echo $(BuildDate) + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' From 7d9b12a2e392b2c86cc7f7f7170d624631dca7b4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 3 Jul 2024 21:51:57 -0700 Subject: [PATCH 08/13] [CPU] SparseAttention op (#21110) Add SparseAttention cpu implementation. - [x] Refactoring GQAAttentionBase - [x] Add SparseAttention implementation - [x] Add test cases This is unfused version. Flash attention version will be added later. --- docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 1 + .../contrib_ops/cpu/bert/attention_base.h | 1 - .../contrib_ops/cpu/bert/attention_cpu_base.h | 5 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 31 +- .../cpu/bert/group_query_attention.cc | 42 +- .../cpu/bert/group_query_attention_helper.h | 32 -- .../contrib_ops/cpu/bert/rotary_helper.h | 47 +++ .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../cpu/sparse/sparse_attention.cc | 226 ++++++++++ .../contrib_ops/cpu/sparse/sparse_attention.h | 21 + .../cpu/sparse/sparse_attention_base.h | 390 ++++++++++++++++++ .../sparse/sparse_attention_helper.h | 6 +- .../cuda/sparse/sparse_attention.cc | 2 +- .../core/graph/contrib_ops/bert_defs.cc | 2 +- .../transformers/test_sparse_attention.py | 389 +++++++++++++---- 16 files changed, 1034 insertions(+), 165 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/bert/rotary_helper.h create mode 100644 onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc create mode 100644 onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h create mode 100644 onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h rename onnxruntime/contrib_ops/{cuda => cpu}/sparse/sparse_attention_helper.h (98%) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45306c852a906..ed9e2a0567d2f 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5646,7 +5646,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16), tensor(bfloat16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output to float tensors.
M : tensor(int32)
Constrain integer type.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5f19c16cba616..df5897529baae 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -512,6 +512,7 @@ Do not modify directly.* |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index af902a713eaa2..a6782daa58f1a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -68,7 +68,6 @@ class AttentionBase { const Tensor* past_seq_len = nullptr) const; int num_heads_; // number of attention heads - int kv_num_heads_; // different for k and v for group query attention bool is_unidirectional_; // whether every token can only attend to previous tokens. std::vector qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute. bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V. diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index fc4905cd31819..dd52001c2ac6b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -3,9 +3,8 @@ #pragma once -#include "attention_base.h" -#include "attention_helper.h" - +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cpu/bert/attention_helper.h" #include "core/common/common.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 6b0c5f395cab0..137612a4bf902 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -3,8 +3,8 @@ #pragma once -#include "attention_base.h" -#include "attention_helper.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cpu/bert/attention_helper.h" #include "core/common/common.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -14,14 +14,31 @@ namespace onnxruntime { namespace contrib { -class GQAAttentionBase : public AttentionBase { +class GQAAttentionBase { protected: - GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) - : AttentionBase(info, require_same_hidden_size) {} + GQAAttentionBase(const OpKernelInfo& info, bool has_local) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); - int local_window_size_; - bool do_rotary_; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; + int local_window_size_; template Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index cad9274e68149..97388a9d6bce8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "group_query_attention.h" -#include "group_query_attention_helper.h" -#include "attention_utils.h" -#include "rotary_embedding.h" -#include "rotary_embedding_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/rotary_helper.h" +#include "contrib_ops/cpu/bert/attention_utils.h" +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" @@ -33,19 +34,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( GroupQueryAttention); template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; -} +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : OpKernel(info), GQAAttentionBase(info, true) {} template Status GroupQueryAttention::Compute(OpKernelContext* context) const { @@ -174,14 +164,14 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (packed_qkv) { const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; - ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV(tp, - parameters.batch_size, - parameters.sequence_length, - parameters.num_heads, - parameters.kv_num_heads, - parameters.head_size, - v_input, - v_rotary)); + ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV(tp, + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.kv_num_heads, + parameters.head_size, + v_input, + v_rotary)); } } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index a7de02452aa58..7ffb72fe55d25 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -263,38 +263,6 @@ Status CheckInputs(const Tensor* query, return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale); } - -template -Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, - int batch_size, - int sequence_length, - int num_heads, - int kv_num_heads, - int head_size, - const T* input, - T* output) { - int seq_stride = head_size; - int head_stride = sequence_length * seq_stride; - int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride; - - const int loop_len = batch_size * sequence_length * kv_num_heads; - const double cost = static_cast(head_size); - ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / kv_num_heads) / sequence_length); - const int s = static_cast((ptr / kv_num_heads) % sequence_length); - const int n = static_cast(ptr % kv_num_heads); - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + block_offset; - T* output_data = output + block_offset; - for (int i = 0; i < head_size; i++) { - output_data[i] = input_data[i]; - } - } - }); - return Status::OK(); -} - } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h new file mode 100644 index 0000000000000..714d962dfb34e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_helper { + +template +Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const T* input, + T* output) { + int seq_stride = head_size; + int head_stride = sequence_length * seq_stride; + int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride; + + const int loop_len = batch_size * sequence_length * kv_num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / kv_num_heads) / sequence_length); + const int s = static_cast((ptr / kv_num_heads) % sequence_length); + const int n = static_cast(ptr % kv_num_heads); + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + for (int i = 0; i < head_size; i++) { + output_data[i] = input_data[i]; + } + } + }); + return Status::OK(); +} + +} // namespace rotary_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index e8ca4370135cc..90a51fda0b188 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -21,6 +21,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); @@ -281,6 +282,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc new file mode 100644 index 0000000000000..e337f41a8688d --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/sparse/sparse_attention.h" +#include "contrib_ops/cpu/sparse/sparse_attention_helper.h" +#include "contrib_ops/cpu/bert/rotary_helper.h" +#include "contrib_ops/cpu/bert/attention_utils.h" +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/graph/onnx_protobuf.h" +#include "core/common/safeint.h" +#include "core/platform/threadpool.h" + +#include +#include + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_TYPED_KERNEL_EX( + SparseAttention, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + SparseAttention); + +template +SparseAttention::SparseAttention(const OpKernelInfo& info) : OpKernel(info), SparseAttentionBase(info) { +} + +template +Status SparseAttention::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_key = context->Input(3); + const Tensor* past_value = context->Input(4); + const Tensor* block_row_indices = context->Input(5); + const Tensor* block_col_indices = context->Input(6); + const Tensor* total_seq_len = context->Input(7); + const Tensor* total_key_lengths = context->Input(8); + const Tensor* cos_cache = context->Input(9); + const Tensor* sin_cache = context->Input(10); + + SparseAttentionParameters parameters = {}; + + // Parameters from node attribute shall be set before calling CheckInputs + parameters.sparse_block_size = sparse_block_size_; + parameters.num_heads = num_heads_; + parameters.kv_num_heads = kv_num_heads_; + parameters.scale = scale_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + ORT_RETURN_IF_ERROR(sparse_attention_helper::CheckInputs(¶meters, + query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + block_row_indices, + block_col_indices, + total_key_lengths, + total_seq_len)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int q_hidden_size = parameters.hidden_size; + + std::vector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(q_hidden_size); + Tensor* output = context->Output(0, output_shape); + + constexpr bool past_present_share_buffer = true; // Only supports share buffer for past and present for now. + parameters.past_present_share_buffer = past_present_share_buffer; + + int head_size = parameters.head_size; + const int cache_length = past_present_share_buffer + ? parameters.max_cache_sequence_length + : parameters.total_sequence_length; + std::vector present_k_shape({static_cast(batch_size), + static_cast(kv_num_heads_), + static_cast(cache_length), + static_cast(head_size)}); + std::vector present_v_shape({static_cast(batch_size), + static_cast(kv_num_heads_), + static_cast(cache_length), + static_cast(head_size)}); + Tensor* present_key = context->Output(1, present_k_shape); + Tensor* present_value = context->Output(2, present_v_shape); + + // Check past and present share buffer. + if (past_present_share_buffer) { + ORT_ENFORCE(past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw()); + } + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto element_type = DataTypeImpl::GetType(); + OrtValue Q; + OrtValue K; + OrtValue V; + + const bool packed_qkv = parameters.is_packed_qkv; + if (packed_qkv) { + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); + } else { + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); + } + + if (do_rotary_) { + rotary_embedding_helper::RotaryParameters rotary_params = {}; + rotary_params.batch_size = batch_size; + rotary_params.sequence_length = sequence_length; + rotary_params.hidden_size = q_hidden_size; + rotary_params.head_size = head_size; + rotary_params.rotary_embedding_dim = parameters.rotary_dim; + rotary_params.num_heads = num_heads_; + rotary_params.max_sequence_length = sequence_length; // unused + rotary_params.seq_stride = head_size; + rotary_params.head_stride = sequence_length * rotary_params.seq_stride; + rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * + rotary_params.head_stride; + rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.transposed = true; + auto* tp = context->GetOperatorThreadPool(); + + const bool is_prompt = parameters.total_sequence_length == parameters.sequence_length; + std::vector pos_ids(is_prompt ? 1 : batch_size * sequence_length); + if (is_prompt) { + pos_ids[0] = static_cast(0); + } else if (sequence_length == 1) { + for (int b = 0; b < batch_size; b++) { + pos_ids[b] = static_cast(total_key_lengths->Data()[b]) - 1; + } + } else { + // This supports a rare case that sequence_length > 1 when it is not prompt. + for (int b = 0; b < batch_size; b++) { + for (int s = 0; s < sequence_length; s++) { + pos_ids[b * sequence_length + s] = static_cast(total_key_lengths->Data()[b]) - + (sequence_length - s); + } + } + } + + const T* q_input; + const T* k_input; + T* q_rotary; + T* k_rotary; + if (packed_qkv) { + OrtValue RotaryQKV; + TensorShape qkv_shape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}); + Tensor::InitOrtValue(element_type, qkv_shape, allocator, RotaryQKV); + q_input = Q.Get().Data(); + k_input = q_input + num_heads_ * sequence_length * head_size; + q_rotary = RotaryQKV.GetMutable()->MutableData(); + k_rotary = q_rotary + num_heads_ * sequence_length * head_size; + Q = RotaryQKV; + } else { + OrtValue RotaryQ; + TensorShape q_shape({batch_size, num_heads_, sequence_length, head_size}); + Tensor::InitOrtValue(element_type, q_shape, allocator, RotaryQ); + OrtValue RotaryK; + TensorShape k_shape({batch_size, kv_num_heads_, sequence_length, head_size}); + Tensor::InitOrtValue(element_type, k_shape, allocator, RotaryK); + q_input = Q.Get().Data(); + k_input = K.Get().Data(); + q_rotary = RotaryQ.GetMutable()->MutableData(); + k_rotary = RotaryK.GetMutable()->MutableData(); + Q = RotaryQ; + K = RotaryK; + } + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, + pos_ids.data(), cos_cache->Data(), + sin_cache->Data(), q_rotary, rotary_interleaved_)); + + rotary_params.num_heads = kv_num_heads_; + rotary_params.hidden_size = parameters.kv_hidden_size; + if (!packed_qkv) { + rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride; + } + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, + pos_ids.data(), cos_cache->Data(), + sin_cache->Data(), k_rotary, rotary_interleaved_)); + if (packed_qkv) { + const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; + T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; + ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV(tp, + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.kv_num_heads, + parameters.head_size, + v_input, + v_rotary)); + } + } + + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + // Compute the attention score and apply the score to V + return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(), + packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, + output, present_key, present_value, + total_key_lengths, block_row_indices, block_col_indices, parameters, allocator, context); +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h new file mode 100644 index 0000000000000..4267d85c0e35d --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/sparse/sparse_attention_base.h" + +namespace onnxruntime { +namespace contrib { + +template +class SparseAttention final : public OpKernel, public SparseAttentionBase { + public: + SparseAttention(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h new file mode 100644 index 0000000000000..cf66bd8407126 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -0,0 +1,390 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_helper.h" + +#include "core/common/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/utils/dump_tensor.h" + +namespace onnxruntime { +namespace contrib { + +class SparseAttentionBase { + protected: + SparseAttentionBase(const OpKernelInfo& info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + int64_t sparse_block_size = 0; + ORT_ENFORCE(info.GetAttr("sparse_block_size", &sparse_block_size).IsOK()); + sparse_block_size_ = static_cast(sparse_block_size); + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + bool do_rotary_; // whether or not to use rotary embeddings + bool rotary_interleaved_; + int sparse_block_size_; + + template + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxN_kvxSxH + const T* V, // V data with shape BxN_kvxSxH + const Tensor* past_key, // past K input tensor + const Tensor* past_value, // past V input tensor + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor + Tensor* present_value, // present V output tensor + const Tensor* total_key_lengths, // total key lengths tensor + const Tensor* block_row_indices, // block row indices + const Tensor* block_col_indices, // block column indices + SparseAttentionParameters& parameters, // attention parameters + AllocatorPtr allocator, // allocator for temporary tensors + OpKernelContext* context) const { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int head_size = parameters.head_size; + const bool packed_qkv = parameters.is_packed_qkv; + + int past_buffer_sequence_length = static_cast(past_key->Shape().GetDims()[2]); + int present_buffer_sequence_length = static_cast(present_key->Shape().GetDims()[2]); + + // Allocate a buffer to store Softmax(QK) + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * parameters.total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + + bool past_present_share_buffer = parameters.past_present_share_buffer; + assert(past_present_share_buffer); + + auto* tp = context->GetOperatorThreadPool(); + + const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + ComputeAttentionProbs( + static_cast(attention_probs), Q, k, total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, + past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, + block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + + // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) + const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + ComputeVxAttentionScore( + output->MutableData(), static_cast(attention_probs), v, + total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + + return Status::OK(); + } + + private: + // Helper function to compute the attention probs. It does 2 things: + // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) + // attention_probs(B, N, S, T) = Softmax(attention_probs) + template + void ComputeAttentionProbs( + T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // query start pointer + const T* K, // key start pointer + const int32_t* total_key_lengths, // total key sequence lengths (past + new) + int batch_size, // batch size + int sequence_length, // sequence length of query or new key + int total_sequence_length, // maximum past_sequence_length + sequence_length + int past_buffer_sequence_length, // sequence length of past_key or past_value + int present_buffer_sequence_length, // sequence length of present_key or present_value + int head_size, // head size of query + const T* past_key, // past key + T* present_key, // present key + bool past_present_share_buffer, // whether past_key and present_key share the buffer + bool packed_qkv, // whether Q, K, V are packed + const int32_t* block_row_indices, // block row indices + const int32_t* block_col_indices, // block column indices + SparseAttentionParameters& parameters, // parameters + ThreadPool* tp) const { // thread pool + const bool is_prompt = (total_sequence_length == sequence_length); + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const int kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H + const size_t kv_input_chunk_length = q_input_chunk_length; + const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; + const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; + + const int loop_len = batch_size * num_heads_; + const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + TensorOpCost unit_cost; + const ptrdiff_t probs_matrix_bytes = + SafeInt(sequence_length) * total_sequence_length * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); + + unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); + unit_cost.bytes_stored += static_cast(probs_matrix_bytes); + + // Cost to concatenate current key to cache (assume past and present share buffer). + double bytes_to_copy_key = static_cast(sizeof(T) * sequence_length * head_size); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + + DUMP_CPU_TENSOR_INIT(); + DUMP_CPU_TENSOR("block_row_indices", block_row_indices, parameters.num_sparse_layout, parameters.stride_row_indices); + DUMP_CPU_TENSOR("block_col_indices", block_col_indices, parameters.num_sparse_layout, parameters.stride_col_indices); + + // Check whether each layout has sparse (has zero in lower triangular) + std::vector layout_has_sparse(parameters.num_sparse_layout); + for (int layout_index = 0; layout_index < parameters.num_sparse_layout; layout_index++) { + int nonzero_elements = block_row_indices[(layout_index + 1) * parameters.stride_row_indices - 1]; + int dense_nonzero = (parameters.stride_row_indices * (parameters.stride_row_indices - 1)) / 2; + layout_has_sparse[layout_index] = nonzero_elements < dense_nonzero; + DUMP_STRING("layout_has_sparse[", layout_index, "]=", layout_has_sparse[layout_index]); + } + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + DUMP_STRING("batch_size=", batch_size, ",num_heads=", num_heads_, ",loop_len=", loop_len, ",begin=", begin, ",end=", end); + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i) / num_heads_; + const int head_index = static_cast(i) % num_heads_; + const int past_seq_len = is_prompt ? 0 : (static_cast(total_key_lengths[batch_index]) - sequence_length); + const size_t past_chunk_length = static_cast(past_seq_len) * head_size; + const int total_seq_len = total_key_lengths[batch_index]; + + const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; + T* output = attention_probs + output_offset; + + const T* k; + if (packed_qkv) { + k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + k = K + kv_input_chunk_length * (i / kv_num_heads_factor); + } + + // Concatenate past_k + k -> present_k + // TODO: avoid copying mutiple times for a group. + k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + const T* q; + if (packed_qkv) { + q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index; + } else { + q = Q + q_input_chunk_length * i; + } + + DUMP_STRING("i=", i, ",batch_index=", batch_index, ",head_index=", head_index, + ",past_seq_len=", past_seq_len, ",total_seq_len=", total_seq_len, ",packed_qkv=", packed_qkv); + DUMP_CPU_TENSOR("Q", q, sequence_length, head_size); + DUMP_CPU_TENSOR("K", k, total_seq_len, head_size); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, alpha, q, + head_size, k, head_size, 0.0f /*bata*/, output, total_seq_len, + nullptr); + + DUMP_CPU_TENSOR("QK", output, sequence_length, total_seq_len); + + // Compute Softmax for causal and output result in place. + T* output_softmax = output; + + int layout_id = head_index % parameters.num_sparse_layout; + bool is_sparse_layout = layout_has_sparse[layout_id]; + + DUMP_STRING("layout_id=", layout_id, ",is_sparse_layout=", is_sparse_layout); + + if (!is_sparse_layout) { // dense + for (int q_id = 0; q_id < sequence_length; q_id++) { + int causal_length = past_seq_len + q_id + 1; + ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr); + for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) { + output_softmax[remain_seq_id] = 0.f; + } + output_softmax += total_seq_len; + } + } else { // sparse + int q_id = 0; + bool has_sparse = false; + std::vector mask(parameters.max_sequence_length); + + const int32_t* layout_row_indices = block_row_indices + layout_id * parameters.stride_row_indices; + const int32_t* layout_col_indices = block_col_indices + layout_id * parameters.stride_col_indices; + do { + int q_abs_position = past_seq_len + q_id; + int causal_length = q_abs_position + 1; + + // Update mask when query token is the first or at the boundary of sparse block. + if (q_id == 0 || q_abs_position % parameters.sparse_block_size == 0) { + int row_in_sparse_layout = q_abs_position / parameters.sparse_block_size; + int start_in_col_indices = layout_row_indices[row_in_sparse_layout]; + int end_in_col_indices = layout_row_indices[row_in_sparse_layout + 1]; + int nonzero_blocks = end_in_col_indices - start_in_col_indices; + has_sparse = (nonzero_blocks != row_in_sparse_layout + 1); + + DUMP_STRING("q_id=", q_id, + ",q_abs_position=", q_abs_position, + ",sparse_block_size=", parameters.sparse_block_size, + ",row_in_sparse_layout=", row_in_sparse_layout, + ",start_in_col_indices=", start_in_col_indices, + ",end_in_col_indices=", end_in_col_indices, + ",nonzero_blocks=", nonzero_blocks, + ",has_sparse=", has_sparse); + + // Expand attention mask for current row of q_id + if (has_sparse) { + int block_aligned_length = q_abs_position / parameters.sparse_block_size * parameters.sparse_block_size + parameters.sparse_block_size; + DUMP_STRING("block_aligned_length=", block_aligned_length); + + std::fill_n(mask.begin(), block_aligned_length, 0); + for (int j = start_in_col_indices; j < end_in_col_indices; j++) { + int col_in_sparse_layout = layout_col_indices[j]; + + int offset = col_in_sparse_layout * parameters.sparse_block_size; + for (int s = 0; s < parameters.sparse_block_size; s++, offset++) { + mask[offset] = 1; + } + } + + DUMP_CPU_TENSOR("mask", mask, block_aligned_length); + } + } + + // Update inline according to attention mask. + if (has_sparse) { + for (int s = 0; s < causal_length; s++) { + if (mask[s] == 0) + output_softmax[s] = std::numeric_limits::lowest(); + } + } + ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr); + + for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) { + output_softmax[remain_seq_id] = 0.f; + } + + output_softmax += total_seq_len; + q_id++; + + } while (q_id < sequence_length); + } + + DUMP_CPU_TENSOR("softmax", output, sequence_length, total_seq_len); + } + }); + } + + template + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + const T* attention_probs, // Softmax of Q*K' with size BxNxSxT + const T* V, // v value with size BxN_kvxSxH + const int32_t* total_key_lengths, // total sequence lengths + int batch_size, // batch size + int sequence_length, // sequence length + int total_sequence_length, // maximum past_sequence_length + sequence_length + int past_buffer_sequence_length, // sequence length in past state + int present_buffer_sequence_length, // sequence length in past state + int head_size, // head size of Q, K, V + int hidden_size, // hidden size of Output + const T* past_value, // past value only + T* present_value, // present value only + bool past_present_share_buffer, // whether past_key and present_key share the buffer + bool packed_qkv, // whether Q, K, V are packed + ThreadPool* tp) const { + const bool is_prompt = sequence_length == total_sequence_length; + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const int kv_num_heads_factor = num_heads_ / kv_num_heads_; + + const int kv_input_chunk_length = sequence_length * head_size; // S x H + const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; + const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; + + // The cost of Gemm. + TensorOpCost unit_cost; + // Here we use total_sequence_length to estimate total_key_lengths[batch_index] used in GEMM. + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(sequence_length + head_size) * + total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T)); + + if (present_value) { + double bytes_to_copy_value = static_cast(sizeof(T) * sequence_length * head_size); + unit_cost.bytes_loaded += bytes_to_copy_value; + unit_cost.bytes_stored += bytes_to_copy_value; + } + + DUMP_CPU_TENSOR_INIT(); + + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + DUMP_STRING("batch_size=", batch_size, ",num_heads=", num_heads_, ",begin=", begin, ",end=", end); + + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + const int past_seq_len = is_prompt ? 0 : (static_cast(total_key_lengths[batch_index]) - sequence_length); + const size_t past_chunk_length = static_cast(past_seq_len) * head_size; + const int total_seq_len = total_key_lengths[batch_index]; + + DUMP_STRING("i=", i, ",batch_index=", batch_index, ",head_index=", head_index, + ",past_seq_len=", past_seq_len, ",total_seq_len=", total_seq_len, ",packed_qkv=", packed_qkv); + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + + // Concatenate past_v + v -> present_v + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + + DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size); + + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_seq_len * i; + + DUMP_CPU_TENSOR("attention_probs", attention_probs + attention_probs_offset, sequence_length, total_seq_len); + + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len, + 1.f, /*alpha*/ + attention_probs + attention_probs_offset, total_seq_len, v, + head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); + + DUMP_CPU_TENSOR("out", attention_probs + attention_probs_offset, sequence_length, head_size); + } + }); + } +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h similarity index 98% rename from onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h rename to onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h index a5f1d50e618af..ca69370b4ce17 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h @@ -21,7 +21,7 @@ Status CheckInputs(void* params, const Tensor* sin_cache, const Tensor* block_row_indices, const Tensor* block_col_indices, - const Tensor* seqlens_k_total, + const Tensor* total_key_lengths, const Tensor* total_seq_len) { // No packing for q/k/v: // query (batch_size, sequence_length, num_heads * head_size) @@ -36,7 +36,7 @@ Status CheckInputs(void* params, // past_value (batch_size, kv_num_heads, max_cache_sequence_length, head_size) // block_row_indices (num_layout, max_blocks + 1), where max_blocks = max_sequence_length / sparse_block_size // block_col_indices (num_layout, max_nnz) - // seqlens_k_total (batch_size) when do_rotary is True, optional otherwise + // total_key_lengths (batch_size) // total_seq_len (1) // cos_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true. // sin_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true. @@ -197,7 +197,7 @@ Status CheckInputs(void* params, } // Check the shape of total_key_sequence_lengths. We do not check the values here. - const auto& k_len_dim = seqlens_k_total->Shape().GetDims(); + const auto& k_len_dim = total_key_lengths->Shape().GetDims(); if (k_len_dim.size() != 1 && k_len_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_total_sequence_lengths must have shape (batch_size)."); diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 7d3f6eb9295d8..865a1dc29ce47 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -3,7 +3,7 @@ #include "contrib_ops/cuda/sparse/sparse_attention_impl.h" #include "contrib_ops/cuda/sparse/sparse_attention.h" -#include "contrib_ops/cuda/sparse/sparse_attention_helper.h" +#include "contrib_ops/cpu/sparse/sparse_attention_helper.h" #include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h" #include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h" #include "core/platform/env_var_utils.h" diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 2a14ba1db4bb7..7272a949f7218 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1254,7 +1254,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present_value", "Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).", "T") - .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { SparseAttentionTypeAndShapeInference(ctx, 3); diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index f33a56ee4e1f9..f18bcdba65579 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -8,14 +8,16 @@ """ import math import unittest -from typing import Optional +from typing import Optional, Union import torch +from benchmark_mha import InputFormats from onnx import TensorProto, helper +from parameterized import parameterized from torch import Tensor -from onnxruntime import InferenceSession, SessionOptions -from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager +from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from onnxruntime.transformers.io_binding_helper import CudaSession ENABLE_DEBUG = False @@ -34,6 +36,7 @@ def __init__( softmax_scale: Optional[float], do_rotary: bool, rotary_interleaved: bool, + provider: str = "CUDAExecutionProvider", device="cuda", dtype=torch.float16, share_buffer: bool = True, @@ -62,11 +65,13 @@ def __init__( self.do_rotary = do_rotary self.rotary_interleaved = rotary_interleaved + + self.provider = provider self.device = device + self.dtype = dtype self.share_buffer = share_buffer self.is_packed_qkv = is_packed_qkv - self.dtype = dtype def shape_dict(self): shapes = { @@ -106,7 +111,7 @@ def get_cos_sin_cache(self, dtype): def random_inputs(self): device = self.device # Since bfloat16 is not supported in ORT python I/O binding API, we always use float16 as model inputs. - dtype = torch.float16 + dtype = torch.float16 if self.dtype == torch.bfloat16 else self.dtype # Always use non-packed qkv to generate same inputs for Torch and ORT. packed = self.is_packed_qkv # Save the original value. @@ -153,7 +158,9 @@ def __init__( softmax_scale=None, do_rotary: bool = False, rotary_interleaved: bool = False, + provider: str = "CUDAExecutionProvider", device="cuda", + dtype=torch.float16, local_window_size: int = -1, attention_mask=None, is_packed_qkv=False, @@ -162,17 +169,19 @@ def __init__( ): super().__init__( "GroupQueryAttention", - batch_size, - sequence_length, - max_sequence_length, - past_sequence_length, - num_heads, - kv_num_heads, - head_size, - softmax_scale, - do_rotary, - rotary_interleaved, - device, + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + do_rotary=do_rotary, + rotary_interleaved=rotary_interleaved, + provider=provider, + device=device, + dtype=dtype, is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, @@ -220,24 +229,28 @@ def __init__( softmax_scale=None, do_rotary: bool = False, rotary_interleaved: bool = False, + provider: str = "CUDAExecutionProvider", device="cuda", + dtype=torch.float16, is_packed_qkv=False, max_cache_sequence_length=None, max_rotary_sequence_length=None, ): super().__init__( "SparseAttention", - batch_size, - sequence_length, - max_sequence_length, - past_sequence_length, - num_heads, - kv_num_heads, - head_size, - softmax_scale, - do_rotary, - rotary_interleaved, - device, + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + do_rotary=do_rotary, + rotary_interleaved=rotary_interleaved, + provider=provider, + device=device, + dtype=dtype, is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, @@ -288,17 +301,19 @@ def random_inputs(self): def get_comparable_ort_gqa_config(self, use_local=False) -> GroupQueryAttentionConfig: return GroupQueryAttentionConfig( - self.batch_size, - self.sequence_length, - self.max_sequence_length, - self.past_sequence_length, - self.num_heads, - self.kv_num_heads, - self.head_size, - self.softmax_scale, - self.do_rotary, - self.rotary_interleaved, - self.device, + batch_size=self.batch_size, + sequence_length=self.sequence_length, + max_sequence_length=self.max_sequence_length, + past_sequence_length=self.past_sequence_length, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + head_size=self.head_size, + softmax_scale=self.softmax_scale, + do_rotary=self.do_rotary, + rotary_interleaved=self.rotary_interleaved, + provider=self.provider, + device=self.device, + dtype=self.dtype, local_window_size=self.local_blocks * self.sparse_block_size if use_local else -1, is_packed_qkv=self.is_packed_qkv, max_cache_sequence_length=self.max_cache_sequence_length, @@ -314,17 +329,19 @@ def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttenti attention_mask = attention_mask[:, :, -self.sequence_length :, :] return GroupQueryAttentionConfig( - self.batch_size, - self.sequence_length, - self.max_sequence_length, - self.past_sequence_length, - self.num_heads, - self.kv_num_heads, - self.head_size, - self.softmax_scale, - self.do_rotary, - self.rotary_interleaved, - self.device, + batch_size=self.batch_size, + sequence_length=self.sequence_length, + max_sequence_length=self.max_sequence_length, + past_sequence_length=self.past_sequence_length, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + head_size=self.head_size, + softmax_scale=self.softmax_scale, + do_rotary=self.do_rotary, + rotary_interleaved=self.rotary_interleaved, + provider=self.provider, + device=self.device, + dtype=self.dtype, attention_mask=attention_mask, is_packed_qkv=False, # torch reference implementation does not support packed qkv. max_cache_sequence_length=self.max_cache_sequence_length, @@ -375,7 +392,7 @@ def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size): def create_sparse_attention_onnx_model(config: SparseAttentionConfig): # ORT Python I/O binding API does not support bf16, so always use fp16 as graph inputs/outputs. - io_float_type = TensorProto.FLOAT16 + io_float_type = TensorProto.FLOAT if config.dtype == torch.float32 else TensorProto.FLOAT16 suffix = "_bf16" if config.dtype == torch.bfloat16 else "" nodes = [ @@ -487,9 +504,9 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig): def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): - assert config.dtype == torch.float16 + assert config.dtype in [torch.float16, torch.float32] - float_type = TensorProto.FLOAT16 + float_type = TensorProto.FLOAT16 if config.dtype in [torch.float16] else TensorProto.FLOAT nodes = [ helper.make_node( "GroupQueryAttention", @@ -599,7 +616,10 @@ def group_query_attention_reference( attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) result = attn_output.transpose(1, 2).contiguous() - torch.cuda.synchronize() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + return result @@ -671,25 +691,42 @@ def infer(self): ) +def create_ort_session( + config: Union[SparseAttentionConfig, GroupQueryAttentionConfig], session_options=None, enable_cuda_graph=False +) -> CudaSession: + if isinstance(config, SparseAttentionConfig): + onnx_model_str = create_sparse_attention_onnx_model(config) + else: + onnx_model_str = create_group_query_attention_onnx_model(config) + + if config.provider == "CUDAExecutionProvider": + device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index + provider_options = CudaSession.get_cuda_provider_options( + device_id, enable_cuda_graph=enable_cuda_graph, stream=torch.cuda.current_stream().cuda_stream + ) + providers = [(config.provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + # Note that CudaSession could work with both CUDA and CPU providers. + cuda_session = CudaSession(ort_session, config.device, enable_cuda_graph=enable_cuda_graph) + shape_dict = config.shape_dict() + cuda_session.allocate_buffers(shape_dict) + + buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} + for input_name, output_name in buffer_sharing.items(): + cuda_session.set_buffer_sharing(input_name, output_name) + + return cuda_session + + class OrtGroupQueryAttention: """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" def __init__(self, config: GroupQueryAttentionConfig): - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_group_query_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=config.device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) + self.session = create_ort_session(config) + self.feed_dict = config.random_inputs() if ENABLE_DEBUG and not config.is_packed_qkv: @@ -709,28 +746,14 @@ def __init__(self, config: GroupQueryAttentionConfig): print("seqlens_k (BSNH, GQA)", self.feed_dict["seqlens_k"]) def infer(self): - return self.gpu_binding.infer(self.feed_dict) + return self.session.infer(self.feed_dict) class OrtSparseAttention: """A wrapper of ORT SparseAttention to test relevance and performance.""" def __init__(self, config: SparseAttentionConfig): - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_sparse_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=config.device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) + self.session = create_ort_session(config) self.feed_dict = config.random_inputs() if ENABLE_DEBUG and not config.is_packed_qkv: @@ -753,19 +776,196 @@ def __init__(self, config: SparseAttentionConfig): print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"]) def infer(self): - return self.gpu_binding.infer(self.feed_dict) + return self.session.infer(self.feed_dict) + + +def get_provider_support_info(provider: str, use_kv_cache: bool): + if provider == "CUDAExecutionProvider": + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.QKV_BSN3H] + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + else: + assert provider == "CPUExecutionProvider" + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.QKV_BSN3H] + device = torch.device("cpu") + dtype = torch.float + return device, dtype, formats + + +def has_cuda_support(): + if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm in [75, 80, 86, 89, 90] + + return False + + +def get_simple_test_case(provider: str, has_past_kv: bool): + """A simple test case for debugging purpose.""" + device, dtype, _formats = get_provider_support_info(provider, False) + if provider == "CPUExecutionProvider": + # A simple case for debugging purpose. + max_sequence_length = 16 + sequence_length = 15 + packed_qkv = False + config = SparseAttentionConfig( + batch_size=1, + sequence_length=1 if has_past_kv else sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=sequence_length if has_past_kv else 0, + num_heads=4, + kv_num_heads=2, + head_size=8, + sparse_block_size=4, + num_layout=2, + local_blocks=2, + vert_stride=2, + softmax_scale=0.0, + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=packed_qkv, + max_cache_sequence_length=max_sequence_length, + ) + yield config + + +def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rotary=False): + if provider == "CUDAExecutionProvider" and not has_cuda_support(): + return + yield + + device, dtype, formats = get_provider_support_info(provider, False) + batch_sizes = [1, 2, 3] + sequence_lengths = [1, 64, 127, 128, 192, 256] + heads = [4, 8, 16] + + # SparseAttention CUDA kernel only supports head size 128 + head_sizes = [128] if provider == "CUDAExecutionProvider" else [128, 256] + + if comprehensive: + for batch_size in batch_sizes: + for sequence_length in sequence_lengths: + for num_heads in heads: + for head_size in head_sizes: + for format in formats: + packed_qkv = format == InputFormats.QKV_BSN3H + + non_prompt_len = 1 + if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary: + # Generate case of sequence_length > 1 when it is not prompt for CPU provider. + non_prompt_len = batch_size + + query_sequence_length = non_prompt_len if has_past_kv else sequence_length + config = SparseAttentionConfig( + batch_size=batch_size, + sequence_length=query_sequence_length, + max_sequence_length=256, + past_sequence_length=( + min(256 - query_sequence_length, sequence_length) if has_past_kv else 0 + ), + num_heads=num_heads, + kv_num_heads=num_heads // 2, + head_size=head_size, + sparse_block_size=64, + num_layout=2, + local_blocks=2, + vert_stride=2, + softmax_scale=1.8 / (128**0.5), + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=packed_qkv, + do_rotary=do_rotary, + rotary_interleaved=sequence_length <= 128, + max_cache_sequence_length=None if sequence_length >= 128 else 128, + ) + yield config + else: + test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) + for i in range(test_cases): + batch_size = batch_sizes[i % len(batch_sizes)] + sequence_length = sequence_lengths[i % len(sequence_lengths)] + num_heads = heads[i % len(heads)] + head_size = head_sizes[i % len(head_sizes)] + format = formats[i % len(formats)] + packed_qkv = format == InputFormats.QKV_BSN3H + + non_prompt_len = 1 + if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary: + # Generate case of sequence_length > 1 when it is not prompt for CPU provider. + non_prompt_len = batch_size + + query_sequence_length = non_prompt_len if has_past_kv else sequence_length + + config = SparseAttentionConfig( + batch_size=batch_size, + sequence_length=query_sequence_length, + max_sequence_length=256, + past_sequence_length=min(256 - query_sequence_length, sequence_length) if has_past_kv else 0, + num_heads=num_heads, + kv_num_heads=num_heads // 2, + head_size=head_size, + sparse_block_size=64, + num_layout=2, + local_blocks=2, + vert_stride=2, + softmax_scale=1.8 / (128**0.5), + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=packed_qkv, + do_rotary=do_rotary, + rotary_interleaved=sequence_length <= 128, + max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer. + ) + yield config + + +# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. +comprehensive_mode = False class TestSparseAttention(unittest.TestCase): - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + @unittest.skipUnless(has_cuda_support(), "cuda not available") def test_sparse_attention(self): major, minor = torch.cuda.get_device_capability() sm = major * 10 + minor + self.run_relevance_test(sm) - if sm not in [75, 80, 86, 89, 90]: - self.skipTest("SparseAttention is not supported on this GPU") + @parameterized.expand(get_simple_test_case("CPUExecutionProvider", True), skip_on_empty=True) + def test_simple_token_cpu(self, config: SparseAttentionConfig): + self.run_one_relevance_test(config) - self.run_relevance_test(sm) + @parameterized.expand(get_simple_test_case("CPUExecutionProvider", False), skip_on_empty=True) + def test_simple_prompt_cpu(self, config: SparseAttentionConfig): + self.run_one_relevance_test(config) + + @parameterized.expand( + get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True), skip_on_empty=True + ) + def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig): + # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense. + if config.sparse_block_size * config.local_blocks > config.total_sequence_length: + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CUDAExecutionProvider", True, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_token_gpu(self, config): + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_token_cpu(self, config): + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_prompt_cpu(self, config): + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CUDAExecutionProvider", False, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_prompt_gpu(self, config): + self.run_one_relevance_test(config) def run_one_relevance_test(self, config: SparseAttentionConfig): if (not config.do_rotary) and config.total_sequence_length <= 2048: @@ -774,6 +974,10 @@ def run_one_relevance_test(self, config: SparseAttentionConfig): obj = TorchGroupQueryAttention(gqa_config) expected_out = obj.infer() else: + if config.dtype == torch.bfloat16: + # Skip test since create_group_query_attention_onnx_model does not support bfloat16 right now. + return + # Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only). gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False) obj = OrtGroupQueryAttention(gqa_config) @@ -881,6 +1085,8 @@ def run_relevance_no_past_128k(self, sm: int, device): local_blocks=2048, # use dense to compare with GQA vert_stride=8, softmax_scale=None, + provider="CUDAExecutionProvider", + dtype=torch.float16, device=device, is_packed_qkv=packed_qkv, ) @@ -907,6 +1113,8 @@ def run_relevance_past_128k(self, sm: int, device): local_blocks=2048, # use dense to compare with GQA vert_stride=8, softmax_scale=None, + provider="CUDAExecutionProvider", + dtype=torch.float16, device=device, is_packed_qkv=packed_qkv, ) @@ -921,7 +1129,8 @@ def run_relevance_test(self, sm: int): device = torch.device("cuda", device_id) with torch.no_grad(): # Test long sequence when GPU memory is enough (need about 12 GB for 128K sequence length) - if torch.cuda.get_device_properties(device_id).total_memory > 13 * 1024 * 1024 * 1024: + # The 128k tests fails randomly in T4 GPU, increase memory threshold for now. + if torch.cuda.get_device_properties(device_id).total_memory > 20 * 1024 * 1024 * 1024: self.run_relevance_no_past_128k(sm, device) self.run_relevance_past_128k(sm, device) self.run_relevance_no_past(sm, device) From 40d4b2ec75a17d48e16d2794d3d6ced39333a265 Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Thu, 4 Jul 2024 14:32:28 +0800 Subject: [PATCH 09/13] exclude split3inner kernel on rocm ep (#21238) ### Description There is an issue when using split3inner kernel on rocm-6.0.3, exclude these code from rocm EP. --- onnxruntime/core/providers/cuda/tensor/split.cc | 2 ++ onnxruntime/core/providers/cuda/tensor/split_impl.cu | 2 ++ 2 files changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index ca82387600085..52775b2e8be7a 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -76,6 +76,7 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { auto input_dims = input_shape.GetDims(); auto output_dimensions{input_shape.AsShapeVector()}; +#ifndef USE_ROCM if (split_sizes.size() == 3 && ((axis + 1) == gsl::narrow_cast(input_shape.NumDimensions()))) { // we use (axis + 1) == num_dimensions to check if we are splitting on inner most axis. // only when split on inner axis and output size is 3, we can use Split3Inner. @@ -100,6 +101,7 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { output2->MutableDataRaw(), input_dims); } +#endif CudaAsyncBuffer output_ptr(this, num_outputs); gsl::span output_ptr_span = output_ptr.CpuSpan(); diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index 00f94694f83c0..e2f42e4d5855c 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -157,6 +157,7 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block return Status::OK(); } +#ifndef USE_ROCM template __global__ void _Split3InnerKernel(const int64_t size0_in_byte, const int64_t size1_in_byte, @@ -263,6 +264,7 @@ Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t return Status::OK(); } +#endif } // namespace cuda } // namespace onnxruntime From 07c429191e19678b97ec8fe818ecb9a64ac6b394 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 4 Jul 2024 00:54:13 -0700 Subject: [PATCH 10/13] Delete path.h (#21211) ### Description Delete path.h and replace all occurrences of onnxruntime::Path with std::filesystem::path. Previously we couldn't use C++17's std::filesystem because it was not supported in iOS 12(which was released in 2018). Now we dropped the support for iOS 12. ### Motivation and Context To simplify code. For example, if an EP wants to use the Path class, now it can directly use it without going through a wrapper. And the standard implementation can handle various path types better. (We didn't take much consideration on UNC path, "/" as a path separator on Windows, etc). --- include/onnxruntime/core/graph/graph.h | 2 +- onnxruntime/core/common/path.cc | 308 ------------------ onnxruntime/core/common/path.h | 106 ------ .../debug_node_inputs_outputs_utils.cc | 10 +- .../debug_node_inputs_outputs_utils.h | 5 +- onnxruntime/core/framework/tensorprotoutils.h | 1 - onnxruntime/core/graph/model.h | 1 - onnxruntime/core/optimizer/initializer.cc | 1 - onnxruntime/core/optimizer/initializer.h | 1 - .../shared_library/provider_interfaces.h | 7 - .../shared_library/provider_wrappedtypes.h | 13 - onnxruntime/core/providers/tvm/tvm_api.cc | 1 - .../providers/vitisai/imp/tensor_proto.cc | 3 +- .../core/session/provider_bridge_ort.cc | 8 +- onnxruntime/test/common/path_test.cc | 256 --------------- .../debug_node_inputs_outputs_utils_test.cc | 4 +- .../test/flatbuffers/flatbuffer_utils_test.cc | 1 - onnxruntime/test/ir/graph_test.cc | 8 +- .../test/optimizer/graph_transform_test.cc | 7 +- .../core/framework/checkpoint_common.h | 1 - .../core/framework/checkpointing.cc | 10 +- .../models/runner/training_runner.cc | 5 +- 22 files changed, 25 insertions(+), 734 deletions(-) delete mode 100644 onnxruntime/core/common/path.cc delete mode 100644 onnxruntime/core/common/path.h delete mode 100644 onnxruntime/test/common/path_test.cc diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 538cbfdcefc47..7dabe42ba0a28 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -17,12 +17,12 @@ #include "core/common/gsl.h" #include "core/common/common.h" +#include "core/common/path_string.h" #include "core/common/const_pointer_container.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/common/inlined_containers.h" #endif #include "core/common/inlined_containers_fwd.h" -#include "core/common/path.h" #include "core/common/span_utils.h" #include "core/common/status.h" #include "core/common/logging/logging.h" diff --git a/onnxruntime/core/common/path.cc b/onnxruntime/core/common/path.cc deleted file mode 100644 index 8b74d2d8c9c1c..0000000000000 --- a/onnxruntime/core/common/path.cc +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/path.h" - -#include -#include - -namespace onnxruntime { - -namespace { - -constexpr auto k_dot = ORT_TSTR("."); -constexpr auto k_dotdot = ORT_TSTR(".."); - -constexpr std::array k_valid_path_separators{ - ORT_TSTR('/'), ORT_TSTR('\\')}; - -constexpr bool IsPreferredPathSeparator(PathChar c) { - return c == k_preferred_path_separator; -} - -PathString NormalizePathSeparators(const PathString& path) { - PathString result{}; - std::replace_copy_if( - path.begin(), path.end(), std::back_inserter(result), - [](PathChar c) { - return std::find( - k_valid_path_separators.begin(), - k_valid_path_separators.end(), - c) != k_valid_path_separators.end(); - }, - k_preferred_path_separator); - return result; -} - -// parse component and trailing path separator -PathString::const_iterator ParsePathComponent( - PathString::const_iterator begin, PathString::const_iterator end, - PathString::const_iterator& component_end, bool* has_trailing_separator) { - component_end = std::find_if(begin, end, IsPreferredPathSeparator); - const auto sep_end = std::find_if_not(component_end, end, IsPreferredPathSeparator); - if (has_trailing_separator) *has_trailing_separator = sep_end != component_end; - return sep_end; -} - -#ifdef _WIN32 - -Status ParsePathRoot( - const PathString& path, - PathString& root, bool& has_root_dir, size_t& num_parsed_chars) { - // assume NormalizePathSeparators() has been called - - // drive letter - if (path.size() > 1 && - (ORT_TSTR('a') <= path[0] && path[0] <= ORT_TSTR('z') || - ORT_TSTR('A') <= path[0] && path[0] <= ORT_TSTR('Z')) && - path[1] == ORT_TSTR(':')) { - const auto root_dir_begin = path.begin() + 2; - const auto root_dir_end = std::find_if_not(root_dir_begin, path.end(), IsPreferredPathSeparator); - - root = path.substr(0, 2); - has_root_dir = root_dir_begin != root_dir_end; - num_parsed_chars = std::distance(path.begin(), root_dir_end); - return Status::OK(); - } - - // leading path separator - auto curr_it = std::find_if_not(path.begin(), path.end(), IsPreferredPathSeparator); - const auto num_initial_seps = std::distance(path.begin(), curr_it); - - if (num_initial_seps == 2) { - // "\\server_name\share_name\" - // after "\\", parse 2 path components with trailing separators - PathString::const_iterator component_end; - bool has_trailing_separator; - curr_it = ParsePathComponent(curr_it, path.end(), component_end, &has_trailing_separator); - ORT_RETURN_IF_NOT(has_trailing_separator, "Failed to parse path root: ", ToUTF8String(path)); - curr_it = ParsePathComponent(curr_it, path.end(), component_end, &has_trailing_separator); - ORT_RETURN_IF_NOT(has_trailing_separator, "Failed to parse path root: ", ToUTF8String(path)); - - root.assign(path.begin(), component_end); - has_root_dir = true; - num_parsed_chars = std::distance(path.begin(), curr_it); - } else { - // "\", "" - root.clear(); - has_root_dir = num_initial_seps > 0; - num_parsed_chars = num_initial_seps; - } - - return Status::OK(); -} - -#else // POSIX - -Status ParsePathRoot( - const PathString& path, - PathString& root, bool& has_root_dir, size_t& num_parsed_chars) { - // assume NormalizePathSeparators() has been called - auto curr_it = std::find_if_not(path.begin(), path.end(), IsPreferredPathSeparator); - const auto num_initial_seps = std::distance(path.begin(), curr_it); - - if (num_initial_seps == 2) { - // "//root_name/" - // after "//", parse path component with trailing separator - PathString::const_iterator component_end; - bool has_trailing_separator; - curr_it = ParsePathComponent(curr_it, path.end(), component_end, &has_trailing_separator); - ORT_RETURN_IF_NOT(has_trailing_separator, "Failed to parse path root: ", ToUTF8String(path)); - - root.assign(path.begin(), component_end); - has_root_dir = true; - num_parsed_chars = std::distance(path.begin(), curr_it); - } else { - // "/", "" - root.clear(); - has_root_dir = num_initial_seps > 0; - num_parsed_chars = num_initial_seps; - } - - return Status::OK(); -} - -#endif - -} // namespace - -Status Path::Parse(const PathString& original_path_str, Path& path) { - Path result{}; - - // normalize separators - const PathString path_str = NormalizePathSeparators(original_path_str); - - // parse root - size_t root_length = 0; - ORT_RETURN_IF_ERROR(ParsePathRoot( - path_str, result.root_name_, result.has_root_dir_, root_length)); - - // parse components - PathString::const_iterator component_begin = path_str.begin() + root_length; - while (component_begin != path_str.end()) { - PathString::const_iterator component_end; - PathString::const_iterator next_component_begin = ParsePathComponent( - component_begin, path_str.end(), component_end, nullptr); - result.components_.emplace_back(component_begin, component_end); - component_begin = next_component_begin; - } - - path = std::move(result); - return Status::OK(); -} - -Path Path::Parse(const PathString& path_str) { - Path path{}; - const auto status = Parse(path_str, path); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - return path; -} - -PathString Path::ToPathString() const { - PathString result = GetRootPathString(); - const size_t components_size = components_.size(); - for (size_t i = 0; i < components_size; ++i) { - result += components_[i]; - if (i + 1 < components_size) result += k_preferred_path_separator; - } - return result; -} - -PathString Path::GetRootPathString() const { - return has_root_dir_ ? root_name_ + k_preferred_path_separator : root_name_; -} - -bool Path::IsEmpty() const { - return !has_root_dir_ && root_name_.empty() && components_.empty(); -} - -bool Path::IsAbsolute() const { -#ifdef _WIN32 - return has_root_dir_ && !root_name_.empty(); -#else // POSIX - return has_root_dir_; -#endif -} - -Path Path::ParentPath() const { - Path parent{*this}; - if (!parent.components_.empty()) parent.components_.pop_back(); - return parent; -} - -Path& Path::Normalize() { - if (IsEmpty()) return *this; - - // handle . and .. - std::vector normalized_components{}; - for (const auto& component : components_) { - // ignore . - if (component == k_dot) continue; - - // handle .. which backtracks over previous component - if (component == k_dotdot) { - if (!normalized_components.empty() && - normalized_components.back() != k_dotdot) { - normalized_components.pop_back(); - continue; - } - } - - normalized_components.emplace_back(component); - } - - // remove leading ..'s if root dir present - if (has_root_dir_) { - const auto first_non_dotdot_it = std::find_if( - normalized_components.begin(), normalized_components.end(), - [](const PathString& component) { return component != k_dotdot; }); - normalized_components.erase( - normalized_components.begin(), first_non_dotdot_it); - } - - // if empty at this point, add a dot - if (!has_root_dir_ && root_name_.empty() && normalized_components.empty()) { - normalized_components.emplace_back(k_dot); - } - - components_ = std::move(normalized_components); - - return *this; -} - -Path& Path::Append(const Path& other) { - if (other.IsAbsolute() || - (!other.root_name_.empty() && other.root_name_ != root_name_)) { - return *this = other; - } - - if (other.has_root_dir_) { - has_root_dir_ = true; - components_.clear(); - } - - components_.insert( - components_.end(), other.components_.begin(), other.components_.end()); - - return *this; -} - -Path& Path::Concat(const PathString& value) { - auto first_separator = std::find_if(value.begin(), value.end(), - [](PathChar c) { - return std::find( - k_valid_path_separators.begin(), - k_valid_path_separators.end(), - c) != k_valid_path_separators.end(); - }); - ORT_ENFORCE(first_separator == value.end(), - "Cannot concatenate with a string containing a path separator. String: ", ToUTF8String(value)); - - if (components_.empty()) { - components_.push_back(value); - } else { - components_.back() += value; - } - return *this; -} - -Status RelativePath(const Path& src, const Path& dst, Path& rel) { - ORT_RETURN_IF_NOT( - src.GetRootPathString() == dst.GetRootPathString(), - "Paths must have the same root to compute a relative path. ", - "src root: ", ToUTF8String(src.GetRootPathString()), - ", dst root: ", ToUTF8String(dst.GetRootPathString())); - - const Path norm_src = src.NormalizedPath(), norm_dst = dst.NormalizedPath(); - const auto& src_components = norm_src.GetComponents(); - const auto& dst_components = norm_dst.GetComponents(); - - const auto min_num_components = std::min( - src_components.size(), dst_components.size()); - - const auto mismatch_point = std::mismatch( - src_components.begin(), src_components.begin() + min_num_components, - dst_components.begin()); - - const auto& common_src_components_end = mismatch_point.first; - const auto& common_dst_components_end = mismatch_point.second; - - std::vector rel_components{}; - rel_components.reserve( - (src_components.end() - common_src_components_end) + - (dst_components.end() - common_dst_components_end)); - - std::fill_n( - std::back_inserter(rel_components), - (src_components.end() - common_src_components_end), - k_dotdot); - - std::copy( - common_dst_components_end, dst_components.end(), - std::back_inserter(rel_components)); - - rel = Path(PathString{}, false, rel_components); - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/common/path.h b/onnxruntime/core/common/path.h deleted file mode 100644 index 732bbabe8ae3e..0000000000000 --- a/onnxruntime/core/common/path.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/common/common.h" -#include "core/common/path_string.h" - -namespace onnxruntime { - -#ifdef _WIN32 -constexpr PathChar k_preferred_path_separator = ORT_TSTR('\\'); -#else // POSIX -constexpr PathChar k_preferred_path_separator = ORT_TSTR('/'); -#endif - -// Note: We should use the std::filesystem library after upgrading to C++17. - -/** A filesystem path. */ -class Path { - public: - Path() = default; - Path(const Path&) = default; - Path& operator=(const Path&) = default; - Path(Path&&) = default; - Path& operator=(Path&&) = default; - - /** Parses a path from `path_str`. */ - static Status Parse(const PathString& path_str, Path& path); - /** Parses a path from `path_str`. Throws on failure. */ - static Path Parse(const PathString& path_str); - - /** Gets a string representation of the path. */ - PathString ToPathString() const; - /** Gets a string representation of the path's root path, if any. */ - PathString GetRootPathString() const; - /** Gets the path components following the path root. */ - const std::vector& GetComponents() const { return components_; } - - /** Whether the path is empty. */ - bool IsEmpty() const; - - /** Whether the path is absolute (refers unambiguously to a file location). */ - bool IsAbsolute() const; - /** Whether the path is relative (not absolute). */ - bool IsRelative() const { return !IsAbsolute(); } - - /** Returns a copy of the path without the last component. */ - Path ParentPath() const; - - /** - * Normalizes the path. - * A normalized path is one with "."'s and ".."'s resolved. - * Note: This is a pure path computation with no filesystem access. - */ - Path& Normalize(); - /** Returns a normalized copy of the path. */ - Path NormalizedPath() const { - Path p{*this}; - return p.Normalize(); - } - - /** - * Appends `other` to the path. - * The algorithm should model that of std::filesystem::path::append(). - */ - Path& Append(const Path& other); - - /** - * Concatenates the current path and the argument string. - * Unlike with Append() or operator/=, additional directory separators are never introduced. - */ - Path& Concat(const PathString& string); - - /** Equivalent to this->Append(other). */ - Path& operator/=(const Path& other) { - return Append(other); - } - /** Returns `a` appended with `b`. */ - friend Path operator/(Path a, const Path& b) { - return a /= b; - } - - friend Status RelativePath(const Path& src, const Path& dst, Path& rel); - - private: - Path(PathString root_name, bool has_root_dir, std::vector components) - : root_name_{std::move(root_name)}, - has_root_dir_{has_root_dir}, - components_{std::move(components)} { - } - - PathString root_name_{}; - bool has_root_dir_{false}; - std::vector components_{}; -}; - -/** - * Computes the relative path from `src` to `dst`. - * Note: This is a pure path computation with no filesystem access. - */ -Status RelativePath(const Path& src, const Path& dst, Path& rel); - -} // namespace onnxruntime diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index ec50bb7d6a5cb..7665a90448520 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -76,9 +76,9 @@ PathString MakeTensorFileName(const std::string& tensor_name, const NodeDumpOpti return path_utils::MakePathString(make_valid_name(tensor_name), dump_options.file_suffix, ".tensorproto"); } -void DumpTensorToFile(const Tensor& tensor, const std::string& tensor_name, const Path& file_path) { +void DumpTensorToFile(const Tensor& tensor, const std::string& tensor_name, const std::filesystem::path& file_path) { auto tensor_proto = utils::TensorToTensorProto(tensor, tensor_name); - const PathString file_path_str = file_path.ToPathString(); + const PathString file_path_str = file_path.native(); int output_fd; ORT_THROW_IF_ERROR(Env::Default().FileOpenWr(file_path_str, output_fd)); try { @@ -302,7 +302,7 @@ void DumpCpuTensor( break; } case NodeDumpOptions::DataDestination::TensorProtoFiles: { - const Path tensor_file = dump_options.output_dir / Path::Parse(MakeTensorFileName(tensor_metadata.name, dump_options)); + const std::filesystem::path tensor_file = dump_options.output_dir / MakeTensorFileName(tensor_metadata.name, dump_options); DumpTensorToFile(tensor, tensor_metadata.name, tensor_file); break; } @@ -411,11 +411,11 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() { } } - opts.output_dir = Path::Parse(ToPathString(Env::Default().GetEnvironmentVar(env_vars::kOutputDir))); + opts.output_dir = ToPathString(Env::Default().GetEnvironmentVar(env_vars::kOutputDir)); std::string sqlite_db_prefix = ParseEnvironmentVariableWithDefault(env_vars::kSqliteDbPrefix, "execution-trace"); - opts.sqlite_db_prefix = Path::Parse(ToPathString(sqlite_db_prefix)); + opts.sqlite_db_prefix = ToPathString(sqlite_db_prefix); // check for confirmation for dumping data to files for all nodes const bool is_input_or_output_requested = ((opts.dump_flags & NodeDumpOptions::DumpFlags::InputData) != 0) || diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h index bde005fc204c8..6090a835aa060 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h @@ -16,7 +16,6 @@ #pragma once -#include "core/common/path.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" #include "core/graph/graph.h" @@ -109,9 +108,9 @@ struct NodeDumpOptions { std::string file_suffix; // the output directory for dumped data files - Path output_dir; + std::filesystem::path output_dir; // the sqlite3 db to append dumped data - Path sqlite_db_prefix; + std::filesystem::path sqlite_db_prefix; // Total number of elements which trigger snippet rather than full array for Stdout. Value 0 disables snippet. int snippet_threshold; diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 2f3f942e75578..a66caf1ace33b 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -9,7 +9,6 @@ #ifndef SHARED_PROVIDER #include "core/common/common.h" -#include "core/common/path.h" #include "core/common/status.h" #include "core/common/safeint.h" #include "core/framework/endian_utils.h" diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 9c73ee16963bd..728af727ac83b 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -11,7 +11,6 @@ #include "core/common/flatbuffers.h" -#include "core/common/path.h" #include "core/graph/graph_viewer.h" #include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 7d80e6e5d3a76..5953935203b83 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -7,7 +7,6 @@ #include #include "core/common/gsl.h" -#include "core/common/path.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_external_data_info.h" #include "core/platform/env.h" diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index b8ae2188beb5d..3099faed18ac3 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -10,7 +10,6 @@ #include #include "core/common/common.h" #include "core/common/narrow.h" -#include "core/common/path.h" #include "core/framework/allocator.h" #include "core/optimizer/graph_transformer.h" #include "core/framework/tensor_shape.h" diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 7454b322a310c..bc6dac1a2f27f 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -903,13 +903,6 @@ struct ProviderHost { int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; - // Path - virtual PathString Path__ToPathString(const Path* p) noexcept = 0; - virtual const std::vector& Path__GetComponents(const Path* p) noexcept = 0; - virtual bool Path__IsEmpty(const Path* p) noexcept = 0; - virtual std::unique_ptr Path__construct() = 0; - virtual void Path__operator_delete(ONNX_NAMESPACE::Path* p) = 0; - // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 2ccd05fe9df60..fb3b274d9b80b 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -968,19 +968,6 @@ class GraphViewer final { void operator=(const GraphViewer&) = delete; }; -struct Path final { - static std::unique_ptr Create() { return g_host->Path__construct(); } - static void operator delete(void* p) { g_host->Path__operator_delete(reinterpret_cast(p)); } - - PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); } - const std::vector& GetComponents() const noexcept { return g_host->Path__GetComponents(this); } - bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); } - - Path() = delete; - Path(const Path&) = delete; - void operator=(const Path&) = delete; -}; - struct OpKernelContext final { template const T& RequiredInput(int index) const; diff --git a/onnxruntime/core/providers/tvm/tvm_api.cc b/onnxruntime/core/providers/tvm/tvm_api.cc index 37982d0bdb551..4c46ea5ffae72 100644 --- a/onnxruntime/core/providers/tvm/tvm_api.cc +++ b/onnxruntime/core/providers/tvm/tvm_api.cc @@ -16,7 +16,6 @@ #include #include "core/common/common.h" -#include "core/common/path.h" #include "core/common/gsl.h" #include "tvm_api.h" diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index f53894b9d1efb..d5e9c63847fbe 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -12,8 +12,7 @@ gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& ten auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { std::vector unpacked_tensor; - auto path = onnxruntime::Path::Create(); - auto s = onnxruntime::utils::UnpackInitializerData(tensor, *path, unpacked_tensor); + auto s = onnxruntime::utils::UnpackInitializerData(tensor, std::filesystem::path(), unpacked_tensor); mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); mut_tensor.clear_float_data(); mut_tensor.clear_int32_data(); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 0494616a9ca0c..1bb013d0cdc10 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -5,6 +5,7 @@ // It implements onnxruntime::ProviderHost #include "core/common/inlined_containers.h" +#include "core/common/path_string.h" #include "core/framework/allocator_utils.h" #include "core/framework/config_options.h" #include "core/framework/compute_capability.h" @@ -1203,13 +1204,6 @@ struct ProviderHostImpl : ProviderHost { } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } - // Path (wrapped) - PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } - const std::vector& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); } - bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); } - std::unique_ptr Path__construct() override { return std::make_unique(); } - void Path__operator_delete(ONNX_NAMESPACE::Path* p) override { delete p; }; - // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } diff --git a/onnxruntime/test/common/path_test.cc b/onnxruntime/test/common/path_test.cc deleted file mode 100644 index d097705773568..0000000000000 --- a/onnxruntime/test/common/path_test.cc +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/path.h" - -#include "gtest/gtest.h" - -#include "core/common/optional.h" -#include "test/util/include/asserts.h" - -namespace onnxruntime { -namespace test { - -TEST(PathTest, Parse) { - auto check_parse = - [](const std::string& path_string, - const std::string& expected_root, - const std::vector& expected_components) { - Path p{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - - std::vector expected_components_ps{}; - std::transform( - expected_components.begin(), expected_components.end(), - std::back_inserter(expected_components_ps), - [](const std::string& s) { return ToPathString(s); }); - EXPECT_EQ(p.GetComponents(), expected_components_ps); - EXPECT_EQ(p.GetRootPathString(), ToPathString(expected_root)); - }; - - check_parse( - "i/am/relative", - "", {"i", "am", "relative"}); -#ifdef _WIN32 - check_parse( - "/i/am/rooted", - R"(\)", {"i", "am", "rooted"}); - check_parse( - R"(\\server\share\i\am\rooted)", - R"(\\server\share\)", {"i", "am", "rooted"}); - check_parse( - R"(C:\i\am\rooted)", - R"(C:\)", {"i", "am", "rooted"}); - check_parse( - R"(C:i\am\relative)", - "C:", {"i", "am", "relative"}); -#else // POSIX - check_parse( - "/i/am/rooted", - "/", {"i", "am", "rooted"}); - check_parse( - "//root_name/i/am/rooted", - "//root_name/", {"i", "am", "rooted"}); -#endif -} - -TEST(PathTest, ParseFailure) { - auto check_parse_failure = - [](const std::string& path_string) { - Path p{}; - EXPECT_FALSE(Path::Parse(ToPathString(path_string), p).IsOK()); - }; - -#ifdef _WIN32 - check_parse_failure(R"(\\server_name_no_separator)"); - check_parse_failure(R"(\\server_name_no_share_name\)"); - check_parse_failure(R"(\\server_name\share_name_no_root_dir)"); -#else // POSIX - check_parse_failure("//root_name_no_root_dir"); -#endif -} - -TEST(PathTest, IsEmpty) { - auto check_empty = - [](const std::string& path_string, bool is_empty) { - Path p{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - - EXPECT_EQ(p.IsEmpty(), is_empty); - }; - - check_empty("", true); - check_empty(".", false); - check_empty("/", false); -} - -TEST(PathTest, IsAbsoluteOrRelative) { - auto check_abs_or_rel = - [](const std::string& path_string, bool is_absolute) { - Path p{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - - EXPECT_EQ(p.IsAbsolute(), is_absolute); - EXPECT_EQ(p.IsRelative(), !is_absolute); - }; - - check_abs_or_rel("relative", false); - check_abs_or_rel("", false); -#ifdef _WIN32 - check_abs_or_rel(R"(\root_relative)", false); - check_abs_or_rel(R"(\)", false); - check_abs_or_rel("C:drive_relative", false); - check_abs_or_rel("C:", false); - check_abs_or_rel(R"(C:\absolute)", true); - check_abs_or_rel(R"(C:\)", true); -#else // POSIX - check_abs_or_rel("/absolute", true); - check_abs_or_rel("/", true); -#endif -} - -TEST(PathTest, ParentPath) { - auto check_parent = - [](const std::string path_string, const std::string& expected_parent_path_string) { - Path p{}, p_expected_parent{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_parent_path_string), p_expected_parent)); - - EXPECT_EQ(p.ParentPath().ToPathString(), p_expected_parent.ToPathString()); - }; - - check_parent("a/b", "a"); - check_parent("/a/b", "/a"); - check_parent("", ""); - check_parent("/", "/"); -} - -TEST(PathTest, Normalize) { - auto check_normalize = - [](const std::string& path_string, - const std::string& expected_normalized_path_string) { - Path p{}, p_expected_normalized{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_normalized_path_string), p_expected_normalized)); - - EXPECT_EQ(p.Normalize().ToPathString(), p_expected_normalized.ToPathString()); - }; - - check_normalize("/a/b/./c/../../d/../e", "/a/e"); - check_normalize("a/b/./c/../../d/../e", "a/e"); - check_normalize("/../a/../../b", "/b"); - check_normalize("../a/../../b", "../../b"); - check_normalize("/a/..", "/"); - check_normalize("a/..", "."); - check_normalize("", ""); - check_normalize("/", "/"); - check_normalize(".", "."); -} - -TEST(PathTest, Append) { - auto check_append = - [](const std::string& a, const std::string& b, const std::string& expected_ab) { - Path p_a{}, p_b{}, p_expected_ab{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(b), p_b)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_ab), p_expected_ab)); - - EXPECT_EQ(p_a.Append(p_b).ToPathString(), p_expected_ab.ToPathString()); - }; - - check_append("/a/b", "c/d", "/a/b/c/d"); - check_append("/a/b", "/c/d", "/c/d"); - check_append("a/b", "c/d", "a/b/c/d"); - check_append("a/b", "/c/d", "/c/d"); -#ifdef _WIN32 - check_append(R"(C:\a\b)", R"(c\d)", R"(C:\a\b\c\d)"); - check_append(R"(C:\a\b)", R"(\c\d)", R"(C:\c\d)"); - check_append(R"(C:\a\b)", R"(D:c\d)", R"(D:c\d)"); - check_append(R"(C:\a\b)", R"(D:\c\d)", R"(D:\c\d)"); - check_append(R"(C:a\b)", R"(c\d)", R"(C:a\b\c\d)"); - check_append(R"(C:a\b)", R"(\c\d)", R"(C:\c\d)"); - check_append(R"(C:a\b)", R"(D:c\d)", R"(D:c\d)"); - check_append(R"(C:a\b)", R"(D:\c\d)", R"(D:\c\d)"); -#else // POSIX - check_append("//root_0/a/b", "c/d", "//root_0/a/b/c/d"); - check_append("//root_0/a/b", "/c/d", "/c/d"); - check_append("//root_0/a/b", "//root_1/c/d", "//root_1/c/d"); -#endif -} - -TEST(PathTest, RelativePath) { - auto check_relative = - [](const std::string& src, - const std::string& dst, - const std::string& expected_rel) { - Path p_src, p_dst, p_expected_rel, p_rel; - ASSERT_STATUS_OK(Path::Parse(ToPathString(src), p_src)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(dst), p_dst)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_rel), p_expected_rel)); - - ASSERT_STATUS_OK(RelativePath(p_src, p_dst, p_rel)); - EXPECT_EQ(p_rel.ToPathString(), p_expected_rel.ToPathString()); - }; - - check_relative( - "/a/b/c/d/e", "/a/b/c/d/e/f/g/h", - "f/g/h"); - check_relative( - "/a/b/c/d/e", "/a/b/f/g/h/i", - "../../../f/g/h/i"); - check_relative( - "a/b/../c/../d", "e/./f/../g/h", - "../../e/g/h"); -} - -TEST(PathTest, RelativePathFailure) { - auto check_relative_failure = - [](const std::string& src, - const std::string& dst) { - Path p_src, p_dst, p_rel; - ASSERT_STATUS_OK(Path::Parse(ToPathString(src), p_src)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(dst), p_dst)); - - EXPECT_FALSE(RelativePath(p_src, p_dst, p_rel).IsOK()); - }; - - check_relative_failure("/rooted", "relative"); - check_relative_failure("relative", "/rooted"); -#ifdef _WIN32 - check_relative_failure("C:/a", "D:/a"); -#else // POSIX - check_relative_failure("//root_0/a", "//root_1/a"); -#endif -} - -#if !defined(ORT_NO_EXCEPTIONS) -TEST(PathTest, Concat) { - auto check_concat = - [](const optional& a, const std::string& b, const std::string& expected_a, bool expect_throw = false) { - Path p_a{}, p_expected_a{}; - if (a.has_value()) { - ASSERT_STATUS_OK(Path::Parse(ToPathString(*a), p_a)); - } - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a)); - - if (expect_throw) { - EXPECT_THROW(p_a.Concat(ToPathString(b)).ToPathString(), OnnxRuntimeException); - } else { - EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString()); - } - }; - - check_concat({"/a/b"}, "c", "/a/bc"); - check_concat({"a/b"}, "cd", "a/bcd"); - check_concat({""}, "cd", "cd"); - check_concat({}, "c", "c"); -#ifdef _WIN32 - check_concat({"a/b"}, R"(c\d)", "", true /* expect_throw */); -#else - check_concat({"a/b"}, "c/d", "", true /* expect_throw */); -#endif -} -#endif - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc b/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc index 17e26a57f5f3e..b2ab2d9a5701b 100644 --- a/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc +++ b/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc @@ -27,7 +27,7 @@ void VerifyTensorProtoFileData(const PathString& tensor_proto_path, gsl::span actual_data{}; actual_data.resize(expected_data.size()); - ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), actual_data.size())); + ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, tensor_proto_path, actual_data.data(), actual_data.size())); ASSERT_EQ(gsl::span(actual_data), expected_data); } @@ -48,7 +48,7 @@ void VerifyTensorProtoFileDataInt4(const PathString& tensor_proto_path, std::vector> actual_data{}; actual_data.resize(expected_data.size()); - ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), num_elems)); + ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, tensor_proto_path, actual_data.data(), num_elems)); ASSERT_EQ(actual_data.size(), expected_data.size()); diff --git a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc index 7289f92c65663..32f2da806be3b 100644 --- a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc +++ b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc @@ -9,7 +9,6 @@ #include "gtest/gtest.h" #include "core/common/common.h" -#include "core/common/path.h" #include "core/graph/graph_flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/cpu/cpu_execution_provider.h" diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 4766ef6fbc621..5fc036790b765 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1782,8 +1782,8 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { {initializer_name, ort_value}}; // We do not need actual files there since we are not going to load it. - const auto tensor_data_dir_path = Path::Parse(ToPathString(".")); - const auto tensor_data_dir_relative_path = Path::Parse(ToPathString("external_data.bin")); + const auto tensor_data_dir_path = ORT_TSTR("."); + const auto tensor_data_dir_relative_path = ORT_TSTR("external_data.bin"); const auto tensor_proto = [&]() { @@ -1792,7 +1792,7 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { tensor_proto.add_dims(tensor_data.size()); tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - SetTensorProtoExternalData("location", ToUTF8String(tensor_data_dir_relative_path.ToPathString()), + SetTensorProtoExternalData("location", ToUTF8String(tensor_data_dir_relative_path), tensor_proto); SetTensorProtoExternalData("offset", "0", tensor_proto); SetTensorProtoExternalData("length", std::to_string(tensor_data.size() * sizeof(int32_t)), tensor_proto); @@ -1827,7 +1827,7 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { ASSERT_FALSE(utils::HasExternalData(*with_data)); const auto& original_tensor = ort_value.Get(); Tensor replaced_tensor(original_tensor.DataType(), data_shape, std::make_shared()); - ASSERT_STATUS_OK(utils::TensorProtoToTensor(Env::Default(), tensor_data_dir_path.ToPathString().c_str(), *with_data, + ASSERT_STATUS_OK(utils::TensorProtoToTensor(Env::Default(), tensor_data_dir_path, *with_data, replaced_tensor)); ASSERT_EQ(original_tensor.GetElementType(), replaced_tensor.GetElementType()); const auto original_span = original_tensor.DataAsSpan(); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f83fb8238ff61..2bfa57a2ceb9e 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6215,9 +6215,10 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { std::make_unique(strategy, level, test_case.allow_ops), TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - Path p = Path::Parse(test_case.model_uri); - ASSERT_FALSE(p.GetComponents().empty()); - PathString transformed_model_uri = temp_dir.Path() + GetPathSep() + ORT_TSTR("transformed_") + p.GetComponents().back(); + std::filesystem::path p = test_case.model_uri; + PathString model_filename = ORT_TSTR("transformed_"); + model_filename += p.filename(); + std::filesystem::path transformed_model_uri = std::filesystem::path(temp_dir.Path()) / model_filename; ASSERT_STATUS_OK(Model::Save(*p_model, transformed_model_uri)); // Load the transformed model to validate ASSERT_STATUS_OK(Model::Load(transformed_model_uri, p_model, nullptr, *logger_)); diff --git a/orttraining/orttraining/core/framework/checkpoint_common.h b/orttraining/orttraining/core/framework/checkpoint_common.h index 316417829e43b..5ff96fa77753d 100644 --- a/orttraining/orttraining/core/framework/checkpoint_common.h +++ b/orttraining/orttraining/core/framework/checkpoint_common.h @@ -6,7 +6,6 @@ #include "core/framework/tensorprotoutils.h" #include "core/common/logging/logging.h" #include "core/common/logging/sinks/clog_sink.h" -#include "core/common/path.h" #include "core/common/path_string.h" #include "core/common/status.h" #include "core/framework/framework_common.h" diff --git a/orttraining/orttraining/core/framework/checkpointing.cc b/orttraining/orttraining/core/framework/checkpointing.cc index 9e1aa8d17e3ee..462a05d9db562 100644 --- a/orttraining/orttraining/core/framework/checkpointing.cc +++ b/orttraining/orttraining/core/framework/checkpointing.cc @@ -10,7 +10,6 @@ #include "core/common/common.h" #include "core/common/logging/logging.h" -#include "core/common/path.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/endian_utils.h" #include "core/framework/ort_value.h" @@ -260,13 +259,10 @@ Status LoadModelCheckpoint( ORT_RETURN_IF_ERROR(Env::Default().GetCanonicalPath( checkpoint_path, checkpoint_canonical_path)); - Path relative_tensors_data_path_obj{}; - ORT_RETURN_IF_ERROR(RelativePath( - Path::Parse(model_directory_canonical_path), - Path::Parse(GetCheckpointTensorsDataFilePath(checkpoint_canonical_path)), - relative_tensors_data_path_obj)); + std::filesystem::path relative_tensors_data_path_obj = std::filesystem::relative( + GetCheckpointTensorsDataFilePath(checkpoint_canonical_path), model_directory_canonical_path); ORT_RETURN_IF_ERROR(UpdateTensorsExternalDataLocations( - relative_tensors_data_path_obj.ToPathString(), loaded_tensor_protos)); + relative_tensors_data_path_obj.native(), loaded_tensor_protos)); } // read properties file diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 15c74d40926ee..6421f7c81f7fb 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -1019,9 +1019,8 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz optimizer = optimizer.substr(0, pos); perf_metrics["Optimizer"] = optimizer; - Path model_path{}; - ORT_RETURN_IF_ERROR(Path::Parse(params_.model_path, model_path)); - PathString leaf = model_path.GetComponents().back(); + std::filesystem::path model_path = params_.model_path; + PathString leaf = model_path.filename(); std::string model_name = ToUTF8String(leaf.c_str()); perf_metrics["ModelName"] = model_name; From 0bbd061a54dec206a526c2eee3566560dba6ea5a Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 4 Jul 2024 10:50:27 -0700 Subject: [PATCH 11/13] Exclude azure ep from gen_def.cc (#21250) Addresses python packaging pipeline failure. --- tools/ci_build/gen_def.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index b53fb33659120..fe47d8dbe57fe 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -79,6 +79,7 @@ def parse_arguments(): "cann", "dnnl", "tensorrt", + "azure", ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") From 3f6b7430d64cdad07acfbeb7b464387dc0b6707d Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 5 Jul 2024 17:27:45 +0800 Subject: [PATCH 12/13] Use cuda memset async (#21216) ### Description ### Motivation and Context --- .../orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index 7bd759e8976c1..f3feef4391bb5 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor { typedef typename ToCudaType::MappedType CudaT; const CudaT* input_data = reinterpret_cast(input_tensor.Data()); - CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT))); + CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT), + stream)); PadAndUnflattenImpl(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound, input_data, indices_tensor.Data(), reinterpret_cast(output_tensor.MutableData())); @@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { const Tensor* input_tensor = context->Input(0); const Tensor* indices_tensor = context->Input(1); const Tensor* unflatten_dims_tensor = context->Input(2); // Parse the 1-D shape tensor. + ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1, "unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions()); ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2, From 9ef28f092f575ce15bfc40c7663e649e75809c53 Mon Sep 17 00:00:00 2001 From: KnightYao Date: Fri, 5 Jul 2024 23:11:59 +0800 Subject: [PATCH 13/13] [Fix Bug] Fp8*Fp8 Run Error (#20911) Fix fp8*fp8 when input A is e5m2, input B is e4m3 will run error ### Description ### Motivation and Context --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 28ab27ee33d10..07c5de2fe8d8c 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm( #endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; #endif default: @@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm( compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; break; case CUDA_R_32F: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; default: ORT_THROW("Unable to determine computeType in operator GemmFloat8.");