From f917dde71740982c4520febc0ced1bff58b0068d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:04:02 -0800 Subject: [PATCH 01/39] [web] remove xnnpack from web backends (#19116) ### Description XNNPACK is already disabled in web assembly build. This change removes the xnnpack backend registration in JS. --- js/common/lib/inference-session.ts | 2 +- js/web/lib/index.ts | 7 ++----- js/web/lib/wasm/session-options.ts | 3 --- js/web/script/test-runner-cli-args.ts | 7 +++---- js/web/test/test-runner.ts | 4 ++-- .../github/azure-pipelines/templates/win-web-ci.yml | 6 +++--- .../azure-pipelines/templates/win-web-multi-browsers.yml | 6 +++--- 7 files changed, 14 insertions(+), 21 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index edc32535fc64d..1221b52cd4985 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -181,7 +181,7 @@ export declare namespace InferenceSession { // Currently, we have the following backends to support execution providers: // Backend Node.js binding: supports 'cpu' and 'cuda'. - // Backend WebAssembly: supports 'cpu', 'wasm', 'xnnpack' and 'webnn'. + // Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'. // Backend ONNX.js: supports 'webgl'. // Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android). interface ExecutionProviderOptionMap { diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 4f1a3943de69a..baf45e74addea 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -26,11 +26,8 @@ if (!BUILD_DEFS.DISABLE_WASM) { } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (BUILD_DEFS.DISABLE_TRAINING) { - registerBackend('xnnpack', wasmBackend, 9); - if (!BUILD_DEFS.DISABLE_WEBNN) { - registerBackend('webnn', wasmBackend, 9); - } + if (!BUILD_DEFS.DISABLE_WEBNN) { + registerBackend('webnn', wasmBackend, 9); } } diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 45ea48a2df209..41ab2d52ca209 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -60,9 +60,6 @@ const setExecutionProviders = // check EP name switch (epName) { - case 'xnnpack': - epName = 'XNNPACK'; - break; case 'webnn': epName = 'WEBNN'; if (typeof ep !== 'string') { diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index fc74adfed1fee..8f6c5f6f04122 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -36,7 +36,6 @@ Options: webgl webgpu wasm - xnnpack webnn -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) @@ -111,7 +110,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; + type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'dev'|'perf'; type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; @@ -378,13 +377,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack', 'webnn']; + const browserBackends = ['webgl', 'webgpu', 'wasm', 'webnn']; // TODO: remove this when Chrome support WebNN. // we need this for now because Chrome does not support webnn yet, // and ChromeCanary is not in CI. - const defaultBrowserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack' /*, 'webnn'*/]; + const defaultBrowserBackends = ['webgl', 'webgpu', 'wasm' /*, 'webnn'*/]; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 3492c8f3780ea..442cb1bcf1f34 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -96,7 +96,7 @@ async function loadTensors( const outputs: Test.NamedTensor[] = []; let dataFileType: 'none'|'pb'|'npy' = 'none'; - const allowInt64 = ['wasm', 'xnnpack', 'webgpu', 'webnn'].includes(backendName); + const allowInt64 = ['wasm', 'webgpu', 'webnn'].includes(backendName); for (const dataFile of testCase.dataFiles) { const ext = extname(dataFile); @@ -317,7 +317,7 @@ export class TensorResultValidator { } else if (backend === 'webgpu') { this.absoluteThreshold = WEBGPU_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WEBGPU_THRESHOLD_RELATIVE_ERROR; - } else if (backend === 'wasm' || backend === 'xnnpack' || backend === 'webnn') { + } else if (backend === 'wasm' || backend === 'webnn') { this.absoluteThreshold = WASM_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WASM_THRESHOLD_RELATIVE_ERROR; } else if (backend === 'onnxruntime') { diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 8d4efc79eaca8..8ba3517530edd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -169,12 +169,12 @@ jobs: errorActionPreference: stop displayName: 'Pack NPM packages' - script: | - npm test -- -e=chrome -b=webgl,wasm,xnnpack + npm test -- -e=chrome -b=webgl,wasm workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' + displayName: 'Run ort-web tests (wasm,webgl backend)' condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=chrome -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) + npm test -- -e=chrome -b=webgl,wasm,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index f7876f15029c1..31ee488318a0b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -68,15 +68,15 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm ci /js/web/' - script: | - npm test -- suite0 -b=wasm,webgl,xnnpack --wasm-init-timeout=30000 --file-cache + npm test -- suite0 -b=wasm,webgl --wasm-init-timeout=30000 --file-cache workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm test (Suite0, Chrome)' - script: | - npm test -- suite0 -b=wasm,webgl,xnnpack --env=firefox --wasm-init-timeout=30000 --file-cache + npm test -- suite0 -b=wasm,webgl --env=firefox --wasm-init-timeout=30000 --file-cache workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm test (Suite0, Firefox)' - script: | - npm test -- suite0 -b=wasm,webgl,xnnpack --env=edge --wasm-init-timeout=30000 --file-cache + npm test -- suite0 -b=wasm,webgl --env=edge --wasm-init-timeout=30000 --file-cache workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm test (Suite0, Edge)' - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 From bb4011b2b14cb2702a4922ccd0b070d9ecc49a93 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sun, 14 Jan 2024 11:36:49 -0800 Subject: [PATCH 02/39] Set default flags nvcc and do not set default compile flags for ROCM EP (#19124) ### Description Set default flags nvcc and do not set the flags for ROCM EP. ### Motivation and Context 1. To meet a BinSkim requirement for CUDA EP. https://github.com/microsoft/binskim/blob/main/docs/BinSkimRules.md#rule-BA2024EnableSpectreMitigations 2. The ROCM EP's pipeline is broken since PR #19073 . Unit tests failed to load the EP with the following error message: Failed to load library libonnxruntime_providers_rocm.so with error: /build/Release/libonnxruntime_providers_rocm.so: undefined symbol: vtable for onnxruntime::InsertMaxPoolOutput . This PR is a hot fix to bring the pipeline back. So far I don't know why the error happened. The symbol "InsertMaxPoolOutput" is in onnxruntime_optimizers. I don't see any EP code references it directly. --- tools/ci_build/build.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 315b9a237b1c4..0da4adb51767d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1474,15 +1474,18 @@ def generate_build_tree( cflags = None cxxflags = None ldflags = None + cudaflags = [] for config in configs: # Setup default values for cflags/cxxflags/ldflags. # The values set here are purely for security and compliance purposes. ONNX Runtime should work fine without these flags. if ( "CFLAGS" not in os.environ and "CXXFLAGS" not in os.environ + and (not args.use_cuda or "CUDAFLAGS" not in os.environ) and not args.ios and not args.android and not args.build_wasm + and not args.use_rocm and not (is_linux() and platform.machine() != "aarch64" and platform.machine() != "x86_64") ): if is_windows(): @@ -1515,9 +1518,19 @@ def generate_build_tree( cxxflags = cflags.copy() if not args.disable_exceptions: cxxflags += ["/EHsc"] + if args.use_cuda: + # On Windows, nvcc passes /EHsc to the host compiler by default. + cuda_compile_flags_str = "" + for compile_flag in cflags: + if compile_flag.startswith("/D"): + cudaflags.append(compile_flag) + else: + cuda_compile_flags_str = cuda_compile_flags_str + " " + compile_flag + if len(cuda_compile_flags_str) != 0: + cudaflags.append('-Xcompiler="%s"' % cuda_compile_flags_str) elif is_linux() or is_macOS(): if is_linux(): - ldflags = ["-Wl,-Bsymbolic-functions", "-Wl,-z,relro", "-Wl,-z,now"] + ldflags = ["-Wl,-Bsymbolic-functions", "-Wl,-z,relro", "-Wl,-z,now", "-Wl,-z,noexecstack"] else: ldflags = [] if config == "Release": @@ -1560,7 +1573,8 @@ def generate_build_tree( # The following flags needs GCC 8 and newer cflags += ["-fstack-clash-protection", "-fcf-protection"] cxxflags = cflags.copy() - + if args.use_cuda: + cudaflags = cflags.copy() config_build_dir = get_config_build_dir(build_dir, config) os.makedirs(config_build_dir, exist_ok=True) if args.use_tvm: @@ -1580,6 +1594,8 @@ def generate_build_tree( "-DCMAKE_C_FLAGS=%s" % (" ".join(cflags)), "-DCMAKE_CXX_FLAGS=%s" % (" ".join(cxxflags)), ] + if cudaflags is not None and len(cudaflags) != 0: + temp_cmake_args += ["-DCMAKE_CUDA_FLAGS_INIT=%s" % (" ".join(cudaflags))] if ldflags is not None and len(ldflags) != 0: temp_cmake_args += [ "-DCMAKE_EXE_LINKER_FLAGS_INIT=%s" % (" ".join(ldflags)), From 76797127d6a3125fc59e605670809957a2183cbe Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Sun, 14 Jan 2024 14:37:26 -0500 Subject: [PATCH 03/39] Always download cuda and trt libraries from Azure blob (#19118) ### Description This way, we will not need to update the windows images constantly and allow more flexibility to choose the cuda version in the future. --- .../c-api-noopenmp-packaging-pipelines.yml | 2 ++ .../jobs/download_win_gpu_library.yml | 36 +++++++++++-------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 93d3b7f37008b..f80b035582f18 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -1172,6 +1172,7 @@ stages: ArtifactSuffix: 'GPU' StageSuffix: 'GPU' Skipx86Tests: 'true' + CudaVersion: ${{ parameters.CudaVersion }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} @@ -1183,6 +1184,7 @@ stages: StageSuffix: 'GPU' MoreSuffix: '_Windows' Skipx86Tests: 'true' + CudaVersion: ${{ parameters.CudaVersion }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index b7ae9ffa3c219..538cccd3c903b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -20,31 +20,37 @@ steps: - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\bin;$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\extras\CUPTI\lib64" displayName: 'Append CUDA SDK Directory to PATH' + - task: CmdLine@2 inputs: script: | echo %PATH% - displayName: 'Print PATH' + displayName: 'Print PATH after download CUDA SDK' - ${{ if eq(parameters.DownloadTRT, true) }}: - ${{ if eq(parameters.CudaVersion, '11.8') }}: - - powershell: | - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 $(Agent.TempDirectory) - displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8' - - powershell: | - Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib" - displayName: 'Append TensorRT Directory to PATH' - + - bash: | + echo "##vso[task.setvariable variable=trtCudaVersion]11.8" + displayName: Set trtCudaVersion - ${{ if eq(parameters.CudaVersion, '12.2') }}: - - powershell: | - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 $(Agent.TempDirectory) - displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0' - - powershell: | - Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib" - displayName: 'Append TensorRT Directory to PATH' + - bash: | + echo "##vso[task.setvariable variable=trtCudaVersion]12.0" + displayName: Set trtCudaVersion + + - bash: | + echo $(trtCudaVersion) + displayName: Get trtCudaVersion + + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-$(trtCudaVersion) $(Agent.TempDirectory) + displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-$(trtCudaVersion)' + + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-$(trtCudaVersion)\lib" + displayName: 'Append TensorRT Directory to PATH' - task: CmdLine@2 inputs: script: | echo %PATH% - displayName: 'Print PATH' \ No newline at end of file + displayName: 'Print PATH after download TensorRT' \ No newline at end of file From c3ce9df80c2cfc7013445f8b44213f3e75cac753 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Sun, 14 Jan 2024 17:51:00 -0500 Subject: [PATCH 04/39] Disabling python3.12 on training python packaging pipleines (#19123) --- .../templates/py-packaging-training-cuda-stage.yml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml index e7b935712ac6c..158037661f072 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml @@ -98,12 +98,13 @@ stages: OpsetVersion: ${{ parameters.opset_version }} CudaVersion: ${{ parameters.cuda_version }} UploadWheel: ${{ parameters.upload_wheel }} - Python312: - PythonVersion: '3.12' - TorchVersion: ${{ parameters.torch_version }} - OpsetVersion: ${{ parameters.opset_version }} - CudaVersion: ${{ parameters.cuda_version }} - UploadWheel: ${{ parameters.upload_wheel }} +# TODO: enable this when we have torch support pyton 3.12 +# Python312: +# PythonVersion: '3.12' +# TorchVersion: ${{ parameters.torch_version }} +# OpsetVersion: ${{ parameters.opset_version }} +# CudaVersion: ${{ parameters.cuda_version }} +# UploadWheel: ${{ parameters.upload_wheel }} steps: - task: CmdLine@2 From 71657d1eb8b0a24a4b6584d9e904506a0b4e1521 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 14 Jan 2024 17:53:26 -0500 Subject: [PATCH 05/39] [java] Fix double close (#19133) ### Description The `OnnxValue` and `OrtProviderOptions` implementations now check to see if they've been closed before accessing the native pointer, and also before close is called. ### Motivation and Context Before they could be closed twice which SIGSEGV'd the JVM. Fixes #19125. --- .../src/main/java/ai/onnxruntime/OnnxMap.java | 27 +++++++++++++-- .../java/ai/onnxruntime/OnnxSequence.java | 27 +++++++++++++-- .../java/ai/onnxruntime/OnnxSparseTensor.java | 18 ++++++++-- .../main/java/ai/onnxruntime/OnnxTensor.java | 24 +++++++++++--- .../java/ai/onnxruntime/OnnxTensorLike.java | 16 +++++++++ .../main/java/ai/onnxruntime/OnnxValue.java | 9 ++++- .../ai/onnxruntime/OrtProviderOptions.java | 30 ++++++++++++++++- .../ai/onnxruntime/OrtTrainingSession.java | 33 +++++++++++++++++-- .../StringConfigProviderOptions.java | 1 + .../java/ai/onnxruntime/InferenceTest.java | 2 ++ .../java/ai/onnxruntime/OnnxTensorTest.java | 27 +++++++++++++-- .../test/java/ai/onnxruntime/TestHelpers.java | 12 +++++++ 12 files changed, 208 insertions(+), 18 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxMap.java b/java/src/main/java/ai/onnxruntime/OnnxMap.java index 354ebec61274d..cde9f0de4ff0a 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxMap.java +++ b/java/src/main/java/ai/onnxruntime/OnnxMap.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.logging.Logger; /** * A container for a map returned by {@link OrtSession#run(Map)}. @@ -16,6 +17,7 @@ * values: String, Long, Float, Double. */ public class OnnxMap implements OnnxValue { + private static final Logger logger = Logger.getLogger(OnnxMap.class.getName()); static { try { @@ -107,6 +109,8 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) { private final OnnxMapValueType valueType; + private boolean closed; + /** * Constructs an OnnxMap containing a reference to the native map along with the type information. * @@ -122,6 +126,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) { this.info = info; this.stringKeys = info.keyType == OnnxJavaType.STRING; this.valueType = OnnxMapValueType.mapFromOnnxJavaType(info.valueType); + this.closed = false; } /** @@ -146,6 +151,7 @@ public OnnxValueType getType() { */ @Override public Map getValue() throws OrtException { + checkClosed(); Object[] keys = getMapKeys(); Object[] values = getMapValues(); HashMap map = new HashMap<>(OrtUtil.capacityFromSize(keys.length)); @@ -222,10 +228,27 @@ public String toString() { return "ONNXMap(size=" + size() + ",info=" + info.toString() + ")"; } + @Override + public synchronized boolean isClosed() { + return closed; + } + /** Closes this map, releasing the native memory backing it and it's elements. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed map."); + } + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } } private native String[] getStringKeys(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/java/ai/onnxruntime/OnnxSequence.java b/java/src/main/java/ai/onnxruntime/OnnxSequence.java index 93e1be21588b4..7722514b913b6 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSequence.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSequence.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.logging.Logger; /** * A sequence of {@link OnnxValue}s all of the same type. @@ -24,6 +25,7 @@ * */ public class OnnxSequence implements OnnxValue { + private static final Logger logger = Logger.getLogger(OnnxSequence.class.getName()); static { try { @@ -40,6 +42,8 @@ public class OnnxSequence implements OnnxValue { private final SequenceInfo info; + private boolean closed; + /** * Creates the wrapper object for a native sequence. * @@ -53,6 +57,7 @@ public class OnnxSequence implements OnnxValue { this.nativeHandle = nativeHandle; this.allocatorHandle = allocatorHandle; this.info = info; + this.closed = false; } @Override @@ -76,6 +81,7 @@ public OnnxValueType getType() { */ @Override public List getValue() throws OrtException { + checkClosed(); if (info.sequenceOfMaps) { OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); return Collections.unmodifiableList(Arrays.asList(maps)); @@ -110,10 +116,27 @@ public String toString() { return "OnnxSequence(info=" + info.toString() + ")"; } + @Override + public synchronized boolean isClosed() { + return closed; + } + /** Closes this sequence, releasing the native memory backing it and it's elements. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed sequence."); + } + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } } private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java index 53bd4c7f9b3e6..804fe742ad624 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java @@ -14,6 +14,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Arrays; +import java.util.logging.Logger; /** * A Java object wrapping an OnnxSparseTensor. @@ -22,6 +23,7 @@ * different static inner class representing each type. */ public final class OnnxSparseTensor extends OnnxTensorLike { + private static final Logger logger = Logger.getLogger(OnnxSparseTensor.class.getName()); private final SparseTensorType sparseTensorType; // Held to prevent deallocation while used in native code. @@ -198,6 +200,7 @@ public OnnxValueType getType() { @Override public SparseTensor getValue() throws OrtException { + checkClosed(); Buffer buffer = getValuesBuffer(); long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); switch (sparseTensorType) { @@ -234,8 +237,13 @@ public SparseTensor getValue() throws OrtException { } @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed OnnxSparseTensor."); + } } /** @@ -257,6 +265,7 @@ public SparseTensorType getSparseTensorType() { * @return The indices. */ public Buffer getIndicesBuffer() { + checkClosed(); switch (sparseTensorType) { case COO: case CSRC: @@ -295,6 +304,7 @@ public Buffer getIndicesBuffer() { * @return The inner indices. */ public LongBuffer getInnerIndicesBuffer() { + checkClosed(); if (sparseTensorType == SparseTensorType.CSRC) { LongBuffer buf = getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle) @@ -320,6 +330,7 @@ public LongBuffer getInnerIndicesBuffer() { * @return The data buffer. */ public Buffer getValuesBuffer() { + checkClosed(); ByteBuffer buffer = getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); switch (info.type) { @@ -396,6 +407,7 @@ public Buffer getValuesBuffer() { * @return The indices shape. */ public long[] getIndicesShape() { + checkClosed(); return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); } @@ -405,6 +417,7 @@ public long[] getIndicesShape() { * @return The indices shape. */ public long[] getInnerIndicesShape() { + checkClosed(); if (sparseTensorType == SparseTensorType.CSRC) { return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); } else { @@ -420,6 +433,7 @@ public long[] getInnerIndicesShape() { * @return The values shape. */ public long[] getValuesShape() { + checkClosed(); return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle); } diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 0078adb6402f8..e1ee2c14fd9d1 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -14,12 +14,14 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Optional; +import java.util.logging.Logger; /** * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be * returned as outputs. */ public class OnnxTensor extends OnnxTensorLike { + private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName()); /** * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does @@ -97,6 +99,7 @@ public OnnxValueType getType() { */ @Override public Object getValue() throws OrtException { + checkClosed(); if (info.isScalar()) { switch (info.type) { case FLOAT: @@ -144,16 +147,21 @@ public Object getValue() throws OrtException { @Override public String toString() { - return "OnnxTensor(info=" + info.toString() + ")"; + return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")"; } /** - * Closes the tensor, releasing it's underlying memory (if it's not backed by an NIO buffer). If - * it is backed by a buffer then the memory is released when the buffer is GC'd. + * Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it + * is backed by a buffer then the memory is released when the buffer is GC'd. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed tensor."); + } } /** @@ -165,6 +173,7 @@ public void close() { * @return A ByteBuffer copy of the OnnxTensor. */ public ByteBuffer getByteBuffer() { + checkClosed(); if (info.type != OnnxJavaType.STRING) { ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); @@ -183,6 +192,7 @@ public ByteBuffer getByteBuffer() { * @return A FloatBuffer copy of the OnnxTensor. */ public FloatBuffer getFloatBuffer() { + checkClosed(); if (info.type == OnnxJavaType.FLOAT) { // if it's fp32 use the efficient copy. FloatBuffer buffer = getBuffer().asFloatBuffer(); @@ -212,6 +222,7 @@ public FloatBuffer getFloatBuffer() { * @return A DoubleBuffer copy of the OnnxTensor. */ public DoubleBuffer getDoubleBuffer() { + checkClosed(); if (info.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = getBuffer().asDoubleBuffer(); DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity()); @@ -230,6 +241,7 @@ public DoubleBuffer getDoubleBuffer() { * @return A ShortBuffer copy of the OnnxTensor. */ public ShortBuffer getShortBuffer() { + checkClosed(); if ((info.type == OnnxJavaType.INT16) || (info.type == OnnxJavaType.FLOAT16) || (info.type == OnnxJavaType.BFLOAT16)) { @@ -250,6 +262,7 @@ public ShortBuffer getShortBuffer() { * @return An IntBuffer copy of the OnnxTensor. */ public IntBuffer getIntBuffer() { + checkClosed(); if (info.type == OnnxJavaType.INT32) { IntBuffer buffer = getBuffer().asIntBuffer(); IntBuffer output = IntBuffer.allocate(buffer.capacity()); @@ -268,6 +281,7 @@ public IntBuffer getIntBuffer() { * @return A LongBuffer copy of the OnnxTensor. */ public LongBuffer getLongBuffer() { + checkClosed(); if (info.type == OnnxJavaType.INT64) { LongBuffer buffer = getBuffer().asLongBuffer(); LongBuffer output = LongBuffer.allocate(buffer.capacity()); diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java index c2989fe296dc2..bbfd4e981ece2 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java @@ -28,6 +28,9 @@ public abstract class OnnxTensorLike implements OnnxValue { /** The size and shape information for this tensor. */ protected final TensorInfo info; + /** Is this value closed? */ + protected boolean closed; + /** * Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor). * @@ -39,6 +42,7 @@ public abstract class OnnxTensorLike implements OnnxValue { this.nativeHandle = nativeHandle; this.allocatorHandle = allocatorHandle; this.info = info; + this.closed = false; } /** @@ -59,4 +63,16 @@ long getNativeHandle() { public TensorInfo getInfo() { return info; } + + @Override + public synchronized boolean isClosed() { + return closed; + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } + } } diff --git a/java/src/main/java/ai/onnxruntime/OnnxValue.java b/java/src/main/java/ai/onnxruntime/OnnxValue.java index 752a0e74267d3..e829bc80f09f6 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxValue.java +++ b/java/src/main/java/ai/onnxruntime/OnnxValue.java @@ -64,7 +64,14 @@ public enum OnnxValueType { */ public ValueInfo getInfo(); - /** Closes the OnnxValue, freeing it's native memory. */ + /** + * Checks if this value is closed (i.e., the native object has been released). + * + * @return True if the value is closed and the native object has been released. + */ + public boolean isClosed(); + + /** Closes the OnnxValue, freeing its native memory. */ @Override public void close(); diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java index 39a5121fad7a2..70af10ff8cd79 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java @@ -5,11 +5,14 @@ package ai.onnxruntime; import java.io.IOException; +import java.util.logging.Logger; /** An abstract base class for execution provider options classes. */ // Note this lives in ai.onnxruntime to allow subclasses to access the OnnxRuntime.ortApiHandle // package private field. public abstract class OrtProviderOptions implements AutoCloseable { + private static final Logger logger = Logger.getLogger(OrtProviderOptions.class.getName()); + static { try { OnnxRuntime.init(); @@ -21,6 +24,9 @@ public abstract class OrtProviderOptions implements AutoCloseable { /** The native pointer. */ protected final long nativeHandle; + /** Is the native object closed? */ + protected boolean closed; + /** * Constructs a OrtProviderOptions wrapped around a native pointer. * @@ -28,6 +34,7 @@ public abstract class OrtProviderOptions implements AutoCloseable { */ protected OrtProviderOptions(long nativeHandle) { this.nativeHandle = nativeHandle; + this.closed = false; } /** @@ -46,9 +53,30 @@ protected static long getApiHandle() { */ public abstract OrtProvider getProvider(); + /** + * Is the native object closed? + * + * @return True if the native object has been released. + */ + public synchronized boolean isClosed() { + return closed; + } + @Override public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed tensor."); + } + } + + /** Checks if the OrtProviderOptions is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtProviderOptions"); + } } /** diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 49ddf29c22335..eeede3a1bed0b 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -12,6 +12,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.logging.Logger; /** * Wraps an ONNX training model and allows training and inference calls. @@ -1049,8 +1050,12 @@ private native void exportModelForInference( /** Wrapper class for the checkpoint state. */ static final class OrtCheckpointState implements AutoCloseable { + private static final Logger logger = Logger.getLogger(OrtCheckpointState.class.getName()); + final long nativeHandle; + private boolean closed; + /** * Wraps an object around the checkpoint native handle. * @@ -1058,6 +1063,7 @@ static final class OrtCheckpointState implements AutoCloseable { */ OrtCheckpointState(long nativeHandle) { this.nativeHandle = nativeHandle; + this.closed = false; } /** @@ -1097,6 +1103,7 @@ static OrtCheckpointState loadCheckpoint(String checkpoint) throws OrtException * @throws OrtException If the checkpoint failed to save. */ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException { + checkClosed(); Objects.requireNonNull(outputPath, "checkpoint path must not be null"); String outputStr = outputPath.toString(); saveCheckpoint( @@ -1115,6 +1122,7 @@ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtExc * @throws OrtException If the call failed. */ public void addProperty(String name, float value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1127,6 +1135,7 @@ public void addProperty(String name, float value) throws OrtException { * @throws OrtException If the call failed. */ public void addProperty(String name, int value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1139,6 +1148,7 @@ public void addProperty(String name, int value) throws OrtException { * @throws OrtException If the call failed. */ public void addProperty(String name, String value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1152,6 +1162,7 @@ public void addProperty(String name, String value) throws OrtException { * @throws OrtException If the property does not exist, or is of the wrong type. */ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getFloatProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1169,6 +1180,7 @@ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtExc * @throws OrtException If the property does not exist, or is of the wrong type. */ public int getIntProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getIntProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1186,6 +1198,7 @@ public int getIntProperty(OrtAllocator allocator, String name) throws OrtExcepti * @throws OrtException If the property does not exist, or is of the wrong type. */ public String getStringProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getStringProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1194,9 +1207,25 @@ public String getStringProperty(OrtAllocator allocator, String name) throws OrtE name); } + /** Checks if the OrtCheckpointState is closed, if so throws {@link IllegalStateException}. */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtCheckpointState"); + } + } + + public synchronized boolean isClosed() { + return closed; + } + @Override - public void close() { - close(OnnxRuntime.ortTrainingApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortTrainingApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing a checkpoint twice"); + } } /* diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java index 02207b2949e54..961163035c9a6 100644 --- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java @@ -32,6 +32,7 @@ protected StringConfigProviderOptions(long nativeHandle) { * @throws OrtException If the addition failed. */ public void add(String key, String value) throws OrtException { + checkClosed(); Objects.requireNonNull(key, "Key must not be null"); Objects.requireNonNull(value, "Value must not be null"); options.put(key, value); diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index e975117fb75bd..f6f9da1829402 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -69,7 +69,9 @@ public void environmentTest() { // Checks that the environment instance is the same. OrtEnvironment otherEnv = OrtEnvironment.getEnvironment(); assertSame(env, otherEnv); + TestHelpers.quietLogger(OrtEnvironment.class); otherEnv = OrtEnvironment.getEnvironment("test-name"); + TestHelpers.loudLogger(OrtEnvironment.class); assertSame(env, otherEnv); } diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index a5f285ba86a14..c060cf73ecf14 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -4,6 +4,10 @@ */ package ai.onnxruntime; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + import ai.onnxruntime.platform.Fp16Conversions; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -97,8 +101,8 @@ public void testBufferCreation() throws OrtException { float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { // array creation isn't backed by buffers - Assertions.assertFalse(t.ownsBuffer()); - Assertions.assertFalse(t.getBufferRef().isPresent()); + assertFalse(t.ownsBuffer()); + assertFalse(t.getBufferRef().isPresent()); FloatBuffer buf = t.getFloatBuffer(); float[] output = new float[arrValues.length]; buf.get(output); @@ -146,7 +150,7 @@ public void testBufferCreation() throws OrtException { directBuffer.rewind(); try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) { // direct buffers don't trigger a copy - Assertions.assertFalse(t.ownsBuffer()); + assertFalse(t.ownsBuffer()); // tensors backed by buffers can get the buffer ref back out Assertions.assertTrue(t.getBufferRef().isPresent()); FloatBuffer buf = t.getFloatBuffer(); @@ -428,4 +432,21 @@ public void testBf16RoundTrip() { } } } + + @Test + public void testClose() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + long[] input = new long[] {1, 2, 3, 4, 5}; + OnnxTensor value = OnnxTensor.createTensor(env, input); + assertFalse(value.isClosed()); + long[] output = (long[]) value.getValue(); + assertArrayEquals(input, output); + value.close(); + // check use after close throws + assertThrows(IllegalStateException.class, value::getValue); + // check double close doesn't crash (emits warning) + TestHelpers.quietLogger(OnnxTensor.class); + value.close(); + TestHelpers.loudLogger(OnnxTensor.class); + } } diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 55d8169434d48..c13cdf222b15b 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -22,6 +22,8 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; import java.util.regex.Pattern; import org.junit.jupiter.api.Assertions; @@ -258,6 +260,16 @@ static void flattenStringBase(String[] input, List output) { output.addAll(Arrays.asList(input)); } + static void loudLogger(Class loggerClass) { + Logger l = Logger.getLogger(loggerClass.getName()); + l.setLevel(Level.INFO); + } + + static void quietLogger(Class loggerClass) { + Logger l = Logger.getLogger(loggerClass.getName()); + l.setLevel(Level.OFF); + } + public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); } From b2ce3eedb9f3d9cee82525c9f29c2d1f42ba58c7 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 15 Jan 2024 15:09:49 +1000 Subject: [PATCH 06/39] Fix build error for CoreML Split op (#19099) ### Description The `split` input of the Split op is int64_t. Fixing that resolves a type mismatch build error on Windows when CoreML is enabled (for debugging the partitioning code). ### Motivation and Context Fix build error --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../core/providers/coreml/builders/impl/split_op_builder.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 815f68128ffaf..56c87c883156b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -139,8 +139,8 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } const auto& splits_tensor = *initializers.at(input_defs[1]->Name()); Initializer unpacked_tensor(splits_tensor); - auto splits_span = unpacked_tensor.DataAsSpan(); - int sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), 0); + auto splits_span = unpacked_tensor.DataAsSpan(); + int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0}); if (sum_of_splits != split_dims_at_axis) { LOGS(logger, VERBOSE) << "Mismatch between the sum of 'split'. Expected: " << split_dims_at_axis From 922a2f00e3855fdc9852ed1bfe7f6f0a88e40a24 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 15 Jan 2024 14:37:22 +0800 Subject: [PATCH 07/39] Extend timeout in Nuget-CUDA-Packaging-Pipeline (#19138) ### Description ### Motivation and Context Linux_GPU_x64 job in the pipeline has been canceled due to timeout since 0112. --- .../azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index fbdd67bb5de22..48a6e0e8529e6 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -15,7 +15,7 @@ stages: - job: workspace: clean: all - timeoutInMinutes: 120 + timeoutInMinutes: 150 pool: 'Onnxruntime-Linux-GPU' variables: - name: CUDA_VERSION_MAJOR From a97199c62de4a96939624ba511313d0f81014f56 Mon Sep 17 00:00:00 2001 From: Ben Niu Date: Mon, 15 Jan 2024 14:29:19 -0800 Subject: [PATCH 08/39] Fix Arm64EC build for test_q4qdq.cpp (#18523) ### Description Fix ifdef guards in test_q4qdq.cpp to exclude code blocks intended only for native x64 compilation instead of x64 + Arm64EC. --- onnxruntime/test/mlas/unittest/test_q4qdq.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_q4qdq.cpp b/onnxruntime/test/mlas/unittest/test_q4qdq.cpp index 955c3b1201989..c317395bee970 100644 --- a/onnxruntime/test/mlas/unittest/test_q4qdq.cpp +++ b/onnxruntime/test/mlas/unittest/test_q4qdq.cpp @@ -19,7 +19,7 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -#if (defined(_M_AMD64) || defined(__x86_64__)) +#if ((defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__)) /** * @brief For testing purpose, @@ -93,7 +93,7 @@ class MlasQ4dqTest : public MlasTestBase { << K << "] QType: " << qtype; } -#if (defined(_M_AMD64) || defined(__x86_64__)) +#if ((defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__)) /* Test MlasBlkQ4DequantSgemmPackB, make sure we can reuse SGEMM kernel as it rearrange B the same way as sgemm pack B*/ const size_t AlignedN = (N + 15) & ~15; From 191525301f2b30fa4ff7337cd40c5f3f94834488 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 15 Jan 2024 17:42:50 -0500 Subject: [PATCH 09/39] [java] Updating TensorInfo so it contains the named dimensions (#18962) ### Description The Java `TensorInfo` object which is used to describe a tensor's shape, along with the input and output placeholders for a model couldn't show any symbolic/named dimensions in that tensor. Now this information is stored in Java strings on construction and included in the toString. ### Motivation and Context Setting symbolic dimensions required external information in Java, the names were not discoverable from within the API. --- .../main/java/ai/onnxruntime/TensorInfo.java | 63 ++++++++++++++++--- java/src/main/native/OrtJniUtil.c | 26 ++++++-- .../java/ai/onnxruntime/InferenceTest.java | 6 ++ 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 69ccb954e8afe..1c21387b50455 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -7,6 +7,7 @@ import java.lang.reflect.Array; import java.nio.Buffer; import java.util.Arrays; +import java.util.stream.Collectors; /** Describes an {@link OnnxTensor}, including it's size, shape and element type. */ public class TensorInfo implements ValueInfo { @@ -159,6 +160,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { /** The shape of the tensor. */ final long[] shape; + /** The names of the unbound dimensions. */ + final String[] dimensionNames; + + /** If there are non-empty dimension names */ + private final boolean hasNames; + /** The Java type of this tensor. */ public final OnnxJavaType type; @@ -177,6 +184,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { */ TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) { this.shape = shape; + this.dimensionNames = new String[shape.length]; + Arrays.fill(dimensionNames, ""); + this.hasNames = false; this.type = type; this.onnxType = onnxType; this.numElements = elementCount(shape); @@ -188,10 +198,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { *

Called from JNI. * * @param shape The tensor shape. + * @param names The dimension names. * @param typeInt The native type int. */ - TensorInfo(long[] shape, int typeInt) { + TensorInfo(long[] shape, String[] names, int typeInt) { this.shape = shape; + this.dimensionNames = names; + boolean hasNames = false; + for (String s : names) { + if (!s.isEmpty()) { + hasNames = true; + break; + } + } + this.hasNames = hasNames; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); this.numElements = elementCount(shape); @@ -206,15 +226,42 @@ public long[] getShape() { return Arrays.copyOf(shape, shape.length); } + /** + * Get a copy of the tensor's named dimensions. + * + * @return A copof the tensor's named dimensions. + */ + public String[] getDimensionNames() { + return Arrays.copyOf(dimensionNames, dimensionNames.length); + } + @Override public String toString() { - return "TensorInfo(javaType=" - + type.toString() - + ",onnxType=" - + onnxType.toString() - + ",shape=" - + Arrays.toString(shape) - + ")"; + String output = + "TensorInfo(javaType=" + + type.toString() + + ",onnxType=" + + onnxType.toString() + + ",shape=" + + Arrays.toString(shape); + if (hasNames) { + output = + output + + ",dimNames=[" + + Arrays.stream(dimensionNames) + .map( + a -> { + if (a.isEmpty()) { + return "\"\""; + } else { + return a; + } + }) + .collect(Collectors.joining(",")) + + "]"; + } + output = output + ")"; + return output; } /** diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 879ba8a310618..7b26291581395 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT if (code != ORT_OK) { return NULL; } - //printf("numDim %d\n",numDim); int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); if (code != ORT_OK) { @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT free(dimensions); dimensions = NULL; + // Create the string array for the names. + const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim); + if (dimensionNames == NULL) { + throwOrtException(jniEnv, 1, "Not enough memory"); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim)); + if (code != ORT_OK) { + // extraction failed, exception has been thrown, return to Java. + free(dimensionNames); + return NULL; + } + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL); + for (size_t i = 0; i < numDim; i++) { + jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName); + } + free(dimensionNames); + // Create the TensorInfo object static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); - jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V"); - //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); - jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt); + jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V"); + jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt); return tensorInfo; } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f6f9da1829402..7fef2dc784b7b 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -590,6 +590,12 @@ public void testSymbolicDimensionAssignment() throws OrtException { Map infoMap = session.getInputInfo(); TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); assertArrayEquals(new long[] {-1, 2}, aInfo.shape); + assertEquals(2, aInfo.dimensionNames.length); + assertEquals("n", aInfo.dimensionNames[0]); + assertEquals("", aInfo.dimensionNames[1]); + TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo(); + assertEquals(1, bInfo.dimensionNames.length); + assertEquals("m", bInfo.dimensionNames[0]); } } // Check that when the options are assigned it overrides the symbolic dimension From 1150b1f81ea7e46a840212acf194422af7f764a3 Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 16 Jan 2024 08:57:37 +0800 Subject: [PATCH 10/39] ORTModule memory improvement (#18924) ## Dependency https://github.com/microsoft/onnxruntime/pull/19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with https://github.com/microsoft/onnxruntime/pull/8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation: ![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe) differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic. --- docs/ORTModule_Training_Guidelines.md | 10 + onnxruntime/core/framework/execution_frame.cc | 3 +- .../python/tools/symbolic_shape_infer.py | 9 +- .../ortmodule/_graph_execution_manager.py | 109 ++++++-- .../ortmodule/_mem_efficient_grad_mgmt.py | 246 ++++++++++++++++++ .../python/training/ortmodule/_onnx_models.py | 1 + .../training/ortmodule/_pythonop_helper.py | 240 +++++++++++++++++ .../training/ortmodule/_training_manager.py | 27 +- .../python/training/ortmodule/options.py | 12 + .../utils/hooks/_zero_offload_subscriber.py | 2 +- .../python/orttraining_test_ortmodule_api.py | 2 +- .../torch_custom_function_kernel_base.cc | 5 +- ...-linux-nightly-ortmodule-test-pipeline.yml | 2 +- 13 files changed, 638 insertions(+), 30 deletions(-) create mode 100644 orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py create mode 100644 orttraining/orttraining/python/training/ortmodule/_pythonop_helper.py diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index bede16204d420..91057d3dfb120 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -293,6 +293,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` +### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. + + ```bash + export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable + export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index d9c49dc6bea1d..8c08152986cf6 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -223,7 +223,8 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span& initializers, const std::function& is_initializer_sparse_func, gsl::span fetches) { - ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size()); + ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ", + feed_mlvalue_idxs.size()); ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size()); // Need this for sparse conversions in host memory diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index e90eea553c185..ef4c4ae906243 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2415,9 +2415,9 @@ def _infer_RotaryEmbedding(self, node): # noqa: N802 def _infer_PythonOp(self, node): # noqa: N802 output_tensor_types = get_attribute(node, "output_tensor_types") - assert output_tensor_types + assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." output_tensor_ranks = get_attribute(node, "output_tensor_ranks") - assert output_tensor_ranks + assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." from onnxruntime.capi._pybind_state import get_shape_inference_function @@ -2438,7 +2438,10 @@ def _infer_PythonOp(self, node): # noqa: N802 input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type input_dtypes.append(input_dtype) output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) - assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1) + assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( + f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " + f"but expected {len(node.output) - 1} outputs." + ) for i in range(len(node.output) - 1): output_index = i + 1 vi = self.known_vi_[node.output[output_index]] diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 853eab61b4bd6..779b6bfe50422 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -36,7 +36,6 @@ from ._io import _FlattenedModule, _InputInfo from ._runtime_inspector import RuntimeInspector from ._utils import check_function_has_param, get_rank -from ._zero_stage3_compatibility import stage3_export_context from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -148,6 +147,10 @@ def __init__( configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) + # Will be reset everytime we re-initialize the graph builder. + # Be noted, we will never enable this feature for inference mode. + self._mem_efficient_grad_management_is_enabled = False + def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -388,6 +391,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" try: + from ._zero_stage3_compatibility import stage3_export_context + with torch.no_grad(), stage3_export_context(self._runtime_options.enable_zero_stage3_support, self): required_export_kwargs = { "input_names": self._input_info.names, @@ -496,9 +501,35 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati def _initialize_graph_builder(self): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" + self._mem_efficient_grad_management_is_enabled = ( + self._export_mode != torch.onnx.TrainingMode.EVAL + and self._runtime_options.enable_mem_efficient_grad_management + ) + + # We post process the exported model because the trainable parame might be changed, so this path is + # re-triggered by reinitialize_graph_builder. + exported_model = copy.deepcopy(self._onnx_models.exported_model) + self._onnx_models.processed_exported_model = exported_model + + if self._mem_efficient_grad_management_is_enabled: + from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + + # Override the options if model is not modified. + ( + self._mem_efficient_grad_management_is_enabled, + exported_model, + ) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters()) + + if self._runtime_options.run_symbolic_shape_infer: + exported_model = SymbolicShapeInference.infer_shapes( + exported_model, auto_merge=True, guess_output_rank=True + ) + # All initializer names along with user inputs are a part of the onnx graph inputs # since the onnx model was exported with the flag keep_initializers_as_inputs=True - onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} + # We need to use the raw exported model here since the graph inputs include both user inputrs and + # parameters. + onnx_initializer_names = {p.name for p in exported_model.graph.input} # TODO: PyTorch exporter bug: changes the initializer order in ONNX model initializer_names = [ @@ -521,6 +552,13 @@ def _initialize_graph_builder(self): # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + + if self._mem_efficient_grad_management_is_enabled: + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization @@ -532,12 +570,23 @@ def _initialize_graph_builder(self): # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config) + self._graph_builder.initialize(exported_model.SerializeToString(), grad_builder_config) + + raw_onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} + + raw_initializer_names = [ + name for name, _ in self._flattened_module.named_parameters() if name in raw_onnx_initializer_names + ] + raw_initializer_names_to_train = [ + name + for name, param in self._flattened_module.named_parameters() + if param.requires_grad and name in raw_onnx_initializer_names + ] # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train # a set (unordered_set in the backend) that does not require a copy on each reference. - self._graph_initializer_names = set(initializer_names) - self._graph_initializer_names_to_train = set(initializer_names_to_train) + self._graph_initializer_names = set(raw_initializer_names) + self._graph_initializer_names_to_train = set(raw_initializer_names_to_train) # Initializers can be cached and used since they are expected not to be re-instantiated # between forward calls. @@ -588,7 +637,7 @@ def _enable_conditional_optimizations( # Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density. if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density: self._runtime_inspector.enable_input_inspector( - self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names + self._onnx_models.processed_exported_model, self._graph_builder.get_graph_info().user_input_names ) if self._runtime_options.enable_sparse_optimizer: @@ -596,11 +645,21 @@ def _enable_conditional_optimizations( inputs, kwargs ) - if self._runtime_options.enable_zero_stage3_support: + if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: self._append_pull_weight_trigger_as_input(kwargs, detected_device) + param_to_append_as_onnx_graph_inputs = [] + if self._mem_efficient_grad_management_is_enabled: + from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + + param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( + self._flattened_module.named_parameters(), self._onnx_models.exported_model + ) + else: + param_to_append_as_onnx_graph_inputs = self._graph_initializers + _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( - self._graph_initializers, + param_to_append_as_onnx_graph_inputs, self._graph_builder.get_graph_info().user_input_names, self._input_info, self._flattened_module.named_buffers(), @@ -632,19 +691,31 @@ def _enable_conditional_optimizations( self._runtime_inspector.disable_input_inspector() def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): - from ._zero_stage3_compatibility import ( - STAGE3_PULL_WEIGHT_TRIGGER_NAME, - STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, - STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, - ) + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import ( + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) - kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( - STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), - device=device, - ).requires_grad_() + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + + if self._mem_efficient_grad_management_is_enabled: + from ._mem_efficient_grad_mgmt import ( + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) - return kwargs + kwargs[MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME] = torch.zeros( + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() def _log_feature_stats(self): if get_rank() != 0: diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py new file mode 100644 index 0000000000000..4663afdaa94a0 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -0,0 +1,246 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import ctypes + +import torch +from onnx import ModelProto, NodeProto, TensorProto, helper + +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype + +from ._pythonop_helper import make_pythonop_node + +MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME = "mem_efficient_pull_weight_trigger" +MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT +MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE = [1] + + +def get_params_connected_to_pull_param_trigger( + named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto +): + # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. + onnx_initializer_names = {p.name for p in exported_model.graph.input} + return {k: v for k, v in named_params if v.requires_grad and k in onnx_initializer_names} + + +def get_params_not_connected_to_pull_param_trigger( + named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto +): + # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. + onnx_initializer_names = {p.name for p in exported_model.graph.input} + return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names] + + +def post_processing_enable_mem_efficient_training( + exported_model: ModelProto, + named_params: dict[str, torch.nn.parameter.Parameter], +) -> tuple[bool, ModelProto]: + """This function is used to enable zero stage3 compatibility. + + Args: + exported_model (ModelProto): The exported model. + named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map. + + Returns: + tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified. + + """ + trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model) + if len(trainable_named_params) == 0: + return False, exported_model + + # Create weight retrieving function using trainable_named_params. + param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params) + param_retrieve_func_class = _create_param_retrieval_function(trainable_named_params) + + def _get_param_pull_trigger_name(param_name: str) -> str: + return f"pull_{param_name}" + + # Create weight retrieving PythonOp. + inputs = [ + helper.make_tensor_value_info( + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, # Use the same data type with output for the input + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) + ] + + outputs = [ + helper.make_tensor_value_info( + _get_param_pull_trigger_name(pname), + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) + for pname in trainable_named_params + ] + + weight_pull_node = make_pythonop_node( + "weight_pull_trigger", + inputs, + outputs, + param_pull_trigger_func_class, + training_mode=1, + safe_run_mode=0, + ) + + graph_inputs_to_remove = [] + input_offset = 0 + for graph_input in exported_model.graph.input: + if graph_input.name not in trainable_named_params: + continue + + graph_inputs_to_remove.append(graph_input) + + # Create the param retrieval function for this parameter. + node_inputs = [ + helper.make_tensor_value_info( + _get_param_pull_trigger_name(graph_input.name), + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ), + graph_input.name, # Second param is a string, which represents the param_name + ] + + node_outputs = [ + helper.make_tensor_value_info( + graph_input.name, # output use the same name as weight + int(pytorch_type_to_onnx_dtype(trainable_named_params[graph_input.name].dtype)), + list(trainable_named_params[graph_input.name].shape), + ), + ] + + new_node = make_pythonop_node( + f"weight_retrieval_{graph_input.name}", + node_inputs, + node_outputs, + param_retrieve_func_class, + training_mode=1, + safe_run_mode=0, + ) + exported_model.graph.node.insert(input_offset, new_node) + input_offset += 1 + + # Delete exported_model.graph.input + names_to_remove = [input.name for input in graph_inputs_to_remove] + value_infos_to_remove = [ + value_info for value_info in exported_model.graph.value_info if value_info.name in names_to_remove + ] + for value_info in value_infos_to_remove: + exported_model.graph.value_info.remove(value_info) + + for input_to_remove in graph_inputs_to_remove: + exported_model.graph.input.remove(input_to_remove) + + # Re-order graph input to make sure the weight pull trigger is the first user input. + offset = 0 # Find the first trainable param, and insert the new input before it, as part of user inputs. + for input in exported_model.graph.input: + if input.name in named_params: + break + offset += 1 + exported_model.graph.input.insert(offset, inputs[0]) + exported_model.graph.node.insert(0, weight_pull_node) + + return True, exported_model + + +_PARAM_FUNCTION_INDEX = [0] + + +def _create_param_trigger_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]): + """This function is used to create a weight retrieving function using trainable_named_params.""" + + @staticmethod + def forward(ctx, weight_in_trigger): + params = list(trainable_named_params.values()) + ctx.params = params + ctx.dtype = weight_in_trigger.dtype + ctx.device = weight_in_trigger.device + ctx.shape = weight_in_trigger.shape + return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params) + + @staticmethod + def backward(ctx, *grad_outputs): + return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype) + + @staticmethod + def infer_shape( + node: NodeProto, + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: + param_count = len(trainable_named_params.values()) + tensor_output_shapes = [ + tensor_input_shapes[0], + ] * param_count + tensor_output_dtypes = [ + tensor_input_dtypes[0], + ] * param_count + + return tensor_output_shapes, tensor_output_dtypes + + _PARAM_FUNCTION_INDEX[0] += 1 + + return type( + f"ParamTriggerFunction_{_PARAM_FUNCTION_INDEX[0]}", + (torch.autograd.Function,), + { + "forward": forward, + "backward": backward, + "infer_shape": infer_shape, + }, + ) + + +def _create_param_retrieval_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]): + """This function is used to create a weight retrieving function using trainable_named_params.""" + + @staticmethod + def forward(ctx, param_trigger, param_name): + ctx.param_name = param_name + ctx.dtype = param_trigger.dtype + ctx.device = param_trigger.device + ctx.shape = param_trigger.shape + return trainable_named_params[param_name] + + @staticmethod + def backward(ctx, *grad_outputs): + trainable_named_params[ctx.param_name].backward(grad_outputs[0]) + return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype), None + + @staticmethod + def infer_shape( + node: NodeProto, + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: + input_pointer_scalars_attr_name = "input_pointer_scalars" + found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] + + assert len(found) == 1 + input_pointer_scalars = found[0].ints + + # Restore the nn.Module from the pointer. + param_name = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value + + tensor_output_shapes = [ + list(trainable_named_params[param_name].shape), + ] + tensor_output_dtypes = [ + int(pytorch_type_to_onnx_dtype(trainable_named_params[param_name].dtype)), + ] + + return tensor_output_shapes, tensor_output_dtypes + + return type( + f"ParamRetrievalFunction_{_PARAM_FUNCTION_INDEX[0]}", + (torch.autograd.Function,), + { + "forward": forward, + "backward": backward, + "infer_shape": infer_shape, + }, + ) diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index d687bc24384ed..a0001a2f201f1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -33,6 +33,7 @@ class ONNXModels: """ exported_model: Optional[onnx.ModelProto] = None + processed_exported_model: Optional[onnx.ModelProto] = None optimized_model: Optional[onnx.ModelProto] = None def save_exported_model(self, path, name_prefix, export_mode): diff --git a/orttraining/orttraining/python/training/ortmodule/_pythonop_helper.py b/orttraining/orttraining/python/training/ortmodule/_pythonop_helper.py new file mode 100644 index 0000000000000..32a564b27acd0 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_pythonop_helper.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import inspect + +import onnx +import torch + +from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function + +from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary +from ._utils import get_fully_qualified_class_name + +PYTHON_OP_DOMAIN = "com.microsoft" +PYTHON_OP_TYPE = "PythonOp" + +PYTHON_OP_ATTRIBUTE_FUNC_NAME = "func_name" +PYTHON_OP_ATTRIBUTE_SAFE_RUN_MODE = "safe_run_mode" +PYTHON_OP_ATTRIBUTE_TRAINING_MODE = "training_mode" + + +def set_safe_run_mode(model: onnx.ModelProto, allowed_unsafe_run_python_op_names: list[str]) -> onnx.ModelProto: + # Update safe_run_mode attribute for PythonOp. + for node in model.graph.node: + if node.domain == PYTHON_OP_DOMAIN and node.op_type == PYTHON_OP_TYPE: + func_name = None + safe_run_mode_attr = None + for attr in node.attribute: + if attr.name == PYTHON_OP_ATTRIBUTE_FUNC_NAME: + func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if attr.name == PYTHON_OP_ATTRIBUTE_SAFE_RUN_MODE: + safe_run_mode_attr = attr + + if func_name in allowed_unsafe_run_python_op_names: + if safe_run_mode_attr: + node.attribute.remove(safe_run_mode_attr) + node.attribute.append(onnx.helper.make_attribute(PYTHON_OP_ATTRIBUTE_SAFE_RUN_MODE, 0)) + + return model + + +_PYTHON_OP_INCRE_INDEX = [0] + + +def make_pythonop_node( + name_prefix: str, + inputs: list[ + onnx.ValueInfoProto | int | bool | float | tuple[int, ...] | tuple[bool, ...] | tuple[float, ...] | object + ], + outputs: list[onnx.ValueInfoProto], + func_class: torch.autograd.Function, + training_mode: int, + safe_run_mode: int, +) -> onnx.NodeProto: + assert issubclass(func_class, torch.autograd.Function), "func_class must be a subclass of torch.autograd.Function." + + assert len(inputs) > 0, f"inputs must not be empty for function {func_class}." + assert len(outputs) > 0, f"outputs must not be empty for function {func_class}." + + all_input_parameters: list[inspect.Parameter] = list(inspect.signature(func_class.forward).parameters.values()) + + # Remove the first parameter (ctx) from inspected parameter list. + assert len(inputs) == len(all_input_parameters) - 1, ( + f"The number of inputs ({len(inputs)}) must match the number of parameters " + f"({len(all_input_parameters) - 1}) of the forward function." + ) + + func_full_qual_name = get_fully_qualified_class_name(func_class) + + input_tensor_types = [] + input_tensor_ranks = [] + + input_bool_scalars = [] + input_bool_scalar_positions = [] + + input_int_scalars = [] + input_int_scalar_positions = [] + + input_float_scalars = [] + input_float_scalar_positions = [] + + input_bool_tuples = [] + input_bool_tuple_positions = [] + input_bool_tuple_begins = [] + + input_int_tuples = [] + input_int_tuple_positions = [] + input_int_tuple_begins = [] + + input_float_tuples = [] + input_float_tuple_positions = [] + input_float_tuple_begins = [] + + input_pointer_scalars = [] + input_pointer_scalar_positions = [] + + tensor_args = [] + debug_comment = "" + cconv = "" + # Encode inputs to torch.autograd.Function. + for i, arg in enumerate(inputs): + if isinstance(arg, onnx.ValueInfoProto): + # Got a tensor variable. + tensor_args.append(arg.name) + input_tensor_types.append(arg.type.tensor_type.elem_type) + input_tensor_ranks.append(len(arg.type.tensor_type.shape.dim)) + cconv += "d" + continue + + cconv += "c" + + # Got a non-tensor variable. + if isinstance(arg, float): + # A float. + input_float_scalar_positions.append(i) + input_float_scalars.append(arg) + continue + # bool check MUST be before int check since bool is a subclass of int + elif isinstance(arg, bool): + # A bool. + input_bool_scalar_positions.append(i) + input_bool_scalars.append(int(arg)) + continue + elif isinstance(arg, int): + # A int. + input_int_scalar_positions.append(i) + input_int_scalars.append(arg) + continue + + is_bool_tuple = False + is_int_tuple = False + is_float_tuple = False + if isinstance(arg, tuple) and len(arg) > 0: + # bool check MUST be before int check since bool is a subclass of int. + is_bool_tuple = all(isinstance(ele, bool) for ele in arg) + is_int_tuple = not is_bool_tuple and all(isinstance(ele, int) for ele in arg) + is_float_tuple = not is_bool_tuple and not is_int_tuple and all(isinstance(ele, float) for ele in arg) + + # Only support tuple of bool, int or float, for other types, handle it as a pointer. + if is_bool_tuple: + # A tuple of bool. + input_bool_tuple_positions.append(i) + input_bool_tuple_begins.append(len(input_bool_tuples)) + input_bool_tuples.extend([int(ele) for ele in arg]) + continue + elif is_int_tuple: + # A tuple of ints. + input_int_tuple_positions.append(i) + input_int_tuple_begins.append(len(input_int_tuples)) + input_int_tuples.extend(list(arg)) + continue + elif is_float_tuple: + # A tuple of floats. + input_float_tuple_positions.append(i) + input_float_tuple_begins.append(len(input_float_tuples)) + input_float_tuples.extend(list(arg)) + continue + + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + + is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) + if is_inspect_activation and isinstance(arg, str): + # _InspectActivation is a special case where the first argument is a string + # that is used to determine the activation name to be inspected. + debug_comment += arg + + # All other inputs are accessed via "pointers". + input_pointer_scalar_positions.append(i) + input_pointer_scalars.append(id(arg)) + + # For pointer (for example, ProcessGroup passed to PythonOp) needed for PythonOp execution, + # we append it into a global store to hold a reference (in case it is released after module exported). + register_miscellaneous_const_input(arg) + + output_tensor_types = [] + output_tensor_ranks = [] + for arg in outputs: + output_tensor_types.append(arg.type.tensor_type.elem_type) + output_tensor_ranks.append(len(arg.type.tensor_type.shape.dim)) + + attrs = { + "func_name": func_full_qual_name, + "input_convention": cconv, + "input_tensor_types": input_tensor_types, + "input_tensor_ranks": input_tensor_ranks, + "output_tensor_types": output_tensor_types, + "output_tensor_ranks": output_tensor_ranks, + "training_mode": training_mode, + "safe_run_mode": safe_run_mode, + "comment": debug_comment, + } + + if len(input_bool_scalars) > 0: + attrs["input_bool_scalars"] = input_bool_scalars + attrs["input_bool_scalar_positions"] = input_bool_scalar_positions + if len(input_int_scalars) > 0: + attrs["input_int_scalars"] = input_int_scalars + attrs["input_int_scalar_positions"] = input_int_scalar_positions + if len(input_float_scalars) > 0: + attrs["input_float_scalars"] = input_float_scalars + attrs["input_float_scalar_positions"] = input_float_scalar_positions + if len(input_bool_tuples) > 0: + attrs["input_bool_tuples"] = input_bool_tuples + attrs["input_bool_tuple_positions"] = input_bool_tuple_positions + attrs["input_bool_tuple_begins"] = input_bool_tuple_begins + if len(input_int_tuples) > 0: + attrs["input_int_tuples"] = input_int_tuples + attrs["input_int_tuple_positions"] = input_int_tuple_positions + attrs["input_int_tuple_begins"] = input_int_tuple_begins + if len(input_float_tuples) > 0: + attrs["input_float_tuples"] = input_float_tuples + attrs["input_float_tuple_positions"] = input_float_tuple_positions + attrs["input_float_tuple_begins"] = input_float_tuple_begins + if len(input_pointer_scalars) > 0: + attrs["input_pointer_scalars"] = input_pointer_scalars + attrs["input_pointer_scalar_positions"] = input_pointer_scalar_positions + + # Register function with class names. + register_torch_autograd_function(func_full_qual_name, func_class) + + register_custom_function_schema_supplementary(func_class) + + _PYTHON_OP_INCRE_INDEX[0] += 1 + node_name = f"{name_prefix}_{_PYTHON_OP_INCRE_INDEX[0]}" + + node = onnx.helper.make_node( + PYTHON_OP_TYPE, + tensor_args, + [f"{node_name}_ctx", *[output.name for output in outputs]], + node_name, # node name + "", + PYTHON_OP_DOMAIN, + **attrs, + ) + + return node diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5b2c673ce94cb..cc533e549db92 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -310,11 +310,22 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() - if self._runtime_options.enable_zero_stage3_support: + if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: self._append_pull_weight_trigger_as_input(kwargs, self._device) + param_to_append_as_onnx_graph_inputs = [] + if self._mem_efficient_grad_management_is_enabled: + from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + + param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( + self._flattened_module.named_parameters(), self._onnx_models.exported_model + ) + + else: + param_to_append_as_onnx_graph_inputs = self._graph_initializers + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( - self._graph_initializers, + param_to_append_as_onnx_graph_inputs, self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), @@ -492,10 +503,20 @@ def _reinitialize_graph_builder(self, input_info: _InputInfo): if param.requires_grad and name in self._graph_initializer_names } + if self._mem_efficient_grad_management_is_enabled: + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Remove the inputs we added during model post-processing. + existing_require_grad_names = [ + n for n in self._input_info.require_grad_names if n != MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + ] + else: + existing_require_grad_names = self._input_info.require_grad_names + # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad if ( - input_info.require_grad_names != self._input_info.require_grad_names + input_info.require_grad_names != existing_require_grad_names or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train ): self._input_info = input_info diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index bfa38efb349ae..df3b078788d16 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -308,6 +308,9 @@ def __init__(self, logger: Logger): # Experimental features. self.enable_zero_stage3_support = False # Once enabled, cannot be disabled. + # We disable memory efficient grad management by default, will enable once it's fully validated. + self.enable_mem_efficient_grad_management = False + self.deepcopy_before_model_export = True # Override the feature config if it exists in os env. @@ -397,6 +400,15 @@ def _override_from_env_vars(self): if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1: self.enable_zero_stage3_support = True + if "ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT" in os.environ: + enable_grad_mgmt = int(os.getenv("ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT")) + self.enable_mem_efficient_grad_management = enable_grad_mgmt == 1 and self.enable_custom_autograd_function + if not self.enable_custom_autograd_function and enable_grad_mgmt == 1: + self._logger.warning( + "ORTModule optimization for memory efficient gradient management cannot be enabled " + "because PyTorch custom autograd function support is disabled." + ) + if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ: self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1 diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index e6004319ef5ea..d4b9768116e92 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -289,7 +289,7 @@ def backward(ctx, *grads): raise RuntimeError(f"param {p} has no grad, this should not happen.") # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" - # p.backward(g) + p.backward(g) # At this point, the **real** param grads are already updated, the following grads are only used for # completing the full backward propagation, will not affect parameter updates. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f944d8bc5ef42..938d33cc9a714 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -684,7 +684,7 @@ def test_input_requires_grad_saved(device): model = ORTModule(model) x = torch.randn(N, D_in, device=device, requires_grad=True) + 1 model(x) - assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ["input1"] + assert "input1" in model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names @pytest.mark.parametrize("device", ["cuda", "cpu"]) diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 3c5ac56cb139a..0a98cd959dd36 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -385,7 +385,10 @@ void PythonOpGradBase::RunBackward(OpKernelContext* context, void PythonOpGradBase::SetOutputs(OpKernelContext* context, std::vector& returned_ortvalues) const { auto* ctx_internal = reinterpret_cast(context); - ORT_ENFORCE(output_convention_.size() == returned_ortvalues.size(), "backward output count mismatch."); + ORT_ENFORCE(output_convention_.size() == returned_ortvalues.size(), "backward output count mismatch. Expected ", + output_convention_.size(), ", but got ", returned_ortvalues.size(), + ". Please check the backward function return same number of outputs as forward function's input for ", + name_, "."); int tensor_output_index = 0; for (size_t i = 0; i < returned_ortvalues.size(); ++i) { if (output_convention_[i] == 'd') { diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml index 7824bf2203efe..e13ef9160bed3 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml @@ -24,7 +24,7 @@ jobs: --volume $(Build.SourcesDirectory)/orttraining/orttraining/test/python:/onnxruntime_src \ --volume $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly:/requirements_torch_nightly \ ptebic.azurecr.io/internal/aifx/acpt/nightly-ubuntu-cuda-torch-dev \ - bash -c "python3 -m pip install -r /requirements_torch_nightly/requirements.txt && python3 -m pytest -sv /onnxruntime_src/orttraining_test_ortmodule_api.py" + bash -c "python3 -m pip install -r /requirements_torch_nightly/requirements.txt && ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 python3 -m pytest -sv /onnxruntime_src/orttraining_test_ortmodule_api.py && ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 python3 -m pytest -sv /onnxruntime_src/orttraining_test_ortmodule_api.py" displayName: 'Run ORTModule Tests' condition: succeededOrFailed() timeoutInMinutes: 120 From 9f87c5c41d50fdcf30ce439617c708c964d8a050 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Mon, 15 Jan 2024 17:10:58 -0800 Subject: [PATCH 11/39] Fix build error due to merge with DML adapter enumeration macro defined (#19121) ### Description Fix build error when ENABLE_NPU_ADAPTER_ENUMERATION is defined ### Motivation and Context --- onnxruntime/core/providers/dml/dml_provider_factory.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 73a068f3e1de2..b2688094a6d78 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -329,7 +329,6 @@ static std::optional ParseFilter(const ProviderOptions& prov static const std::string Any = "any"; static const std::string Gpu = "gpu"; #ifdef ENABLE_NPU_ADAPTER_ENUMERATION - static const std::string Any = "any"; static const std::string Npu = "npu"; #endif From 9dee543bedaed8419957afaed3a64b1ab5fa3a21 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 15 Jan 2024 18:40:38 -0800 Subject: [PATCH 12/39] fix gemm beta for fp16 (#19153) per onnx spec beta is always fp32 so we need to cast it --- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 30754c84413b7..a0d4021516bf7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -100,8 +100,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${calculateAlpha} ${(() => { if (c != null) { - return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += uniforms.beta * ${ - c.getByOffset('cOffset')};`; + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ + dataType}(uniforms.beta) * ${c.getByOffset('cOffset')};`; } return ''; })()} From 1bab98988b4e7b6d33be0e672fce361ccbb1d397 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Tue, 16 Jan 2024 10:44:25 +0800 Subject: [PATCH 13/39] [WebNN EP] Fixed bug in int8 data type processing (#19134) --- .../core/providers/webnn/builders/helper.cc | 5 ++++- .../core/providers/webnn/builders/helper.h | 4 +++- .../webnn/builders/impl/cast_op_builder.cc | 4 +++- .../webnn/builders/impl/conv_op_builder.cc | 4 +++- .../core/providers/webnn/builders/model.cc | 18 ++++++++++++++---- .../providers/webnn/builders/model_builder.cc | 11 +++++++++-- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index a55145b0125a7..ef7c10dae580c 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -166,11 +166,14 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { // TODO: Remove legacy "type" once all browsers implement the new "dataType". switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: desc.set("type", emscripten::val("uint8")); desc.set("dataType", emscripten::val("uint8")); return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + desc.set("type", emscripten::val("int8")); + desc.set("dataType", emscripten::val("int8")); + return true; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: desc.set("type", emscripten::val("float16")); desc.set("dataType", emscripten::val("float16")); diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index f3fc7ec5cc4cd..85dafcaf66575 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -101,10 +101,12 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va } switch (tensor.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: scalar = emscripten::val{*reinterpret_cast(unpacked_tensor.data())}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + scalar = emscripten::val{*reinterpret_cast(unpacked_tensor.data())}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: scalar = emscripten::val{MLFloat16::FromBits(*reinterpret_cast(unpacked_tensor.data())).ToFloat()}; break; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 062f1c56061a9..3d961e4589c2e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -39,10 +39,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::string operand_type; switch (to_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: operand_type = "uint8"; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + operand_type = "int8"; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: operand_type = "float16"; break; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 123a9cc016515..ceacb7c2b38a3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -184,10 +184,12 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, size_t element_size{0}; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: element_size = sizeof(uint8_t); break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + element_size = sizeof(int8_t); + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: element_size = sizeof(uint16_t); break; diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index a4031fd9350c5..eaf549ef4e072 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -33,11 +33,14 @@ Status Model::Predict(const InlinedHashMap& inputs, emscripten::val view = emscripten::val::undefined(); switch (tensor.tensor_info.data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + static_cast(tensor.buffer))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; @@ -90,11 +93,14 @@ Status Model::Predict(const InlinedHashMap& inputs, emscripten::val view = emscripten::val::undefined(); switch (tensor.tensor_info.data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + static_cast(tensor.buffer))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; @@ -168,10 +174,12 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = input_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements)); break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + wnn_inputs_.set(input, emscripten::val::global("Int8Array").new_(num_elements)); + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: wnn_inputs_.set(input, emscripten::val::global("Uint16Array").new_(num_elements)); break; @@ -201,10 +209,12 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = output_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements)); break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + wnn_outputs_.set(output, emscripten::val::global("Int8Array").new_(num_elements)); + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: wnn_outputs_.set(output, emscripten::val::global("Uint16Array").new_(num_elements)); break; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 4e0c83db8b127..cf8a0e23db43b 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -160,12 +160,16 @@ Status ModelBuilder::RegisterInitializers() { } switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: desc.set("type", emscripten::val("uint8")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + desc.set("type", emscripten::val("int8")); + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; @@ -318,11 +322,14 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t), reinterpret_cast(dest))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t), + reinterpret_cast(dest))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t), reinterpret_cast(dest))}; From 8d4369b77ef8567653db3e247bbb2f48889fc457 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Mon, 15 Jan 2024 19:04:41 -0800 Subject: [PATCH 14/39] Update DirectML nuget version to 1.13.1 (#19122) ### Description Update DML version to 1.13.1 ### Motivation and Context --- .pipelines/nuget_config/x64/packages.config | 2 +- .pipelines/nuget_config/x86/packages.config | 2 +- cmake/external/dml.cmake | 2 +- packages.config | 2 +- tools/nuget/generate_nuspec_for_native_nuget.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 2583e0d1b2ead..b862dec5e1c87 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 5ca659941c159..c348dd3e9cdad 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index dfd9ad120eb98..ae7e6d3801a64 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.1) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/packages.config b/packages.config index b67219d6d6913..e5b134d99dd89 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 56e50750ac153..09fe99d36cc34 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -219,7 +219,7 @@ def add_common_dependencies(xml_text, package_name, version): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("") From c92f72ebebf5f4a1e63b726e6e5cec1a47250bb5 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 16 Jan 2024 11:59:03 -0500 Subject: [PATCH 15/39] Merge Linux Nuget GPU pipeline with zip-nuget (#19120) ### Description ### Motivation and Context --- .../c-api-noopenmp-packaging-pipelines.yml | 174 ++---------------- .../nuget-linux-cuda-packaging-stage.yml | 18 +- 2 files changed, 31 insertions(+), 161 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index f80b035582f18..2169a3ce1bb9e 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -83,6 +83,16 @@ resources: variables: - name: ReleaseVersionSuffix value: '' +- name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 +- name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 stages: - stage: Setup @@ -189,64 +199,11 @@ stages: AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos' BuildVariant: 'default' -- stage: Linux_C_API_Packaging_GPU_x64 - dependsOn: [] - jobs: - - job: - workspace: - clean: all - timeoutInMinutes: 120 - pool: 'Onnxruntime-Linux-GPU' - variables: - - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: '12' - - name: CUDA_VERSION - value: ${{ parameters.CudaVersion }} - steps: - - template: templates/set-version-number-variables-step.yml - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x64/default/gpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecuda$(CUDA_VERSION_MAJOR)build - - - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_c_api_package.sh - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Build and Test' - - - template: templates/java-api-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'linux-x64' - buildConfig: 'Release' - artifactName: 'onnxruntime-java-linux-x64-cuda' - version: '$(OnnxRuntimeVersion)' - libraryName: 'libonnxruntime.so' - nativeLibraryName: 'libonnxruntime4j_jni.so' - - - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml - parameters: - buildConfig: 'Release' - artifactName: 'onnxruntime-linux-x64-cuda-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-cuda' - libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - template: templates/clean-agent-build-directory-step.yml - -- template: templates/linux-gpu-tensorrt-packaging-pipeline.yml +- template: stages/nuget-linux-cuda-packaging-stage.yml parameters: - artifactName: 'onnxruntime-linux-x64-tensorrt-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-tensorrt' - buildJava: true - buildJavaOption: '--build_java' - buildNodejs: true - buildNodejsOption: '--build_nodejs' + CudaVersion: ${{ parameters.CudaVersion }} + docker_base_image: ${{ variables.docker_base_image }} + linux_trt_version: ${{ variables.linux_trt_version }} #CUDA without tensorrt - template: templates/win-ci.yml @@ -527,109 +484,6 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- stage: Linux_Packaging_combined_GPU - dependsOn: - - Linux_C_API_Packaging_GPU_x64 - - Linux_C_API_Packaging_GPU_TensorRT_x64 - condition: succeeded() - jobs: - - job: - workspace: - clean: all - pool: 'Onnxruntime-Linux-GPU' - - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - submodules: false - - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples - submodules: false - - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux - submodules: false - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - script: | - set -e -x - cd $(Build.SourcesDirectory) - mv manylinux onnxruntime - ls - - - template: templates/with-container-registry-steps.yml - parameters: - Steps: - - script: | - tools/ci_build/get_docker_image.py \ - --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ - --context tools/ci_build/github/linux/docker \ - --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ - --container-registry onnxruntimebuildcache \ - --multiple_repos \ - --repository onnxruntimecuda118xtrt86build - displayName: "Get onnxruntimecuda118xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda" - workingDirectory: $(Build.SourcesDirectory)/onnxruntime - ContainerRegistry: onnxruntimebuildcache - - - template: templates/set-version-number-variables-step.yml - parameters: - versionFileDirectory: '$(Build.SourcesDirectory)/onnxruntime' - workingDirectory: '$(Build.SourcesDirectory)/onnxruntime' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - Combined GPU' - inputs: - artifactName: 'onnxruntime-linux-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - Combined GPU' - inputs: - artifactName: 'onnxruntime-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' - - - task: ShellScript@2 - displayName: 'Shell Script' - inputs: - scriptPath: 'onnxruntime/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh' - args: '-a $(Build.BinariesDirectory)/tgz-artifacts' - workingDirectory: '$(Build.BinariesDirectory)/tgz-artifacts' - - - task: ArchiveFiles@2 - inputs: - rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu' - includeRootFolder: false - archiveType: 'tar' # Options: zip, 7z, tar, wim - tarCompression: 'gz' - archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' - replaceExistingArchive: true - - - template: templates/validate-package.yml - parameters: - PackageType: 'tarball' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' - ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - workingDirectory: '$(Build.ArtifactStagingDirectory)' - - - - task: CmdLine@2 - displayName: 'Test C API application for GPU package' - inputs: - script: | - docker run --gpus all -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/models:/data/models --volume $(Build.SourcesDirectory):/src_dir \ - --volume $(Build.ArtifactStagingDirectory):/artifact_src -e NIGHTLY_BUILD onnxruntimecuda118xtrt86build \ - /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet - workingDirectory: '$(Build.ArtifactStagingDirectory)' - - - task: PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' - artifactName: 'onnxruntime-linux-x64-gpu' - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - stage: Windows_Packaging_combined_GPU dependsOn: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 48a6e0e8529e6..dbbc9ef27e513 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -40,7 +40,16 @@ stages: - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_c_api_package.sh workingDirectory: $(Build.SourcesDirectory) displayName: 'Build and Test' - +# We only support Maven package for CUDA 11.8 + - ${{ if eq(parameters.CudaVersion, '11.8') }}: + - template: ../templates/java-api-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'linux-x64' + buildConfig: 'Release' + artifactName: 'onnxruntime-java-linux-x64-cuda' + version: '$(OnnxRuntimeVersion)' + libraryName: 'libonnxruntime.so' + nativeLibraryName: 'libonnxruntime4j_jni.so' - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: buildConfig: 'Release' @@ -82,6 +91,10 @@ stages: - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux submodules: false + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - script: | set -e -x cd $(Build.SourcesDirectory) @@ -159,3 +172,6 @@ stages: inputs: targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' artifactName: 'onnxruntime-linux-x64-gpu' + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' \ No newline at end of file From e2e488d6f8bcd14f40e9e2c8e65f310ce9c0e872 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 16 Jan 2024 09:18:35 -0800 Subject: [PATCH 16/39] Revert "iOS packaging pipeline stability" (#19135) Reverts microsoft/onnxruntime#19097 because it broken Android CI pipeline. --- .../external/onnxruntime_external_deps.cmake | 74 +++++++++---------- .../mac-ios-packaging-pipeline.yml | 2 +- .../stages/mac-ios-packaging-build-stage.yml | 7 +- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index c79bb87fd7f5d..78f63227c8392 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -108,14 +108,41 @@ FetchContent_Declare( ) # Download a protoc binary from Internet if needed -if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) +if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually # download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE # variable. - if (APPLE) - # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices. - # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html - # To keep it simple, just download and use the universal protoc binary for Apple builds. + message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") + if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) @@ -123,38 +150,6 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() - elseif(CMAKE_CROSSCOMPILING) - message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") - if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") - if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - endif() endif() endif() @@ -189,9 +184,9 @@ FetchContent_Declare( ) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) -#TODO: we'd better to turn the following option off. However, it will cause +#TODO: we'd better to turn the following option off. However, it will cause # ".\build.bat --config Debug --parallel --skip_submodule_sync --update" fail with an error message: -# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is +# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is # not in any export set. #set(protobuf_INSTALL OFF CACHE BOOL "Install protobuf binaries and files" FORCE) set(protobuf_USE_EXTERNAL_GTEST ON CACHE BOOL "" FORCE) @@ -567,3 +562,4 @@ endif() FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) + diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml index 34a51649fc384..5fd15b64e03b6 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml @@ -53,7 +53,7 @@ stages: displayName: "Set common variables" pool: - vmImage: "macOS-12" # macOS-13 seems less stable. macOS-12 will work for this job. + vmImage: "macOS-13" timeoutInMinutes: 5 diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index ed32c5d0e15be..d1dff0769e25f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -78,6 +78,10 @@ stages: pip install -r tools/ci_build/github/apple/ios_packaging.requirements.txt displayName: "Install Python requirements" + - script: | + $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/protobuf_install -d $(Build.SourcesDirectory)/cmake/deps.txt + displayName: "Build Host Protoc" + # create and test mobile pods - script: | python tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ @@ -87,7 +91,8 @@ stages: --test \ --variant ${{ parameters.packageVariant }} \ --build-settings-file "${{ variables.buildSettingsFile }}" \ - ${{ variables.optionalIncludeOpsByConfigOption }} + ${{ variables.optionalIncludeOpsByConfigOption }} \ + -b="--path_to_protoc_exe=$(Build.BinariesDirectory)/protobuf_install/bin/protoc" displayName: "Build macOS/iOS framework and assemble pod package files" - script: | From 80f274ca6f2f4572d827edd6dc7f736d7a8c036a Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 16 Jan 2024 09:42:59 -0800 Subject: [PATCH 17/39] Fix SkipLayerNormalization shape inference (#18724) SkipLayerNorm has more than one input, so `propagateShapeAndTypeFromFirstInput` is not enough. --- .../core/graph/contrib_ops/bert_defs.cc | 4 +- .../contrib_ops/shape_inference_functions.cc | 39 +++++++++++++++++++ .../contrib_ops/shape_inference_functions.h | 3 +- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index df8d0a59cb033..0317ffcfb0e31 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1285,7 +1285,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(3, "input_skip_bias_sum", "Sum of the input and skip inputs (and bias if it exists) with shape (batch_size, sequence_length, hidden_size).", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); ONNX_MS_OPERATOR_SET_SCHEMA( SkipSimplifiedLayerNormalization, 1, @@ -1334,7 +1334,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); constexpr const char* NGramRepeatBlock_ver1_doc = R"DOC( Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form a repeated n-gram if added to the back of the input_ids. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc index eeef20e9dff5e..8b1812f62be25 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -114,6 +114,45 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c } } +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + + auto stash_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + if (ctx.getNumOutputs() > 1) { + auto output_type = ctx.getOutputType(1); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 2) { + auto output_type = ctx.getOutputType(2); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 3) { + propagateElemTypeFromInputToOutput(ctx, 0, 3); + } + if (!hasNInputShapes(ctx, 1)) { + return; + } + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + int64_t input_ndim = input_shape.dim_size(); + int axis = static_cast(input_ndim - 1); + + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); + mean_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 2) { + auto inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); + inv_std_dev_shape->CopyFrom(input_shape); + inv_std_dev_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 3) { + propagateShapeFromInputToOutput(ctx, 0, 3); + } +} + // Shape inference for Attention and QAttention void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { // Input 0, 1, 2 are input, weights and bias. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h index 93cf5b304f653..6eb06af15309c 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h @@ -13,5 +13,6 @@ namespace onnxruntime { namespace contrib { void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index); void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From 8e272b9cac70a11c472fb002af755213a4dabf66 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 16 Jan 2024 16:53:15 -0500 Subject: [PATCH 18/39] Update build.py to remove unused functions and update python to 3.8 (#19164) ### Description ### Motivation and Context --- tools/ci_build/build.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0da4adb51767d..1a6262edf45c9 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -56,7 +56,7 @@ def __init__(self, message): def _check_python_version(): - required_minor_version = 7 + required_minor_version = 8 if (sys.version_info.major, sys.version_info.minor) < (3, required_minor_version): raise UsageError( f"Invalid Python version. At least Python 3.{required_minor_version} is required. " @@ -786,11 +786,6 @@ def get_linux_distro(): return "", "" -def is_ubuntu_1604(): - dist, ver = get_linux_distro() - return dist == "Ubuntu" and ver.startswith("16.04") - - def get_config_build_dir(build_dir, config): # build directory per configuration return os.path.join(build_dir, config) @@ -844,15 +839,6 @@ def update_submodules(source_dir): run_subprocess(["git", "submodule", "update", "--init", "--recursive"], cwd=source_dir) -def is_docker(): - path = "/proc/self/cgroup" - return ( - os.path.exists("/.dockerenv") - or os.path.isfile(path) - and any("docker" in line for line in open(path)) # noqa: SIM115 - ) - - def install_python_deps(numpy_version=""): dep_packages = ["setuptools", "wheel", "pytest"] dep_packages.append(f"numpy=={numpy_version}" if numpy_version else "numpy>=1.16.6") @@ -2401,16 +2387,6 @@ def run_csharp_tests(source_dir, build_dir, use_cuda, use_openvino, use_tensorrt run_subprocess(cmd_args, cwd=csharp_source_dir) -def is_cross_compiling_on_apple(args): - if not is_macOS(): - return False - if args.ios: - return True - if args.osx_arch != platform.machine(): - return True - return False - - def generate_documentation(source_dir, build_dir, configs, validate): # Randomly choose one build config config = next(iter(configs)) @@ -2725,12 +2701,6 @@ def main(): log.info("Activating emsdk...") run_subprocess([emsdk_file, "activate", emsdk_version], cwd=emsdk_dir) - if is_ubuntu_1604(): - if args.arm or args.arm64: - raise BuildError("Only Windows ARM(64) cross-compiled builds supported currently through this script") - if not is_docker() and not args.use_acl and not args.use_armnn: - install_python_deps() - if args.enable_pybind and is_windows(): install_python_deps(args.numpy_version) From c935c8fbd2e463a3e0153145140a8efd780dfabc Mon Sep 17 00:00:00 2001 From: moyo1997 <54333118+moyo1997@users.noreply.github.com> Date: Tue, 16 Jan 2024 16:24:37 -0800 Subject: [PATCH 19/39] remove unnecessary environment variable (#19166) remove unnecessary environment variable when building as arm64x --- build_arm64x.bat | 1 - 1 file changed, 1 deletion(-) diff --git a/build_arm64x.bat b/build_arm64x.bat index fbcdd373086a9..1ed268ae94a43 100644 --- a/build_arm64x.bat +++ b/build_arm64x.bat @@ -5,7 +5,6 @@ setlocal set PATH=C:\Program Files\Git\usr\bin;%PATH% -set LINK_REPRO_NAME=/mylink.rsp rem Requires a Python install to be available in your PATH python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %* From e61861b0a121bca1d60e5d4a3722e52b6820c430 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Tue, 16 Jan 2024 16:36:28 -0800 Subject: [PATCH 20/39] Clean up generated files in QNN UTs (#19127) ### Description Clean up generated files in QNN UTs --- onnxruntime/test/providers/qnn/simple_op_htp_test.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 8ff65c08e8633..c4244fe532456 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -815,7 +815,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { // Check the Onnx skeleton file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // Check the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin")); + std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), @@ -837,6 +838,10 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { QDQTolerance(), logging::Severity::kERROR, context_binary_file); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); } // Run QDQ model on HTP 2 times @@ -898,6 +903,9 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Run QDQ model on HTP with 2 inputs @@ -955,6 +963,8 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { QDQTolerance(), logging::Severity::kERROR, context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { From 81d363045ba273b16a3ec654c53a15217a2d2a36 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 16 Jan 2024 17:25:18 -0800 Subject: [PATCH 21/39] Upgrade Ubuntu machine pool from 20.04 to 22.04 (#19117) ### Description Upgrade Ubuntu machine pool from 20.04 to 22.04 --- .../build-perf-test-binaries-pipeline.yml | 2 +- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- ...lean-build-docker-image-cache-pipeline.yml | 10 +-------- .../cuda-packaging-pipeline.yml | 2 +- .../azure-pipelines/linux-ci-pipeline.yml | 4 ++-- .../linux-cpu-aten-pipeline.yml | 2 +- .../linux-cpu-eager-pipeline.yml | 2 +- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 2 +- .../linux-migraphx-ci-pipeline.yml | 2 +- .../npm-packaging-pipeline.yml | 4 ++-- .../nuget/templates/test_linux.yml | 2 +- .../orttraining-linux-ci-pipeline.yml | 2 +- .../orttraining-pai-ci-pipeline.yml | 4 ++-- .../orttraining-py-packaging-pipeline-cpu.yml | 2 +- .../azure-pipelines/post-merge-jobs.yml | 6 ++--- .../py-package-test-pipeline.yml | 2 +- .../stages/py-cuda-packaging-stage.yml | 2 +- .../stages/py-cuda-publishing-stage.yml | 2 +- .../templates/android-java-api-aar.yml | 2 +- .../templates/build-linux-wasm-step.yml | 22 +++++++++---------- .../azure-pipelines/templates/c-api-cpu.yml | 4 ++-- .../templates/c-api-linux-cpu.yml | 2 +- .../azure-pipelines/templates/linux-ci.yml | 2 +- .../linux-cpu-packaging-pipeline.yml | 2 +- .../templates/linux-wasm-ci.yml | 2 +- ...device-training-cpu-packaging-pipeline.yml | 2 +- .../py-packaging-selectable-stage.yml | 2 +- .../templates/py-packaging-stage.yml | 4 ++-- .../github/azure-pipelines/templates/rocm.yml | 2 +- .../azure-pipelines/web-ci-pipeline.yml | 2 +- .../linux/build_linux_python_package.sh | 6 ++--- .../ci_build/github/linux/run_python_tests.sh | 2 +- 32 files changed, 50 insertions(+), 60 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 3ddc167bc0a61..d37e9bdc5da4c 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -28,7 +28,7 @@ stages: artifactName: 'onnxruntime-android-full-aar' job_name_suffix: 'Full' publish_executables: '1' - pool_name: 'onnxruntime-Ubuntu2004-AMD-CPU' + pool_name: 'onnxruntime-Ubuntu2204-AMD-CPU' # build Python packages # Linux GPU only diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 2169a3ce1bb9e..3803333bd880a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -246,7 +246,7 @@ stages: workspace: clean: all timeoutInMinutes: 120 - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU variables: RocmVersion: '5.6' steps: diff --git a/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml b/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml index 24086b6166fe4..43e668eef8d00 100644 --- a/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml @@ -19,8 +19,7 @@ variables: jobs: - job: Clean_Build_Docker_Image_Cache - pool: - vmImage: 'ubuntu-20.04' + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 30 @@ -29,13 +28,6 @@ jobs: submodules: false fetchDepth: 1 - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.9' - addToPath: true - architecture: 'x64' - displayName: "Use Python 3.9" - - task: AzureCLI@2 inputs: azureSubscription: 'AIInfraBuild' diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index df7b5f59d28fc..1d2ba88652f48 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -126,7 +126,7 @@ stages: BaseImage: 'registry.access.redhat.com/ubi8/ubi' OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' PackageJava: false PackageNodeJS: false # Nuget Packaging diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 07f672c75d029..cff7c96aa9253 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -46,7 +46,7 @@ stages: skipComponentGovernanceDetection: true ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -123,7 +123,7 @@ stages: skipComponentGovernanceDetection: true ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml index 146186e9eeaf5..090ce97296687 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -43,7 +43,7 @@ jobs: variables: CCACHE_DIR: $(Agent.TempDirectory)/ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml index a5c08e95b7efc..d3d13cc5344da 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml @@ -51,7 +51,7 @@ jobs: timeoutInMinutes: 120 workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - checkout: self clean: true diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 0993a81a02249..5bc8c3603ee92 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -64,7 +64,7 @@ jobs: CCACHE_DIR: $(Pipeline.Workspace)/ccache workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index f7571a3b7eab6..9cf7a3fb42397 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -46,7 +46,7 @@ jobs: TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 120 steps: diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index 7f73da23b5eb1..21fc205c72e89 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -41,7 +41,7 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: true - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' PackageName: 'onnxruntime-web' ExtraBuildArgs: '' UseWebPoolName: true @@ -54,7 +54,7 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' PackageName: 'onnxruntime-react-native' BuildAndroidAARStageDependsOn: 'Precheck_and_extract_commit' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index f44106c145228..2567bec9fdfc2 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -1,5 +1,5 @@ parameters: - AgentPool: 'onnxruntime-Ubuntu2004-AMD-CPU' + AgentPool: 'onnxruntime-Ubuntu2204-AMD-CPU' ArtifactSuffix: '' NugetPackageName : '' StageSuffix: 'CPU' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml index 018672e0b2dea..26fd5e1ec0b5d 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml @@ -44,7 +44,7 @@ jobs: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu-2004-Training-CPU + pool: onnxruntime-Ubuntu-2204-Training-CPU steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index a53f91fb317cb..71b224b65964f 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -37,7 +37,7 @@ jobs: TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 120 steps: @@ -132,7 +132,7 @@ jobs: TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 120 steps: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 817ace0571837..a44a8c215939f 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -16,7 +16,7 @@ stages: timeoutInMinutes: 180 workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU strategy: matrix: diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 5ee39876733e2..3ec5400dacc65 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -4,7 +4,7 @@ stages: parameters: NpmPackagingMode: 'dev' IsReleasePipeline: true - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: '' UseWebPoolName: true @@ -367,7 +367,7 @@ stages: timeoutInMinutes: 150 variables: skipComponentGovernanceDetection: true - pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + pool: 'onnxruntime-Ubuntu2204-AMD-CPU' steps: - template: templates/set-version-number-variables-step.yml @@ -413,7 +413,7 @@ stages: - job: AndroidCustomBuildScript workspace: clean: all - pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + pool: 'onnxruntime-Ubuntu2204-AMD-CPU' variables: dockerImageTag: onnxruntime-android-custom-build steps: diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 55d3150f21aa3..04f555deb1a22 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' base_image: 'registry.access.redhat.com/ubi8/ubi' devtoolset_rootpath: /opt/rh/gcc-toolset-12/root ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml index e6d8ee35e75e3..f82c80d4d7e93 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -105,7 +105,7 @@ stages: - template: ../templates/py-linux-gpu.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} docker_base_image: ${{ variables.docker_base_image }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml index 4f440e0f61b3d..2a4debcf9fba5 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml @@ -20,7 +20,7 @@ stages: dependsOn: [] jobs: - job: - pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + pool: 'onnxruntime-Ubuntu2204-AMD-CPU' steps: - checkout: none - task: DownloadPipelineArtifact@2 diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 5e61f88b4aa18..509fea45ebe53 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -33,7 +33,7 @@ parameters: - name: pool_name displayName: Pool name type: string - default: 'onnxruntime-Ubuntu2004-AMD-CPU' + default: 'onnxruntime-Ubuntu2204-AMD-CPU' - name: packageName # now we can build onnxruntime or onnxruntime-mobile for Android, need specify it here diff --git a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml index e664cf69dec76..e77b1a4008b7c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml @@ -24,19 +24,17 @@ parameters: type: string steps: - - task: Cache@2 - inputs: - ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: - key: ' "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | merge ' - ${{else}}: - key: '"${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | $(Build.SourceVersion) ' - path: ${{parameters.CacheDir}} - restoreKeys: | - "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} - displayName: Cache Task - condition: eq('${{parameters.WithCache}}', true) - - ${{if eq(parameters.WithCache, true)}}: + - task: Cache@2 + inputs: + ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: + key: ' "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | merge ' + ${{else}}: + key: '"${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | $(Build.SourceVersion) ' + path: ${{parameters.CacheDir}} + restoreKeys: | + "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} + displayName: Cache Task - script: | set -e -x pushd '$(Build.SourcesDirectory)/cmake/external/emsdk' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 81319e07c6b17..168602a17910b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -759,7 +759,7 @@ stages: - template: ../nuget/templates/test_linux.yml parameters: - AgentPool : onnxruntime-Ubuntu2004-AMD-CPU + AgentPool : onnxruntime-Ubuntu2204-AMD-CPU NugetPackageName : 'Microsoft.ML.OnnxRuntime' ArtifactSuffix: 'CPU' SpecificArtifact: ${{ parameters.SpecificArtifact }} @@ -796,7 +796,7 @@ stages: OS: Linux BuildId: ${{ parameters.BuildId }} SpecificArtifact: ${{ parameters.SpecificArtifact }} - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' - template: final-jar-testing.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index 8538f15e93753..cf470b3fa2448 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -19,7 +19,7 @@ parameters: - name: PoolName type: string - default: 'onnxruntime-Ubuntu2004-AMD-CPU' + default: 'onnxruntime-Ubuntu2204-AMD-CPU' - name: ArtifactNamePrefix type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml index 7b9788d90b17d..15165e3cb0950 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml @@ -1,5 +1,5 @@ parameters: - AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' + AgentPool : 'onnxruntime-Ubuntu2204-AMD-CPU' StageName : 'Linux_CI_Dev' RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' NuPackScript: '' diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 6ad5f9f38a4db..8972d55f6e190 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -32,7 +32,7 @@ stages: BaseImage: 'registry.access.redhat.com/ubi8/ubi' OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} PackageNodeJS: ${{ parameters.PackageNodeJS }} diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index e6693a6f6d26a..d279e667f9091 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -13,7 +13,7 @@ parameters: - name: PoolName type: string - default: 'onnxruntime-Ubuntu2004-AMD-CPU' + default: 'onnxruntime-Ubuntu2204-AMD-CPU' - name: SkipPublish type: boolean diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index 51583a25f63ac..cf39be23cbdaf 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -336,7 +336,7 @@ stages: - template: ../nuget/templates/test_linux.yml parameters: - AgentPool : onnxruntime-Ubuntu2004-AMD-CPU + AgentPool : onnxruntime-Ubuntu2204-AMD-CPU NugetPackageName : 'Microsoft.ML.OnnxRuntime.Training' ArtifactSuffix: 'Training-CPU' StageSuffix: 'Training_CPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml index 00ba5ea4a475a..01cab936aa529 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml @@ -48,7 +48,7 @@ stages: timeoutInMinutes: 90 workspace: clean: all - pool: onnxruntime-Ubuntu2004-AMD-CPU + pool: onnxruntime-Ubuntu2204-AMD-CPU strategy: matrix: ${{ each PythonVersion in parameters.python_version }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index abe06e80f4f19..8669a883c31f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -430,7 +430,7 @@ stages: - template: py-linux.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' base_image: 'registry.access.redhat.com/ubi8/ubi' devtoolset_rootpath: /opt/rh/gcc-toolset-12/root ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 @@ -443,6 +443,6 @@ stages: - template: py-linux-gpu.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index 2e9e6c6b35a2e..43a80aa4fd4e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -14,7 +14,7 @@ jobs: workspace: clean: all timeoutInMinutes: 180 - pool: Ubuntu-2004-rocm-aiinfra + pool: Ubuntu-2204-rocm-aiinfra variables: - name: PythonVersion value: ${{ parameters.PythonVersion }} diff --git a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml index e352a04068ee8..24809ccfdec1f 100644 --- a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml @@ -53,7 +53,7 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: false - PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: $(ExtraBuildArgs) WASMTemplate: linux-wasm-ci.yml diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh index 1059dd5047477..933d1f3d5874a 100755 --- a/tools/ci_build/github/linux/build_linux_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -7,9 +7,9 @@ mkdir -p /build/dist EXTRA_ARG="" -# Put 3.8 at the last because Ubuntu 20.04 use python 3.8 and we will upload the intermediate build files of this -# config to Azure DevOps Artifacts and download them to a Ubuntu 20.04 machine to run the tests. -PYTHON_EXES=("/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp38-cp38/bin/python3.8") +# Put 3.8 at the last because Ubuntu 22.04 use python 3.10 and we will upload the intermediate build files of this +# config to Azure DevOps Artifacts and download them to a Ubuntu 22.04 machine to run the tests. +PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp310-cp310/bin/python3.10") while getopts "d:p:x:c:" parameter_Option do case "${parameter_Option}" in diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index 3164a10a09dfd..082c561dd17b9 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -15,7 +15,7 @@ c) BUILD_CONFIG=${OPTARG};; esac done -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH cd /build files=(whl/*.whl) FILE_NAME="${files[0]}" From 07d3aed3aa3a054deb502cedf867f559fc690755 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 17 Jan 2024 13:35:13 +0800 Subject: [PATCH 22/39] [WebNN EP] Fixed build issue with disable_rtti (#19173) Previously building webnn ep with --disable_rtti will throw unboundTypeError since unbound type names are illegal with RTTI disabled in Embind API, we can fix it by adding a -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0 flag. --- cmake/adjust_global_compile_flags.cmake | 5 +++++ cmake/onnxruntime_webassembly.cmake | 5 ++++- tools/ci_build/build.py | 4 ---- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 30d8cbf78fb1a..2c7bf9f1c2f5c 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -123,6 +123,11 @@ if (onnxruntime_DISABLE_RTTI) add_compile_options("$<$:/GR->" "$<$:/we4541>") else() add_compile_options("$<$:-fno-rtti>") + if (onnxruntime_USE_WEBNN) + # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled + # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/7001 + add_compile_options("$<$:-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0>") + endif() endif() else() #MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on. diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 858583e64e9df..546d50c1ca2d3 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -268,7 +268,10 @@ else() endif() if (onnxruntime_USE_WEBNN) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + if (onnxruntime_DISABLE_RTTI) + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") + endif() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1a6262edf45c9..1034a82cb2854 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1283,10 +1283,6 @@ def generate_build_tree( if args.use_webnn: if not args.build_wasm: raise BuildError("WebNN is only available for WebAssembly build.") - if args.disable_rtti: - # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled - # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/16911 - raise BuildError("WebNN is not supported with RTTI disabled.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] if args.use_snpe: From 9876cc7c4f5f6249e1dec8b93abf7b8dfcf5ca0c Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 17 Jan 2024 15:46:19 +0800 Subject: [PATCH 23/39] more inputs support for LLM exporter (#19005) ### Description ### Motivation and Context --- .../transformers/large_model_exporter.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 1601b1a203b9a..9e8b284bf56c7 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -224,24 +224,35 @@ def fetch_onnx_inputs_outputs_name( if not num_of_past_key: num_of_past_key = model.config.num_hidden_layers - onnx_inp_names = ("input_ids", "attention_mask") + # filter out constant inputs + onnx_inp_names = tuple( + [torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)] + ) + assert ( + "input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names + ), "input_ids and attention_mask must be existed in inputs" onnx_out_names = ("logits",) onnx_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, } + # add dyanmic dimensions for the unkonw inputs + for idx, name in enumerate(onnx_inp_names): + if name not in onnx_dynamic_axes: + unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())} + onnx_dynamic_axes[name] = unknown_dims if input_with_past: for i in range(num_of_past_key): - onnx_inp_names += (f"present_key.{i}",) - onnx_inp_names += (f"present_values.{i}",) + onnx_inp_names += (f"past_key_values.{i}.key",) + onnx_inp_names += (f"past_key_values.{i}.value",) onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis if with_past or input_with_past: for i in range(num_of_past_key): - onnx_out_names += (f"past_key.{i}",) - onnx_out_names += (f"past_values.{i}",) + onnx_out_names += (f"present.{i}.key",) + onnx_out_names += (f"present.{i}.value",) onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis From 63dd605d3310f5a9540c414216f3f3b67d455c4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 17 Jan 2024 19:00:36 +0100 Subject: [PATCH 24/39] Fix untyped float values in quantization tool missing from PR #18043 (#19182) ### Description Extends the code coverage to Entroy, Histogram and Distribution calibration method, fix bugs while doing it. ### Motivation and Context Bugs detected in [Olive](https://github.com/microsoft/OLive). --- .../python/tools/quantization/calibrate.py | 86 +++++++++++++++---- .../python/tools/quantization/quant_utils.py | 2 +- .../python/quantization/test_op_matmul.py | 66 +++++++++++++- 3 files changed, 131 insertions(+), 23 deletions(-) diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index d0db57c392961..77b3dce9fb004 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -5,6 +5,7 @@ # license information. # -------------------------------------------------------------------------- import abc +import copy import itertools import os import uuid @@ -21,6 +22,48 @@ from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distribution +def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray: + """ + See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html#scipy.special.rel_entr. + Python implementation. + """ + res = np.empty(pk.shape, dtype=pk.dtype) + res[:] = pk[:] * np.log(pk[:] / qk[:]) + c2 = (pk == 0) & (qk >= 0) + res[c2] = 0 + c1 = (pk > 0) & (qk > 0) + res[~c1] = np.inf + return res + + +def entropy( + pk: np.ndarray, + qk: np.ndarray, + base: Optional[float] = None, + axis: int = 0, +) -> np.ndarray: + """ + Simplifeied version of entropy. + Source: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html. + This avoids taking a dependency on scipy just for this function. + """ + assert base is None or base > 0, "base={base} must be a positive number or `None`." + assert qk is not None, "qk is None" + + pk = np.asarray(pk).astype(np.float32) + pk = 1.0 * pk / np.sum(pk, axis=axis, keepdims=True) + + qk = np.asarray(qk).astype(np.float32) + pk, qk = np.broadcast_arrays(pk, qk) + qk = 1.0 * qk / np.sum(qk, axis=axis, keepdims=True) + vec = rel_entr(pk, qk) + + s = np.sum(vec, axis=axis) + if base is not None: + s /= np.log(base) + return s.astype(pk.dtype) + + class TensorData: _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) _floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"]) @@ -708,8 +751,8 @@ def collect_absolute_value(self, name_to_arr): min_value = np.min(data_arr_np) max_value = np.max(data_arr_np) else: - min_value = 0 - max_value = 0 + min_value = np.array(0, dtype=data_arr_np.dtype) + max_value = np.array(0, dtype=data_arr_np.dtype) data_arr_np = np.absolute(data_arr_np) # only consider absolute value @@ -725,6 +768,8 @@ def collect_absolute_value(self, name_to_arr): old_histogram = self.histogram_dict[tensor] old_min = old_histogram[2] old_max = old_histogram[3] + assert hasattr(old_min, "dtype"), f"old_min should be a numpy array but is {type(old_min)}" + assert hasattr(old_max, "dtype"), f"old_min should be a numpy array but is {type(old_max)}" old_hist = old_histogram[0] old_hist_edges = old_histogram[1] temp_amax = np.max(data_arr_np) @@ -757,7 +802,7 @@ def collect_value(self, name_to_arr): min_value = np.array(0, dtype=data_arr.dtype) max_value = np.array(0, dtype=data_arr.dtype) - threshold = max(abs(min_value), abs(max_value)) + threshold = np.array(max(abs(min_value), abs(max_value)), dtype=data_arr.dtype) if tensor in self.histogram_dict: old_histogram = self.histogram_dict[tensor] @@ -809,7 +854,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho def compute_collection_result(self): if not self.histogram_dict or len(self.histogram_dict) == 0: raise ValueError("Histogram has not been collected. Please run collect() first.") - print(f"Finding optimal threshold for each tensor using {self.method} algorithm ...") + print(f"Finding optimal threshold for each tensor using {self.method!r} algorithm ...") if self.method == "entropy": return self.compute_entropy() @@ -938,7 +983,14 @@ def compute_distribution(self): assert avg_coef.dtype != np.float64 assert std_coef.dtype != np.float64 assert hist_edges.dtype != np.float64 - thresholds_dict[tensor] = TensorData(avg=avg_coef, std=std_coef, hist=hist, hist_edges=hist_edges) + thresholds_dict[tensor] = TensorData( + avg=avg_coef, + std=std_coef, + hist=hist, + hist_edges=hist_edges, + lowest=hist_edges.min(), + highest=hist_edges.max(), + ) # Plot histogram for debug only if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): @@ -952,18 +1004,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): `q` is a truncated version of the original distribution. Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ - import copy - - from scipy.stats import entropy - hist = histogram[0] hist_edges = histogram[1] num_bins = hist.size zero_bin_index = num_bins // 2 num_half_quantized_bin = num_quantized_bins // 2 + dtype = histogram[1].dtype kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1) - thresholds = [(0, 0) for i in range(kl_divergence.size)] + thresholds = [(np.array(0, dtype=dtype), np.array(0, dtype=dtype)) for i in range(kl_divergence.size)] # <------------ num bins ----------------> # <--- quantized bins ----> @@ -983,10 +1032,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): start_index = zero_bin_index - i end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1) <= num_bins else num_bins - thresholds[i - num_half_quantized_bin] = ( - float(hist_edges[start_index]), - float(hist_edges[end_index]), - ) + thresholds[i - num_half_quantized_bin] = (hist_edges[start_index], hist_edges[end_index]) sliced_distribution = copy.deepcopy(hist[start_index:end_index]) @@ -1020,15 +1066,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): norm = sum(nonzeros[start:end]) if norm != 0: - q[start:end] = float(quantized_bins[index]) / float(norm) + q[start:end] = quantized_bins[index] / norm p = smooth_distribution(p) q = smooth_distribution(q) - - if isinstance(q, np.ndarray): - kl_divergence[i - num_half_quantized_bin] = entropy(p, q) + if p is None or q is None: + div = np.array(np.inf, dtype=dtype) else: - kl_divergence[i - num_half_quantized_bin] = float("inf") + div = np.array(entropy(p, q), dtype=dtype) + kl_divergence[i - num_half_quantized_bin] = div min_kl_divergence_idx = np.argmin(kl_divergence) optimal_threshold = thresholds[min_kl_divergence_idx] @@ -1038,6 +1084,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): optimal_threshold = (min_value, optimal_threshold[1]) if optimal_threshold[1] > max_value: optimal_threshold = (optimal_threshold[0], max_value) + assert hasattr(optimal_threshold[0], "dtype") + assert hasattr(optimal_threshold[1], "dtype") return optimal_threshold diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 68c2b3bf79c8b..036f49b420734 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -653,7 +653,7 @@ def smooth_distribution(p, eps=0.0001): if not n_nonzeros: # raise ValueError('The discrete probability distribution is malformed. All entries are 0.') - return -1 + return None eps1 = eps * float(n_zeros) / float(n_nonzeros) assert eps1 < 1.0, "n_zeros=%d, n_nonzeros=%d, eps1=%f" % ( n_zeros, diff --git a/onnxruntime/test/python/quantization/test_op_matmul.py b/onnxruntime/test/python/quantization/test_op_matmul.py index 344583aa7c624..91368bd643158 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul.py +++ b/onnxruntime/test/python/quantization/test_op_matmul.py @@ -10,13 +10,39 @@ import numpy as np import onnx import packaging.version as pv +from numpy.testing import assert_almost_equal from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnxruntime.capi.onnxruntime_pybind11_state import Fail from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType, quantize_dynamic, quantize_static +from onnxruntime.quantization.calibrate import entropy + + +def skip_if_new_opset_exception_raised(func): + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except Fail as e: + if "is under development and support for this is limited" in str(e): + raise unittest.SkipTest(f"Skipped {func} due to opset under development.") # noqa: B904 + raise + + return wrapper class TestOpMatMul(unittest.TestCase): + def test_entropy(self): + try: + from scipy.stats import entropy as scipy_entropy + except ImportError: + raise unittest.SkipTest("scipy not installed.") # noqa: B904 + pk = (np.arange(10) - 5).astype(np.float32) / 10 + qk = -(np.arange(10) - 5).astype(np.float32) / 10 + ent = scipy_entropy(pk, qk) + get = entropy(pk, qk) + assert_almost_equal(ent, get) + def input_feeds(self, n, name2shape, dtype): input_data_list = [] for _i in range(n): @@ -324,10 +350,11 @@ def test_quantize_matmul_u8u8(self): @unittest.skipIf( pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" ) + @skip_if_new_opset_exception_raised def test_quantize_matmul_u8u8_f16(self): - self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 19, 9) + self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 21, 9) - def quantize_matmul_s8s8(self, tt, opset, ir_version): + def quantize_matmul_s8s8(self, tt, opset, ir_version, calibrate_method=CalibrationMethod.MinMax): np.random.seed(1) model_fp_path = "matmul_fp.onnx" self.construct_model_matmul(model_fp_path, tensor_type=tt, opset=opset, ir_version=ir_version) @@ -341,6 +368,7 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version): activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, extra_options={"ActivationSymmetric": True}, + calibrate_method=calibrate_method, ) self.static_quant_test_qdq( model_fp_path, @@ -348,6 +376,7 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version): activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, extra_options={"ActivationSymmetric": True}, + calibrate_method=calibrate_method, ) # dynamic quantization doesn't support activation:int8 @@ -357,11 +386,42 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version): def test_quantize_matmul_s8s8(self): self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8) + def test_quantize_matmul_s8s8_entropy(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Entropy) + + def test_quantize_matmul_s8s8_percentile(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Percentile) + + def test_quantize_matmul_s8s8_distribution(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Distribution) + @unittest.skipIf( pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" ) + @skip_if_new_opset_exception_raised def test_quantize_matmul_s8s8_f16(self): - self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 19, 9) + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9) + + @unittest.skipIf( + pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" + ) + @skip_if_new_opset_exception_raised + def test_quantize_matmul_s8s8_f16_entropy(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Entropy) + + @unittest.skipIf( + pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" + ) + @skip_if_new_opset_exception_raised + def test_quantize_matmul_s8s8_f16_percentile(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Percentile) + + @unittest.skipIf( + pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" + ) + @skip_if_new_opset_exception_raised + def test_quantize_matmul_s8s8_f16_distribution(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Distribution) def quantize_matmul_e4m3fn_same(self, tt, opset, ir_version): np.random.seed(1) From bd9d8fb2a545a59d87a4c23308ec543ba6e4c41d Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:18:32 -0800 Subject: [PATCH 25/39] [ORT 1.17.0 release] Bump up version to 1.18.0 (#19170) ### Description Bump up version to 1.18.0 since the release branch has been cut. ### Motivation and Context Co-authored-by: rachguo --- VERSION_NUMBER | 2 +- .../Training/NativeTrainingMethods.shared.cs | 4 ++-- docs/python/README.rst | 5 +++++ include/onnxruntime/core/session/onnxruntime_c_api.h | 2 +- js/common/lib/version.ts | 2 +- js/common/package-lock.json | 4 ++-- js/common/package.json | 2 +- js/node/lib/version.ts | 2 +- js/node/package-lock.json | 6 +++--- js/node/package.json | 2 +- js/react_native/lib/version.ts | 2 +- js/react_native/package.json | 2 +- js/react_native/yarn.lock | 2 +- js/web/lib/version.ts | 2 +- js/web/package-lock.json | 6 +++--- js/web/package.json | 2 +- onnxruntime/__init__.py | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 8 ++++---- 18 files changed, 31 insertions(+), 26 deletions(-) diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 092afa15df4df..84cc529467b05 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.17.0 +1.18.0 diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 68a399f8b9671..7fe16f4156ef2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -65,10 +65,10 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(18 /*ORT_API_VERSION*/); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(18 /*ORT_API_VERSION*/); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); diff --git a/docs/python/README.rst b/docs/python/README.rst index 32bb3729e01d0..bbc8571fe3f17 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.17.0" +__version__ = "1.18.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d77c188f832a7..91a7f0d930b51 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2397,7 +2397,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_17 = { +static constexpr OrtApi ort_api_1_to_18 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2756,16 +2756,16 @@ static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265 static_assert(offsetof(OrtApi, SetUserLoggingFunction) / sizeof(void*) == 266, "Size of version 17 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.17.0", +static_assert(std::string_view(ORT_VERSION) == "1.18.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it -// 2. If there were any APIs added to ort_api_1_to_17 above: +// 2. If there were any APIs added to ort_api_1_to_18 above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_17; + return &ort_api_1_to_18; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." From bc219ed553fc8d4b8fa3c7b4476810a63a864d8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 17 Jan 2024 20:33:34 +0100 Subject: [PATCH 26/39] [TensorRT EP] Enable a minimal CUDA EP compilation without kernels (#19052) Adresses https://github.com/microsoft/onnxruntime/issues/18542. I followed the advice given by @RyanUnderhill [here](https://github.com/microsoft/onnxruntime/pull/18731#issuecomment-1848261925) and went with a minimal CUDA EP for now. --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_providers_cuda.cmake | 49 ++++++++++++++----- .../core/providers/cuda/cuda_context.h | 3 +- onnxruntime/core/providers/cuda/cuda_call.cc | 4 ++ .../core/providers/cuda/cuda_common.cc | 42 ++++++++-------- onnxruntime/core/providers/cuda/cuda_common.h | 6 ++- .../providers/cuda/cuda_execution_provider.cc | 14 +++++- onnxruntime/core/providers/cuda/cuda_pch.h | 7 +++ .../core/providers/cuda/cuda_stream_handle.cc | 4 ++ .../core/providers/cuda/cudnn_common.cc | 3 +- .../core/providers/cuda/cudnn_common.h | 3 +- 11 files changed, 97 insertions(+), 39 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index bc96218dac79e..712d5d76108aa 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -79,6 +79,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) +option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 84d1376f99d5e..9887d615c92d7 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -1,10 +1,25 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" - ) + + if (onnxruntime_CUDA_MINIMAL) + file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc" + ) + # Remove pch files + list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h" + ) + else() + file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + ) + endif() # Remove pch files list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" @@ -16,11 +31,16 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) - file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" - ) + + if (onnxruntime_CUDA_MINIMAL) + set(onnxruntime_providers_cuda_shared_srcs "") + else() + file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" + ) + endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) @@ -156,10 +176,15 @@ endif() add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) - if(onnxruntime_CUDNN_HOME) - target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + else() + target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + if(onnxruntime_CUDNN_HOME) + target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) + target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + endif() endif() if (onnxruntime_USE_TRITON_KERNEL) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 9416fad5f1448..1370f5c4c5e10 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -16,9 +16,10 @@ #include "core/providers/custom_op_context.h" #include #include +#ifndef USE_CUDA_MINIMAL #include #include - +#endif namespace Ort { namespace Custom { diff --git a/onnxruntime/core/providers/cuda/cuda_call.cc b/onnxruntime/core/providers/cuda/cuda_call.cc index 4f223041e04e3..f60684795a4bc 100644 --- a/onnxruntime/core/providers/cuda/cuda_call.cc +++ b/onnxruntime/core/providers/cuda/cuda_call.cc @@ -30,6 +30,7 @@ const char* CudaErrString(cudaError_t x) { return cudaGetErrorString(x); } +#ifndef USE_CUDA_MINIMAL template <> const char* CudaErrString(cublasStatus_t e) { cudaDeviceSynchronize(); @@ -76,6 +77,7 @@ const char* CudaErrString(cufftResult e) { return "Unknown cufft error status"; } } +#endif #ifdef ORT_USE_NCCL template <> @@ -132,6 +134,7 @@ std::conditional_t CudaCall( template Status CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); template void CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); +#ifndef USE_CUDA_MINIMAL template Status CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line); template void CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line); template Status CudaCall(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line); @@ -140,6 +143,7 @@ template Status CudaCall(curandStatus_t retCode, const ch template void CudaCall(curandStatus_t retCode, const char* exprString, const char* libName, curandStatus_t successCode, const char* msg, const char* file, const int line); template Status CudaCall(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line); template void CudaCall(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line); +#endif #ifdef ORT_USE_NCCL template Status CudaCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 33f2938940e4d..65083f89f7f77 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -14,6 +14,27 @@ namespace cuda { // 0x04 - pedantic constexpr const char* kCudaGemmOptions = "ORT_CUDA_GEMM_OPTIONS"; +const char* CudaDataTypeToString(cudaDataType_t dt) { + switch (dt) { + case CUDA_R_16F: + return "CUDA_R_16F"; + case CUDA_R_16BF: + return "CUDA_R_16BF"; + case CUDA_R_32F: + return "CUDA_R_32F"; +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 + case CUDA_R_8F_E4M3: + return "CUDA_R_8F_E4M3"; + case CUDA_R_8F_E5M2: + return "CUDA_R_8F_E5M2"; +#endif + default: + return ""; + } +} + +#ifndef USE_CUDA_MINIMAL // Initialize the singleton instance HalfGemmOptions HalfGemmOptions::instance; @@ -54,26 +75,6 @@ const char* cublasGetErrorEnum(cublasStatus_t error) { } } -const char* CudaDataTypeToString(cudaDataType_t dt) { - switch (dt) { - case CUDA_R_16F: - return "CUDA_R_16F"; - case CUDA_R_16BF: - return "CUDA_R_16BF"; - case CUDA_R_32F: - return "CUDA_R_32F"; -#if !defined(DISABLE_FLOAT8_TYPES) - // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 - case CUDA_R_8F_E4M3: - return "CUDA_R_8F_E4M3"; - case CUDA_R_8F_E5M2: - return "CUDA_R_8F_E5M2"; -#endif - default: - return ""; - } -} - const char* CublasComputeTypeToString(cublasComputeType_t ct) { switch (ct) { case CUBLAS_COMPUTE_16F: @@ -92,6 +93,7 @@ const char* CublasComputeTypeToString(cublasComputeType_t ct) { return ""; } } +#endif // It must exist somewhere already. cudaDataType_t ToCudaDataType(int32_t element_type) { diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 707099bac3ce0..e9941ce743bc3 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -22,13 +22,14 @@ namespace onnxruntime { namespace cuda { #define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) +#ifndef USE_CUDA_MINIMAL #define CUBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUBLAS_CALL(expr)) #define CUSPARSE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUSPARSE_CALL(expr)) #define CURAND_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CURAND_CALL(expr)) #define CUDNN_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_CALL(expr)) #define CUDNN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(CUDNN_CALL2(expr, m)) #define CUFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUFFT_CALL(expr)) - +#endif // Type mapping for MLFloat16 to half template class ToCudaType { @@ -93,7 +94,7 @@ inline bool CalculateFdmStrides(gsl::span p, const std::vector KernelCreateInfo BuildKernelCreateInfo() { @@ -1326,6 +1332,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifndef USE_CUDA_MINIMAL BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2201,6 +2208,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { @@ -2210,6 +2218,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { } } +#ifndef USE_CUDA_MINIMAL #ifndef DISABLE_CONTRIB_OPS ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry)); #endif @@ -2220,6 +2229,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { #ifdef ENABLE_TRAINING_OPS ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry)); +#endif #endif return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/cuda_pch.h b/onnxruntime/core/providers/cuda/cuda_pch.h index f48554e8f1286..dfe50fe0a8832 100644 --- a/onnxruntime/core/providers/cuda/cuda_pch.h +++ b/onnxruntime/core/providers/cuda/cuda_pch.h @@ -10,12 +10,19 @@ #include #include +#include +#ifndef USE_CUDA_MINIMAL #include #include #include #include #include #include +#else +typedef void* cudnnHandle_t; +typedef void* cublasHandle_t; +typedef void* cublasLtHandle_t; +#endif #ifdef ORT_USE_NCCL #include diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 7c866395ecf6e..0a256394b7d99 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -69,6 +69,7 @@ CudaStream::CudaStream(cudaStream_t stream, release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), deferred_cpu_allocator_(*this), ep_info_(ep_info) { +#ifndef USE_CUDA_MINIMAL if (own_flag) { CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); @@ -80,10 +81,12 @@ CudaStream::CudaStream(cudaStream_t stream, cudnn_handle_ = external_cudnn_handle; CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); } +#endif } CudaStream::~CudaStream() { ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); +#ifndef USE_CUDA_MINIMAL if (own_stream_) { cublasDestroy(cublas_handle_); cudnnDestroy(cudnn_handle_); @@ -91,6 +94,7 @@ CudaStream::~CudaStream() { if (handle) cudaStreamDestroy(static_cast(handle)); } +#endif } std::unique_ptr CudaStream::CreateNotification(size_t /*num_consumers*/) { diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 4df59a98b12e5..c850f7b583bfc 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -9,7 +9,7 @@ #include "core/common/gsl.h" #include "shared_inc/cuda_call.h" #include "core/providers/cpu/tensor/utils.h" - +#ifndef USE_CUDA_MINIMAL namespace onnxruntime { namespace cuda { @@ -222,3 +222,4 @@ const Float8E5M2 Consts::One = Float8E5M2(1.0f, true); } // namespace cuda } // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 8a94a334ee688..fdd14dedad47e 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -7,7 +7,7 @@ #include #include "core/providers/cuda/cuda_common.h" - +#ifndef USE_CUDA_MINIMAL namespace onnxruntime { namespace cuda { @@ -260,3 +260,4 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc, } // namespace cuda } // namespace onnxruntime +#endif From 146ebaf91e85185a0ac18c82bc69eba685ab9727 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:03:43 -0800 Subject: [PATCH 27/39] [js/web] allow proxy to load model with 1GB <= size < 2GB (#19178) ### Description allow proxy to load model with 1GB <= size < 2GB resolves #19157. --- js/web/lib/wasm/wasm-utils-load-file.ts | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/wasm-utils-load-file.ts b/js/web/lib/wasm/wasm-utils-load-file.ts index abe480a43c790..c6cdba2320bde 100644 --- a/js/web/lib/wasm/wasm-utils-load-file.ts +++ b/js/web/lib/wasm/wasm-utils-load-file.ts @@ -47,9 +47,19 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro } const reader = response.body.getReader(); - // use WebAssembly Memory to allocate larger ArrayBuffer - const pages = Math.ceil(fileSize / 65536); - const buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + let buffer; + try { + // try to create ArrayBuffer directly + buffer = new ArrayBuffer(fileSize); + } catch (e) { + if (e instanceof RangeError) { + // use WebAssembly Memory to allocate larger ArrayBuffer + const pages = Math.ceil(fileSize / 65536); + buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + } else { + throw e; + } + } let offset = 0; // eslint-disable-next-line no-constant-condition From f87e69801f200a34ddb312f1d39e7296f19b660b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:04:22 -0800 Subject: [PATCH 28/39] [js/web] show warning when numThreads is set but threads is not supported (#19179) ### Description show warning when numThreads is set but threads is not supported. Resolves #19148, #18933 for web: when crossOriginIsolated is false. for node: always disable. --- js/web/lib/backend-wasm.ts | 6 ++++++ js/web/lib/wasm/wasm-factory.ts | 33 +++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index d9f63fec9c492..31ecffb07e40c 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -31,6 +31,12 @@ export const initializeFlags = (): void => { } if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) { + // Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work. + // Node.js: onnxruntime-web does not support multi-threads in Node.js. + if ((typeof self !== 'undefined' && !self.crossOriginIsolated) || + (typeof process !== 'undefined' && process.versions && process.versions.node)) { + env.wasm.numThreads = 1; + } const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency; env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2)); } diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 81508a253ce8b..9b9334c93b78c 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -28,13 +28,34 @@ let initialized = false; let initializing = false; let aborted = false; -const isMultiThreadSupported = (): boolean => { - try { - // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. - if (typeof SharedArrayBuffer === 'undefined') { - return false; +const isMultiThreadSupported = (numThreads: number): boolean => { + // WebAssembly threads are set to 1 (single thread). + if (numThreads === 1) { + return false; + } + + // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. + if (typeof SharedArrayBuffer === 'undefined') { + if (typeof self !== 'undefined' && !self.crossOriginIsolated) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', but this will not work unless you enable crossOriginIsolated mode. ' + + 'See https://web.dev/cross-origin-isolation-guide/ for more info.'); } + return false; + } + + // onnxruntime-web does not support multi-threads in Node.js. + if (typeof process !== 'undefined' && process.versions && process.versions.node) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', however, currently onnxruntime-web does not support multi-threads in Node.js. ' + + 'Please consider using onnxruntime-node for performance critical scenarios.'); + } + try { // Test for transferability of SABs (for browsers. needed for Firefox) // https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ if (typeof MessageChannel !== 'undefined') { @@ -106,7 +127,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const numThreads = flags.numThreads!; const simd = flags.simd!; - const useThreads = numThreads > 1 && isMultiThreadSupported(); + const useThreads = isMultiThreadSupported(numThreads); const useSimd = simd && isSimdSupported(); const wasmPaths = flags.wasmPaths; From 9da3e36138dd24377fbb0b4022d891b3baf07b84 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Wed, 17 Jan 2024 20:20:42 -0500 Subject: [PATCH 29/39] Fix buildJava from Zip-Nuget-Java-Nodejs Packaging Pipeline (#19187) ### Description ### Motivation and Context --- .../c-api-noopenmp-packaging-pipelines.yml | 2 ++ .../stages/nuget-linux-cuda-packaging-stage.yml | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3803333bd880a..aa1a75bfcda45 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -204,6 +204,8 @@ stages: CudaVersion: ${{ parameters.CudaVersion }} docker_base_image: ${{ variables.docker_base_image }} linux_trt_version: ${{ variables.linux_trt_version }} + buildJava: true + buildNodejs: true #CUDA without tensorrt - template: templates/win-ci.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index dbbc9ef27e513..db9bcacbf0754 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -6,6 +6,12 @@ parameters: type: string - name: linux_trt_version type: string +- name: buildJava + type: boolean + default: false +- name: buildNodejs + type: boolean + default: false stages: # Linux CUDA without TensorRT Packaging @@ -66,9 +72,9 @@ stages: parameters: artifactName: 'onnxruntime-linux-x64-tensorrt-$(OnnxRuntimeVersion)' artifactNameNoVersionString: 'onnxruntime-linux-x64-tensorrt' - buildJava: false + buildJava: ${{ parameters.buildJava }} buildJavaOption: '--build_java' - buildNodejs: false + buildNodejs: ${{ parameters.buildNodejs }} buildNodejsOption: '--build_nodejs' CudaVersion: ${{ parameters.CudaVersion }} # Linux CUDA Combined Testing and Publishing From dadd3ea704243a8c2b2ded790ae01f3b57c4da53 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 18 Jan 2024 11:11:14 -0800 Subject: [PATCH 30/39] Check the ep_cache_context and don't allow access outside the directory (#19174) ### Description Check the ep_cache_context node property for EPContext node, and don't allow relative path like "../file_path" --- .../qnn/builder/onnx_ctx_model_helper.cc | 28 +++- .../test/providers/qnn/simple_op_htp_test.cc | 129 ++++++++++++++++++ 2 files changed, 155 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index b157396306d01..fd9bf200c45ef 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -88,9 +88,33 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, qnn_model); } - std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); - std::filesystem::path context_binary_path = folder_path.append(external_qnn_context_binary_file_name); + std::string external_qnn_ctx_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); + ORT_RETURN_IF(external_qnn_ctx_binary_file_name.empty(), "The file path in ep_cache_context should not be empty."); +#ifdef _WIN32 + onnxruntime::PathString external_qnn_context_binary_path = onnxruntime::ToPathString(external_qnn_ctx_binary_file_name); + auto ctx_file_path = std::filesystem::path(external_qnn_context_binary_path.c_str()); + ORT_RETURN_IF(ctx_file_path.is_absolute(), "External mode should set ep_cache_context field with a relative path, but it is an absolute path: ", + external_qnn_ctx_binary_file_name); + auto relative_path = ctx_file_path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory."); + } + + std::filesystem::path context_binary_path = folder_path.append(relative_path); +#else + ORT_RETURN_IF(external_qnn_ctx_binary_file_name[0] == '/', + "External mode should set ep_cache_context field with a relative path, but it is an absolute path: ", + external_qnn_ctx_binary_file_name); + if (external_qnn_ctx_binary_file_name.find("..", 0) != std::string::npos) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory."); + } + std::filesystem::path context_binary_path = folder_path.append(external_qnn_ctx_binary_file_name); + std::string file_full_path = context_binary_path.string(); +#endif + if (!std::filesystem::is_regular_file(context_binary_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible."); + } size_t buffer_size{0}; std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index c4244fe532456..4ac1f5ddca643 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -908,6 +908,135 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + // The .. in the path will cause INVALID_GRAPH + ep_context_node.AddAttribute("ep_cache_context", external_bin_path); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNN"); + helper.SetGraphOutputs(); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + return model_data; +} + +// Create a model with EPContext node. Set the node property ep_cache_context has ".." +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context has absolute path +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { +#if defined(_WIN32) + std::string external_ctx_bin_path = "D:/qnn_context.bin"; +#else + std::string external_ctx_bin_path = "/data/qnn_context.bin"; +#endif + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to empty string +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + // Run QDQ model on HTP with 2 inputs // 1st run will generate the Qnn context cache onnx file // 2nd run will load and run from QDQ model + Qnn context cache model From dd2177c5d70b8e5b704f7ee0ddce134243eacb24 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 18 Jan 2024 13:11:47 -0800 Subject: [PATCH 31/39] enable webnn in ci build (#19163) ### Description ### Motivation and Context --- .../github/azure-pipelines/templates/linux-wasm-ci.yml | 4 ++-- .../ci_build/github/azure-pipelines/templates/win-wasm-ci.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index d279e667f9091..360e3d5ef879b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -174,7 +174,7 @@ jobs: ${{ else }}: AdditionalKey: wasm_simd_jsep | ${{ parameters.BuildConfig }} CacheDir: $(ORT_CACHE_DIR)/wasm_simd_jsep - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_jsep --enable_wasm_simd --use_jsep --target onnxruntime_webassembly --skip_tests' + Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_jsep --enable_wasm_simd --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' DisplayName: 'Build (simd + JSEP)' WithCache: ${{ parameters.WithCache }} - template: build-linux-wasm-step.yml @@ -185,7 +185,7 @@ jobs: ${{ else }}: AdditionalKey: wasm_simd_threads_jsep | ${{ parameters.BuildConfig }} CacheDir: $(ORT_CACHE_DIR)/wasm_simd_threads_jsep - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --target onnxruntime_webassembly --skip_tests' + Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' DisplayName: 'Build (simd + threads + JSEP)' WithCache: ${{ parameters.WithCache }} diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 79647cc5699c8..f2005ec5ada39 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -127,14 +127,14 @@ jobs: displayName: 'Build (simd + JSEP)' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_jsep --enable_wasm_simd --use_jsep --target onnxruntime_webassembly --skip_tests' + arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_jsep --enable_wasm_simd --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' workingDirectory: '$(Build.BinariesDirectory)' - ${{ if eq(parameters.BuildJsep, true) }}: - task: PythonScript@0 displayName: 'Build (simd + threads + JSEP)' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --target onnxruntime_webassembly --skip_tests' + arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' workingDirectory: '$(Build.BinariesDirectory)' - ${{ if eq(parameters.SkipPublish, false) }}: - script: | From 459c750b031339456e4061b1c4214904e6853ccd Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Fri, 19 Jan 2024 05:16:34 +0800 Subject: [PATCH 32/39] Update x64 template kernel library for 'sqnbitgemm' (#19016) ### Description 1. Make JBLAS codes an external module of ORT. 2. Move q4 gemm code to contrib_ops. 3. Update template kernel library to v0.1 release. ### Motivation and Context We found that the current LLM model performance is far below our expectations. Here is some performance data collected on Mistral-7B model with Xeon-8480: 8 threads | prompt length=32 past_len=32 | prompt length=1 past_len=32 -- | -- | -- ORT-main | 1220ms | 263ms Neural-speed | 564ms | 87ms ORT-this PR|597ms|120ms Although `Neural-speed` and `ORT-this PR` use the same int4 kernel code, there is a 33ms(87ms vs. 120ms) latency gap between the two frameworks. Through some statistics analysis, the summary latency of `MatMulNBits` is 86.7ms The summary latency of all int4 GEMMs in `Neural-speed` is 84.8ms. So other OPs introduce an extra 30ms latency. The performance of MatMulNBits in this PR meets our expectations. ### Remain Issues 1. For hybrid CPUs, like core 12900K, the ONNXRuntime thread pool uses TaskGranularityFactor to scale its number of threads. This is not expected in our code design. It may slow down the hybrid CPU performance by 30~40%. 2. Prepack uses a single thread which is very slow to init a session. 3. MatMulNBits with zero points will fall through to COMP_FP32 even accuracy_level=4. Our COMP_INT8 IGemmCore with zero points process is not optimized for now. It will be updated in the future. So, for an int4 model with zero points, whether the accuracy_level is 0 or 4 will be no difference. --- cmake/CMakeLists.txt | 18 +- cmake/deps.txt | 2 +- cmake/external/neural_speed.cmake | 18 + cmake/onnxruntime_mlas.cmake | 13 - cmake/onnxruntime_providers_cpu.cmake | 15 + .../cpu/quantization/matmul_nbits.cc | 58 +- .../cpu/quantization/neural_speed_defs.h | 45 + .../cpu/quantization/neural_speed_gemm.cc | 438 ++ .../cpu/quantization/neural_speed_gemm.h | 129 + .../cpu/quantization/neural_speed_wrapper.h | 39 + onnxruntime/core/mlas/inc/mlas_qnbit.h | 130 - onnxruntime/core/mlas/lib/jblas_defs.h | 73 - onnxruntime/core/mlas/lib/jblas_gemm.cpp | 534 -- onnxruntime/core/mlas/lib/jblas_gemm.h | 61 - onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 128 - .../core/mlas/lib/x86_64/jblas/.clang-format | 7 - .../core/mlas/lib/x86_64/jblas/CMakeLists.txt | 33 - .../mlas/lib/x86_64/jblas/jblas/jit_base.h | 303 -- .../mlas/lib/x86_64/jblas/jblas/jit_blas.h | 96 - .../lib/x86_64/jblas/jblas/jit_blas_device.h | 277 - .../x86_64/jblas/jblas/jit_blas_epilogue.h | 329 -- .../lib/x86_64/jblas/jblas/jit_blas_gemm.h | 2699 ---------- .../x86_64/jblas/jblas/jit_blas_parallel.h | 678 --- .../x86_64/jblas/jblas/jit_blas_prologue_a.h | 214 - .../x86_64/jblas/jblas/jit_blas_prologue_b.h | 892 ---- .../lib/x86_64/jblas/jblas/jit_blas_storage.h | 665 --- .../lib/x86_64/jblas/jblas/jit_blas_utils.h | 638 --- .../lib/x86_64/jblas/jblas/jit_blas_wrapper.h | 281 - .../mlas/lib/x86_64/jblas/jblas/kernel_avx2.h | 874 --- .../x86_64/jblas/jblas/kernel_avx512_bf16.h | 92 - .../lib/x86_64/jblas/jblas/kernel_avx512f.h | 1966 ------- .../mlas/lib/x86_64/jblas/jblas/kernel_jit.h | 1375 ----- .../x86_64/jblas/jblas/kernel_jit_injector.h | 930 ---- .../mlas/lib/x86_64/jblas/jblas/kernel_ref.h | 1039 ---- .../lib/x86_64/jblas/jblas/kernel_wrapper.h | 702 --- .../mlas/lib/x86_64/jblas/jblas/xbyak/xbyak.h | 3313 ------------ .../x86_64/jblas/jblas/xbyak/xbyak_bin2hex.h | 271 - .../x86_64/jblas/jblas/xbyak/xbyak_mnemonic.h | 4728 ----------------- .../lib/x86_64/jblas/jblas/xbyak/xbyak_util.h | 1160 ---- .../test/contrib_ops/matmul_4bits_test.cc | 49 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 61 - 41 files changed, 753 insertions(+), 24620 deletions(-) create mode 100644 cmake/external/neural_speed.cmake create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h delete mode 100644 onnxruntime/core/mlas/lib/jblas_defs.h delete mode 100644 onnxruntime/core/mlas/lib/jblas_gemm.cpp delete mode 100644 onnxruntime/core/mlas/lib/jblas_gemm.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_a.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_storage.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_utils.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_wrapper.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx2.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx512_bf16.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx512f.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_jit.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_jit_injector.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_ref.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_wrapper.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak_bin2hex.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak_mnemonic.h delete mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak_util.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 712d5d76108aa..7d7304630c00e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -910,6 +910,10 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) endif() + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) + endif() + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span @@ -1194,14 +1198,10 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -set(USE_JBLAS FALSE) -if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) - elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) +if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) + include(neural_speed) + if (USE_NEURAL_SPEED) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) endif() endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index ff07803013071..fda27e5e93797 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 \ No newline at end of file diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake new file mode 100644 index 0000000000000..e66e2acfb209a --- /dev/null +++ b/cmake/external/neural_speed.cmake @@ -0,0 +1,18 @@ +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + set(USE_NEURAL_SPEED TRUE) +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + set(USE_NEURAL_SPEED TRUE) +endif() + +if(USE_NEURAL_SPEED) + FetchContent_Declare( + neural_speed + URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip + URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 + ) + set(BTLA_USE_OPENMP OFF) + FetchContent_MakeAvailable(neural_speed) + if(NOT neural_speed_POPULATED) + FetchContent_Populate(neural_speed) + endif() +endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b995b27123218..f89d2150a6830 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -57,15 +57,6 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) -function(add_jblas) - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) - target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/jblas_gemm.cpp - ) - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) -endfunction() - #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -622,10 +613,6 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() -if(USE_JBLAS) - add_jblas() -endif() - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f60faa4d39116..b81a5c79ac0cc 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + set(onnxruntime_cpu_neural_speed_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" + ) + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(USE_NEURAL_SPEED) + onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) + endif() +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 406c73c95d444..72948c74d7877 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,9 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#endif namespace onnxruntime { namespace contrib { @@ -24,15 +27,17 @@ class MatMulNBits final : public OpKernel { accuracy_level_{info.GetAttr("accuracy_level")} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); - is_asym_ = info.GetInputCount() >= 4; +#ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; const Tensor* tensor_zero_point = nullptr; bool B_constant = info.TryGetConstantInput(1, &tensor_B); bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + is_asym_ = info.GetInputCount() >= 4; all_constant_ = B_constant && scale_constant; all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; +#endif } Status Compute(OpKernelContext* context) const override; @@ -53,30 +58,34 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; +#ifdef ORT_NEURAL_SPEED bool is_asym_{false}; bool all_constant_{false}; +#endif }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; +#ifdef ORT_NEURAL_SPEED if (!all_constant_) { return Status::OK(); } - -#if defined(MLAS_JBLAS) - - auto compt_type = static_cast(accuracy_level_); MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } + auto comp_type = static_cast(accuracy_level_); + auto nbits = static_cast(nbits_); if (input_idx == 1) { - packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); if (packed_b_size_ == 0) return Status::OK(); auto qptr = tensor.Data(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); std::memset(packed_b_.get(), 0, packed_b_size_); - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, false, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -85,8 +94,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 2 && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, !is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -95,8 +104,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 3 && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -104,7 +113,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(MLAS_JBLAS) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); @@ -119,7 +128,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -127,9 +136,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - -#if defined(MLAS_JBLAS) - +#ifdef ORT_NEURAL_SPEED // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -144,14 +151,14 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep packed_b_ = std::move(prepacked_buffers[0]); } -#else // defined(MLAS_JBLAS) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -160,9 +167,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); - -#if defined(MLAS_JBLAS) - +#ifdef ORT_NEURAL_SPEED if (packed_b_.get()) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); @@ -181,7 +186,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); + std::vector gemm_params(max_len); AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); @@ -192,15 +197,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), - thread_pool); + NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); return Status::OK(); } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h new file mode 100644 index 0000000000000..864abffd131fe --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" + +namespace bestla { + +using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; +using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; +using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; +using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; + +template +using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; +template +using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; + +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { + (void)(nthreads); + assert(0); + } + void sync() const override { assert(0); } + void* mTp; +}; + +} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc new file mode 100644 index 0000000000000..73aaa4ae61a6e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -0,0 +1,438 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.cpp + +Abstract: + + GEMM template combinations of neural_speed. +--*/ + +#include "contrib_ops/cpu/quantization/neural_speed_defs.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; + +namespace bestla { + +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, + [&](ptrdiff_t tid) { func(static_cast(tid)); }); +} + +template +static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::Param args{gp, + {A, lda_, &reduceA}, + {B}, + {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}, + {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(A); + (void)(N); + (void)(C); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + // using ProA = prologue_a::gemm::ActivationBase; + return 0; + } +} + +template +static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + (void)(A); + (void)(C); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && + BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && + BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), + BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + auto BlkSize = wptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} + +template +static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, + const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + [[fallthrough]]; + default: + return 0; + } +} + +static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, + size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, + K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, + lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, + ldb, ThreadPool); + return true; + } + [[fallthrough]]; + default: + return false; + } +} + +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, + NS_SQNBIT_COMPUTE_TYPE CompType) { + if (nbits == 4) { + auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } + return 0; +} + +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + if (nbits == 4) { + if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +} + +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +} + +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +} + +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by bestla + return; + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h new file mode 100644 index 0000000000000..ebcb3027a209f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.h + +Abstract: + + Prepack-weight GEMM APIs of neural_speed. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Define compute types of block quantization + */ +enum NS_SQNBIT_COMPUTE_TYPE { + NSCompUndef = 0, /*!< undef */ + NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ + NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ + NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ + NSCompInt8 = 4 /*!< input int8, accumulator int32 */ +}; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, + NS_SQNBIT_COMPUTE_TYPE comp_type); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, + NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h new file mode 100644 index 0000000000000..d3902f9bd68c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -0,0 +1,39 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" + +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4457) +#pragma warning(disable : 4189) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4702) +#endif + +#include "bestla/bestla_prologue_a.h" +#include "bestla/bestla_wrapper.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index bc0bfc92c85a0..047011e70bd4d 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -183,133 +183,3 @@ MlasSQNBitGemmPackQuantBData( void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr ); - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type -); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t block_size, - int nbits, - bool is_asym, - bool last_call, - MLAS_SQNBIT_COMPUTE_TYPE comp_type, - MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmUnPackB( - float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Get the workspace size required by computation. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes - */ -size_t MLASCALL -MlasSQNBitsGemmBatchPackedBWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); - -/** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool = nullptr -); diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h deleted file mode 100644 index 9cd1711a3ffd2..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_defs.h +++ /dev/null @@ -1,73 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - ---*/ - -#pragma once - -#include "jblas/jit_blas_prologue_b.h" -#include "jblas/jit_blas_wrapper.h" - -namespace jblas -{ - -/* -Name conversion explaination: -Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight -classes) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. -*/ -template -using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationKBlockBaseF32, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompFp32BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -/* -Name conversion explaination: -Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. -*/ -template -using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationF32KBlockQuantize, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompInt8BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; -using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; -using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency -using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; -using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency - -class ORTThreading : public jblas::parallel::IThreading -{ - public: - ORTThreading(void* tp); - void parallel_for(const jblas::parallel::thread_func& func) override; - void set_threads(int nthreads) override { assert(0); } - void sync() override { assert(0); } - void* mTp; -}; - -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp deleted file mode 100644 index f3cae3186c28e..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.cpp +++ /dev/null @@ -1,534 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.cpp - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#include "jblas_gemm.h" - -#include "jblas_defs.h" -#include "mlasi.h" - -using namespace jblas; - -jblas::ORTThreading::ORTThreading(void* tp) - : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) -{ -} - -void -jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) -{ - MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { - func(static_cast(tid)); - }); -} - -template -static void -JblasSQ4GemmCompF32( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - if (M <= 16) { - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->mIsAsym) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); - } - typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), - reduceA.template get(), reduceA.lda}; - - typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); - } else { - using Parallel = jblas::parallel::gemm::SchedulerBase; - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - - typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; - jblas::parallel::GemmBaseRun(kernel, args, th); - } -} - -template -static void -JblasSQ4GemmCompInt8( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - typename Launcher::Param args{ - M_, - N_, - K_, - B->mBlockSize, - {A, lda_, &quanA}, - {B}, - {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, - quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), - quanA.template RPtr(), B->mBlockSize}, - {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); -} - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -template -static size_t -JblasSQ4GemmCompF32WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(N); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - if (B->mIsAsym) { - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - return 0; - } - return 0; -} - -template -static size_t -JblasSQ4GemmCompInt8WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - static Launcher kernel; - (void)(N); - (void)(lda); - (void)(ldc); - auto quanA = kernel.mProA.createStorage( - static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym - ); - return quanA.mSize; -} - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - } - } - } - return size; -} - -template -static size_t -JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) -{ - static T launcher; - auto stor = launcher.mProB.createStorage( - static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, - JBLAS_DTYPE::BF16, isAsym - ); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) -{ - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - break; - default: - return 0; - } - return 0; -} - -template -static void -JblasQ4GemmPackBImpl( - void* PackedBuf, - size_t BlkSize, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - bool IsAsym, - bool lastCall, - size_t ldb, - MLAS_THREADPOOL* ThreadPool -) -{ - static T JblasKernel; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = JblasKernel.mProB.createStorage( - N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym - ); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - JblasKernel.mProB.reduceWeight(&stor, &orth); - } -} - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - default: - return false; - } - return false; -} - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h deleted file mode 100644 index 044dc5e849a0a..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.h +++ /dev/null @@ -1,61 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.h - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#pragma once - -#include "mlas_qnbit.h" - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -); - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb - , MLAS_THREADPOOL* ThreadPool); - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -); - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 7d877848017fe..0d8a5692359a6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,10 +19,6 @@ Module Name: #include -#ifdef MLAS_JBLAS -#include "jblas_gemm.h" -#endif - namespace { @@ -694,127 +690,3 @@ MlasSQNBitGemmBatch( ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } - -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } -#endif - (void)(N); - (void)(K); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(CompType); - return 0; -} - -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - int nbits, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -#endif - (void)(PackedBuf); - (void)(QData); - (void)(Scale); - (void)(Zp); - (void)(N); - (void)(K); - (void)(ldb); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(lastCall); - (void)(CompType); - (void)(ThreadPool); -} - -void MLASCALL -MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ -#ifdef MLAS_JBLAS - if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -#endif - (void)(FpData); - (void)(PackedBuf); - (void)(N); - (void)(K); - (void)(ldb); - (void)(ThreadPool); -} - -size_t MLASCALL -MlasSQNBitsGemmBatchPackedBWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ -#ifdef MLAS_JBLAS - return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - return 0; -} - -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetMlasPlatform(); -#ifdef MLAS_JBLAS - if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by jblas - return; - } -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - (void)(WorkSpace); - (void)(ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format deleted file mode 100644 index 84b876706161d..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format +++ /dev/null @@ -1,7 +0,0 @@ -Language: Cpp -BasedOnStyle: Google -DerivePointerAlignment: false -ColumnLimit: 120 -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt deleted file mode 100644 index 5d9c5edf45a96..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -cmake_minimum_required(VERSION 3.5) - -project(jblas LANGUAGES CXX VERSION 0.1.0) - -file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) -file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) - -add_library(${PROJECT_NAME} INTERFACE) -add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) - -target_include_directories( - ${PROJECT_NAME} INTERFACE - "$" - "$" -) - -if(WIN32) - target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) - target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) - #4068 ignore unroll and GCC flags - #4849 ignore collapse - #6262 ignore stack too large - #4702 unreachable code(false warning on constexpr condition) - #4100 unreferenced formal parameter - - target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size -endif(WIN32) - - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h deleted file mode 100644 index 143adb771760b..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include -#include -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -#define OFFSET(field) offsetof(params, field) - -namespace jblas { - -namespace xbyak { -class JitBase : protected Xbyak::CodeGenerator { - protected: - JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} - - void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { - xor_(reg, reg); - mov(reg.cvt32(), addr); - } - - void vreg_push(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); - } -#endif - } - - void vreg_pop(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); - } -#endif - } - - void padto_le(const Xbyak::Reg64& _src, int padding) { - // _src=_src/padding*padding - if (padding == 1) { - return; - } - for (int i = 1; i < 16; i++) { - if ((1 << i) == padding) { - shr(_src, i); - shl(_src, i); - return; - } - } - assert(0); - } - - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - inLocalLabel(); - lea(_tmp, _total); - sub(_tmp, _pos); - cmp(_tmp, N); - jb(".maskflag"); - cmp(_tmp, 0); - jl(".zeroflag"); - uint64_t allmask = (static_cast(1) << N) - 1; - if (N == 64) { - allmask = static_cast(-1); - } - mov(_tmp, allmask); - kmovq(_msk, _tmp); - jmp(".maskend"); - L(".maskflag"); - mov(_tmp1, 1); - shlx(_tmp1, _tmp1, _tmp); - sub(_tmp1, 1); - kmovq(_msk, _tmp1); - jmp(".maskend"); - L(".zeroflag"); - mov(_tmp1, 0); - kmovq(_msk, _tmp1); - L(".maskend"); - outLocalLabel(); - } - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); - } -}; - -class JitAvx : protected JitBase { - protected: - static int constexpr VBits = 256; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 16; - typedef Xbyak::Ymm vreg_t; -}; - -class JitAvx2 : protected JitAvx { - protected: - static int constexpr VBits = 256; - typedef Xbyak::Ymm vreg_t; - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } - - void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } -}; - -class JitAvx512f : protected JitAvx2 { - protected: - static int constexpr VBits = 512; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 32; - typedef Xbyak::Zmm vreg_t; - - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } - - void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { - vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); - vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); - vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); - vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - } - - void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { - for (int i = 0; i < 8; ++i) { - vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); - vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); - } - - for (int i = 0; i < 4; ++i) { - vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); - vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); - } - - for (int i = 0; i < 2; ++i) { - vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); - vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); - vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); - vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); - vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); - vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); - vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); - vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); - } - - // last step and move out - for (int i = 0; i < N; ++i) { - vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); - } - } - - void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { - vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); - vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); - vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); - vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); - - vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); - vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); - vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); - vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); - vmovups(src_4regs[0], tmp_regs[1]); - vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[1], tmp_regs[3]); - vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); - vmovups(src_4regs[2], tmp_regs[1]); - vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[3], tmp_regs[3]); - vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - } - - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { - vpsrld(_fp32, _fp32, 16); - vpmovdw(_bf16, _fp32); - } - - void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } - - void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { - mov(tmp.cvt16(), addr); - shl(tmp.cvt32(), 16); - vpbroadcastd(dst, tmp.cvt32()); - } - - void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { - auto bf16 = Xbyak::Ymm(_fp32.getIdx()); - cvt_fp32_bf16(bf16, _fp32); - vmovups(_add, bf16); - } -}; - -class JitAvx512_bf16 : protected JitAvx512f {}; - -class JitAvx512_fp16 : protected JitAvx512f {}; - -class JitAvx512vnni : protected JitAvx512f { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); - } -}; - -class JitAvxvnni : protected JitAvx2 { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::VexEncoding); - } -}; - -class JitAmxtile : protected JitAvx512f { - public: - struct alignas(64) tileconfig_t { - uint8_t palette_id; - uint8_t reserved[15]; - uint16_t colb[16]; - uint8_t rows[16]; - }; - static int constexpr TileCount = 8; - - typedef long long (*configure_t)(void*); - - static void generate_config(Xbyak::CodeGenerator* g) { - Xbyak::util::StackFrame st(g, 1, 0, 0); - auto& parambase = st.p[0]; - g->ldtilecfg(g->ptr[parambase]); - } - - static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, - int CNum) { - // Filling tile configure structure. Could be done offline. - tc.palette_id = 1; - // Configure C tiles - int t = 0; - for (; t < CNum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_N * 4); - } - // Configure A tiles - for (; t < CNum + ANum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_K * elesize); - } - // Configure B tile. B effectively has 64 rows and 16 columns. - int kpack = 4 / elesize; - for (; t < CNum + ANum + BNum; ++t) { - tc.rows[t] = static_cast(TILE_K / kpack); - tc.colb[t] = static_cast(TILE_N * 4); - } - } -}; - -class JitAmxbf16 : protected JitAmxtile { - protected: - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } -}; - -class JitAmxint8 : protected JitAmxtile { - protected: - template - void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); -}; -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbssd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbsud(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbusd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbuud(x1, x2, x3); -} -} // namespace xbyak -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h deleted file mode 100644 index 8ecf3535c17f4..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include -enum JBLAS_CODE { - JblasSuccess = 0, - JblasInvalidParam = 1, - JblasInvalidISA = 2, - JblasRuntimeError = 4, - JblasNotSupport = 8, -}; -enum JBLAS_ISA : uint32_t { - JblasNoSIMD = 0, - JblasAVX, - JblasAVX2, - JblasAVX_VNNI, - JblasAVX512F, - JblasAVX512_VNNI, - JblasAMX_BF16, - JblasAMX_INT8, - JblasAVX512_FP16, - JblasAVX512_BF16, -}; -enum class JBLAS_DTYPE : uint32_t { - EleBitsMask = 0xff, - EleBitsUndef = 0, - EleBits4 = 4, - EleBits8 = 8, - EleBits16 = 16, - EleBits32 = 32, - EleBits64 = 64, - TypeMask = 0xff00, - TypeFloat = 0 << 8, - TypeInt = 1 << 8, - SubTypeMask = 0xff0000, - SubType0 = 0 << 16, - SubType1 = 1 << 16, - SubType2 = 2 << 16, - F64 = EleBits64 | TypeFloat, - F32 = EleBits32 | TypeFloat, - F16 = EleBits16 | TypeFloat, - BF16 = EleBits16 | TypeFloat | SubType1, - F8_E4M3 = EleBits8 | TypeFloat, - F8_E5M2 = EleBits8 | TypeFloat | SubType1, - F8_E3M4 = EleBits8 | TypeFloat | SubType2, - S8 = EleBits8 | TypeInt, - U8 = EleBits8 | TypeInt | SubType1, - S4_CLIP = EleBits4 | TypeInt, - S4_FULLRANGE = EleBits4 | TypeInt | SubType1, - F4_E2M1 = EleBits4 | TypeFloat, - F4_BNB = EleBits4 | TypeFloat | SubType1, - F4_NF4 = EleBits4 | TypeFloat | SubType2, - S32 = EleBits32 | TypeInt, - U32 = EleBits32 | TypeInt | SubType1, -}; - -enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; -enum JBLAS_TRANSPOSE { - JblasNoTrans = 111, - JblasTrans = 112, - JblasConjTrans = 113, -}; -enum JBLAS_ELTWISEOP { - GELU, - SWISH, - TANH, - EXP, - LOW_PRECISION_EXP, - RELU, - LINEAR, -}; - -enum class JBLAS_PROLOGUEB_IDS : uint32_t { - Undef = (uint32_t)-1, - Begin = 0, - NormalBegin = Begin, - WeightPack = NormalBegin, - NormalEnd, - KBlockBegin = NormalEnd, - WeightKBlockS8 = KBlockBegin, - WeightKBlockS4, - WeightKBlockF4, - KBlockEnd, - End, -}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h deleted file mode 100644 index 5cac1080bc610..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "jit_blas.h" -#include "xbyak/xbyak_util.h" - -namespace jblas { - -namespace device { - -struct X64_ISA { - int64_t MMX : 1; // 0 - int64_t SSE : 1; // 1 - int64_t SSE2 : 1; // 2 - int64_t SSE3 : 1; // 3 - int64_t SSSE3 : 1; // 4 - int64_t SSE41 : 1; // 5 - int64_t SSE42 : 1; // 6 - int64_t AVX : 1; // 7 - int64_t F16C : 1; // 8 - int64_t FMA : 1; // 9 - int64_t AVX2 : 1; // 10 - int64_t AVX_VNNI : 1; // 11 - int64_t AVX_VNNI_INT8 : 1; // 12 - int64_t AVX_NE_CONVERT : 1; // 13 - int64_t AVX_IFMA : 1; // 14 - int64_t AVX512F : 1; // 15 - int64_t AVX512BW : 1; // 16 - int64_t AVX512CD : 1; // 17 - int64_t AVX512DQ : 1; // 18 - int64_t AVX512ER : 1; // 19 - int64_t AVX512IFMA52 : 1; // 20 - int64_t AVX512PF : 1; // 21 - int64_t AVX512VL : 1; // 22 - int64_t AVX512VPOPCNTDQ : 1; // 23 - int64_t AVX512_4FMAPS : 1; // 24 - int64_t AVX512_4VNNIW : 1; // 25 - int64_t AVX512_BF16 : 1; // 26 - int64_t AVX512_BITALG : 1; // 27 - int64_t AVX512_VBMI : 1; // 28 - int64_t AVX512_VBMI2 : 1; // 29 - int64_t AVX512_VNNI : 1; // 30 - int64_t AVX512_VP2INTERSECT : 1; // 31 - int64_t AVX512_FP16 : 1; // 32 - int64_t AMX_TILE : 1; // 33 - int64_t AMX_BF16 : 1; // 34 - int64_t AMX_INT8 : 1; // 35 - int64_t AMX_FP16 : 1; // 36 - int64_t AMX_COMPLEX : 1; // 37 - int64_t reserved : (64 - 38); -}; - -class AVX2_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 0; - static constexpr bool AVX512BW = 0; - static constexpr bool AVX512CD = 0; - static constexpr bool AVX512DQ = 0; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 0; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 0; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class AVX512_VNNI_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class SapphireRapids { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 1; - static constexpr bool AMX_BF16 = 1; - static constexpr bool AMX_INT8 = 1; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -template -class isa_base { - public: - static bool constexpr avx = ISA_T >= JblasAVX; - static bool constexpr avx2 = ISA_T >= JblasAVX2; - static bool constexpr avx512f = ISA_T >= JblasAVX512F; - static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; - static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; - static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; - static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; -}; - -class CpuDevice { - public: - inline void setThreads(int _nth) { - if (_nth <= 0) { - numthreads = numcores; - } else { - numthreads = std::min(numcores, _nth); - } - } - inline int getThreads() { return numthreads; } - inline int getCores() { return numcores; } - inline uint32_t getL2CacheSize() { return L2Cache; } - inline uint32_t getL1CacheSize() { return L1Cache; } - inline bool AVX() { return mHasAVX; } - inline bool AVX2() { return mHasAVX2; } - inline bool AVX_VNNI() { return mHasAVX_VNNI; } - inline bool AVX512F() { return mHasAVX512F; } - inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } - inline bool AMX_INT8() { return mHasAMX_INT8; } - inline bool AMX_BF16() { return mHasAMX_BF16; } - inline bool AVX512_BF16() { return mHasAVX512_BF16; } - inline bool AVX512_FP16() { return mHasAVX512_FP16; } -#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) - CpuDevice() { - static Xbyak::util::Cpu _cpu; - L1Cache = _cpu.getDataCacheSize(0); - L2Cache = _cpu.getDataCacheSize(1); - ADD_FLAG(AVX); - ADD_FLAG(AVX2); - ADD_FLAG(AVX512F); - ADD_FLAG(AVX512_VNNI); - ADD_FLAG(AVX_VNNI); - ADD_FLAG(AMX_BF16); - ADD_FLAG(AMX_INT8); - ADD_FLAG(AVX512_BF16); - ADD_FLAG(AVX512_FP16); - numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); - numthreads = numcores; - } - - static CpuDevice* getInstance() { - static CpuDevice instance; - return &instance; - } - - void print() { - printf( - "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", - mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, - mHasAVX512_FP16); - } -#undef ADD_FLAG - - protected: - uint32_t L2Cache, L1Cache; - bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, - mHasAVX512_FP16; - int numcores; - int numthreads; -}; - -#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); - -class CpuBase { - public: - CpuBase() { - GetCPUDevice(); - mL2Cache = _cd->getL2CacheSize(); - mL1Cache = _cd->getL1CacheSize(); - mNumThreads = _cd->getThreads(); - } - size_t mL2Cache, mL1Cache; - int mNumThreads; -}; -} // namespace device -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h deleted file mode 100644 index ceb7a545092d8..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_base.h" -#include "jit_blas.h" -#include "jit_blas_utils.h" -#include "kernel_wrapper.h" - -namespace jblas { -namespace epilogue { -namespace gemm { - -template -class AccumulatorWriteBack { - public: - using SType = _SRC_T; - using DType = _DST_T; - struct Param { - DType* C; - int ldc; - void* elt_const_v; - }; - - template - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - bool constexpr Valid = !std::is_same::value || std::is_same::value; - static_assert(Valid, "fp32 to bf16 conversion only."); - if constexpr (std::is_same::value) { - return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (std::is_same, std::tuple>::value) { - return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (sizeof(SType) == sizeof(DType)) { - return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v, ops...); - } else { - assert(false); - } - } -}; - -template -class CustomAccumulatorWriteBackWithEltop { - public: - struct Param { - _DST_T* C; - int ldc; - void* elt_const_v; - }; - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { - return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v); - } else { - assert(false); - } - } -}; -template -using AccumulatorWriteBackFp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackInt32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackBf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; - -template -using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; - -template -using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; - -template -class AlphaBetaProcessFp32 { - public: - struct Param { - float *C, *D; - int ldc, ldd; - float alpha, beta; - }; - - JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto DOffset = M_offset * _param.ldd + N_offset; - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto dptr = _param.D + DOffset; - return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, - dptr, _param.ldd, cptr, _param.ldc, M, N); - } -}; - -template -class CompFp32BlockEpilogue { - public: - struct Param { - void* scales; - JBLAS_DTYPE scaledtype; - int ldsb; - int8_t* zps = nullptr; - float* reduce = nullptr; - int ldra; - }; - JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - auto ret = JblasNotSupport; - if (_param.scaledtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(ret == JblasSuccess); - if (_param.zps != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( - dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, - _param.reduce + M_offset * _param.ldra + K_offset); - } - assert(ret == JblasSuccess); - return ret; - } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(_param.zps == nullptr); - assert(ret == JblasSuccess); - return ret; - } - return JblasNotSupport; - } -}; - -template -class DequantInt32ToFp32 { - public: - struct Param { - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, _param.ldsa, - _param.scalesB + N_offset); - } -}; - -template -class CompInt8BlockEpilogue { - public: - struct Param { - void* scalesB; - JBLAS_DTYPE scaleBdtype; - int ldsb; - float* scalesA; - int ldsa; - // optional if A asym - uint8_t* zpA = nullptr; - void* reduceB = nullptr; - JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - JBLAS_CODE ret = JblasNotSupport; - float* scab = nullptr; - size_t ScaleBTmpSize = N * sizeof(float); - size_t ReduceBTmpSize = N * sizeof(float); - assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); - if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { - auto scache = reinterpret_cast(tmpcache); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - scab = scache; - } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; - } - float* redb = nullptr; - if (_param.reduceB) { - if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { - auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - redb = rcache; - } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { - redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; - } - } - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); - assert(ret == JblasSuccess); - ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, - dstptr, cachestep, M, N); - assert(ret == JblasSuccess); - - if (_param.zpA == nullptr) { - if (_param.zpB == nullptr) { - return ret; - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa + K_offset); - } - } else { - if (_param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, - _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); - } - } - return ret; - } -}; - -template -class ZpDequantInt32ToFp32 { - public: - struct Param { - // necessary - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - // optional if A asym - uint8_t* zpA = nullptr; - float* reduceB = nullptr; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.scalesB + N_offset); - if (ret != JblasSuccess) { - return ret; - } - if (_param.zpA == nullptr && _param.zpB == nullptr) { - return ret; - } else if (_param.zpA != nullptr && _param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.reduceB + N_offset); - } else if (_param.zpA == nullptr && _param.zpB != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, - _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, - _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); - } - return ret; - } -}; - -template -class AlphaBetaProcessS32U8 { - public: - struct Param { - uint8_t* C; - int ldc; - float alpha; - float scaleAcc, scaleC; - int zpC; - }; - - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, - M, N, _param.scaleAcc, _param.scaleC, _param.zpC); - } -}; - -} // namespace gemm -} // namespace epilogue -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h deleted file mode 100644 index 364da9223940f..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h +++ /dev/null @@ -1,2699 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_blas_utils.h" -#include "jit_base.h" - -namespace jblas { -namespace gemm { -enum class CompType : uint32_t { - COMP_FP32 = 0, - COMP_BF16_FP32 = 1, - COMP_FP16_FP16 = 2, - COMP_INT_START = 3, - COMP_INT8_US_INT32 = COMP_INT_START, - COMP_INT8_UU_INT32 = 4, - COMP_INT8_SS_INT32 = 5, - COMP_INT8_SU_INT32 = 6, - COMP_INT16_SS_INT32 = 7, - COMP_INT8_US_FP32 = 8, - COMP_INT8_UU_FP32 = 9, - COMP_INT8_SS_FP32 = 10, - COMP_INT8_SU_FP32 = 11, -}; - -class CoreAttr { - public: - // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| - static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, - COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; - - static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } - static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { - return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); - } - - static void parse_id(uint32_t id, uint32_t* vals) { - vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); - vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); - vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); - } - - static const char* to_str(uint32_t id) { - static char tmp[128]; - uint32_t vals[4]; - parse_id(id, vals); - sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); - return tmp; - } - - static inline size_t get_bsize(uint32_t id) { - auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - return size_t(4 / packrow); - } -}; - -namespace code { - -template -class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { - public: - static int constexpr RegLen = 8, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { - public: - static int constexpr RegLen = 32, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; - typedef utils::fp16 AType; - typedef utils::fp16 BType; - typedef utils::fp16 CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { - public: - static int constexpr RegLen = 8, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; - -template -class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; - static uint32_t constexpr COMPUTE = - (uint32_t)(std::is_same_v - ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 - : std::is_same_v ? CompType::COMP_INT8_US_INT32 - : CompType::COMP_INT8_UU_INT32); - using AType = AT; - using BType = BT; - typedef int32_t CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; -template -using Amxint8N16P4US = Amxint8N16P4; - -template -using Amxint8N16P4SS = Amxint8N16P4; - -class AmxConfigure : protected jblas::xbyak::JitAmxtile { - public: - typedef long long (*func_t)(tileconfig_t*); - - static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { - static AmxConfigure code; - tileconfig_t cfg; - std::memset(&cfg, 0, sizeof(cfg)); - configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); - code.mKernel(&cfg); - } - - protected: - AmxConfigure() { - generate_config(this); - mKernel = getCode(); - } - - func_t mKernel = nullptr; -}; - -namespace kblock { -// optimize for kblock gemm, each block size in k dimension has dequant operation -// all accumulators use fp32 dtype. -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; - typedef uint8_t AType; - typedef int8_t BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - uint8_t* zpA; - float* scaleA; - int ldsa; - float* scaleB; - float* reduceB; - int ldsb; - int k; - int n; - int kblock; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_iterkb; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_tmp4; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = NRegs; - CReg = 0; - CF32Reg = CReg + CRegCount; - BReg = CF32Reg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg < RegCount); - TmpRegCount = RegCount - TmpReg; - assert(TmpRegCount >= 1); - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_iterkb = st.t[12]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_tmp4 = st.t[11]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - xor_(reg_iterkb, reg_iterkb); - L(".kloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - xor_(reg_tmp2, reg_tmp2); - load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); - mov(reg_tmp, reg_tmp3); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kbloop", T_NEAR); - L(".unkbloop"); - generate_fma(_mtile, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_tmp2, KUNROLL * KTILE); - cmp(reg_tmp2, reg_tmp); - jb(".unkbloop"); - cmp(reg_tmp, reg_tmp3); - jge(".kend", T_NEAR); - L(".kbloop"); - generate_fma(_mtile, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_tmp2, 1 * KTILE); - cmp(reg_tmp2, reg_tmp3); - jb(".kbloop"); - L(".kend"); - add(reg_iterk, reg_tmp2); - generate_f32_accumulate(_mtile); - generate_zp_correction(_mtile); - inc(reg_iterkb); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { - for (int kk = 0; kk < _ktile; kk++) { - lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void generate_f32_accumulate(int _mtile) { - load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - - mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); - load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); - lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); - for (int i = 0; i < NRegs; i++) { - vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); - vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); - vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - } - } - } - - void generate_zp_correction(int _mtile) { - load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp1, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - auto& reg_redB = reg_tmp2; - - mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); - auto& reg_zpA = reg_tmp; - - mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); - auto& reg_scaleA = reg_tmp1; - - load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); - auto& reg_ldsa = reg_tmp3; - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); - } - - for (int i = 0; i < _mtile; i++) { - vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); - vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); - vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); - for (int j = 0; j < NRegs; j++) { - vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); - vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); - } - lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); - lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); - } - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -} // namespace kblock -} // namespace code -template