diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index 45ebf889c5da1..292ce60c6b6cf 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -84,7 +84,7 @@ jobs: 7z x cmake-3.26.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7c5cfee61116f..7494035e4784e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1258,13 +1258,7 @@ if (onnxruntime_USE_OPENVINO) endif() # Check OpenVINO version for support - if (${VER} MATCHES "2022.1" OR $ENV{INTEL_OPENVINO_DIR} MATCHES "2022.1") - set(OPENVINO_VERSION "2022.1") - add_definitions(-DOPENVINO_2022_1=1) - elseif (${VER} MATCHES "2022.2" OR $ENV{INTEL_OPENVINO_DIR} MATCHES "2022.2") - set(OPENVINO_VERSION "2022.2") - add_definitions(-DOPENVINO_2022_2=1) - elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2022.3") + if ($ENV{INTEL_OPENVINO_DIR} MATCHES "2022.3") set(OPENVINO_VERSION "2022.3") add_definitions(-DOPENVINO_2022_3=1) elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.0") @@ -1273,9 +1267,12 @@ if (onnxruntime_USE_OPENVINO) elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.1") set(OPENVINO_VERSION "2023.1") add_definitions(-DOPENVINO_2023_1=1) - elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino") - set(OPENVINO_VERSION "2023.1") + elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.2") + set(OPENVINO_VERSION "2023.2") add_definitions(-DOPENVINO_2023_1=1) + elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino") + set(OPENVINO_VERSION "2023.2") + add_definitions(-DOPENVINO_2023_2=1) else() message(FATAL_ERROR "Unsupported OpenVINO version: ${INTEL_OPENVINO_DIR}") endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index b93ccf77d52a2..61922961588b2 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -453,6 +453,9 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py" ) +file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py" +) file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py" ) @@ -550,6 +553,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/CalTableFlatBuffers + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/fusions COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers/qnn COMMAND ${CMAKE_COMMAND} -E make_directory $/quantization @@ -622,6 +626,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_cal_table_flatbuffers_src} $/onnxruntime/quantization/CalTableFlatBuffers/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_quantization_fusions_src} + $/onnxruntime/quantization/fusions/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_ep_qnn_src} $/onnxruntime/quantization/execution_providers/qnn/ diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 0c74a23204d4f..1d15383239baf 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -6,7 +6,7 @@ true - netstandard2.0 + netstandard2.0;netcoreapp3.1;net6.0 diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 201c9d4b209db..8e1ec782079be 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -23,7 +23,7 @@ import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; -import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; +import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; @@ -99,16 +99,16 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Pow', [binaryOps.pow]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], - ['ReduceMin', [reduceMin, parseReduceAttributes]], - ['ReduceMean', [reduceMean, parseReduceAttributes]], - ['ReduceMax', [reduceMax, parseReduceAttributes]], - ['ReduceSum', [reduceSum, parseReduceAttributes]], - ['ReduceProd', [reduceProd, parseReduceAttributes]], - ['ReduceL1', [reduceL1, parseReduceAttributes]], - ['ReduceL2', [reduceL2, parseReduceAttributes]], - ['ReduceLogSum', [reduceLogSum, parseReduceAttributes]], - ['ReduceLogSumExp', [reduceLogSumExp, parseReduceAttributes]], - ['ReduceSumSquare', [reduceSumSquare, parseReduceAttributes]], + ['ReduceMin', [reduceMin]], + ['ReduceMean', [reduceMean]], + ['ReduceMax', [reduceMax]], + ['ReduceSum', [reduceSum]], + ['ReduceProd', [reduceProd]], + ['ReduceL1', [reduceL1]], + ['ReduceL2', [reduceL2]], + ['ReduceLogSum', [reduceLogSum]], + ['ReduceLogSumExp', [reduceLogSumExp]], + ['ReduceSumSquare', [reduceSumSquare]], ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], ['Sigmoid', [unaryOps.sigmoid]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index a8f296ea0c865..47ec16a296712 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -510,11 +510,7 @@ export const createMatmulProgramInfo = name: 'MatMul', shaderCache: { hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${activationAttributes.activation}` + - `${activationAttributes.clipMax}` + - `${activationAttributes.clipMin}` + `${isVec4}` + - `${hasBias}` + `${isChannelsLast}`, inputDependencies }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index b6c6853c8f222..1f27525f370f3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -33,23 +33,23 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes) const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIndices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ - `${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`, - `if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) { - value = ${input.getByOffset('inputOffset')}; - bestIndex = i32(lastIndex); + `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, + `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) { + value = ${input.getByIndices('input_indices')}; + best_index = i32(last_index); }`, - '', output.setByOffset('global_idx', 'bestIndex') + '', output.setByOffset('global_idx', 'best_index') ]; }; context.compute( createReduceProgramInfo( - 'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, - attributes.keepDims), + 'ArgMin', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp, + [attributes.axis], DataType.int64, attributes.keepDims), {inputs: [0]}); }; @@ -59,23 +59,23 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes) const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIndices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ - `${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`, - `if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) { - value = ${input.getByOffset('inputOffset')}; - bestIndex = i32(lastIndex); + `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, + `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) { + value = ${input.getByIndices('input_indices')}; + best_index = i32(last_index); }`, - '', output.setByOffset('global_idx', 'bestIndex') + '', output.setByOffset('global_idx', 'best_index') ]; }; context.compute( createReduceProgramInfo( - 'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, - attributes.keepDims), + 'argMax', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp, + [attributes.axis], DataType.int64, attributes.keepDims), {inputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index c7ea0cffe51c3..33a5db7ff6b25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -10,6 +10,7 @@ import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {createGroupedConvProgramInfo} from './conv-grouped'; import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; +import {createNaiveMatmulProgramInfo} from './matmul'; import {createTransposeProgramInfo} from './transpose'; export const calculateOutputShape = @@ -195,9 +196,19 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (hasBias) { matmulInputs.push(inputs[2]); } - context.compute( - createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), - {inputs: matmulInputs}); + const N = matmulOutputShape[2]; + const K = matmulInputs[0].dims[matmulInputs[0].dims.length - 1]; + // Tune the threshold. + if (N < 8 && K < 8) { + context.compute( + createNaiveMatmulProgramInfo( + matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + {inputs: matmulInputs}); + } else { + context.compute( + createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + {inputs: matmulInputs}); + } return; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index 85682f0b47220..2ff909c30e62e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper} from './common'; export interface CumSumAttributes extends AttributeWithCacheKey { @@ -26,7 +26,7 @@ const createCumsumProgramInfo = const axis = ShapeUtil.normalizeAxis(axisValue, rank); const getShaderSource = (shaderHelper: ShaderHelper) => { const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `; - const max = rank === 1 ? 'i32(uniforms.input_shape)' : 'i32(uniforms.input_shape[uniforms.axis])'; + const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank); const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0'; const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1'); return ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 19ca4ac5358ae..de9309d1e436f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -2,10 +2,150 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {BroadcastUtil} from '../../util'; -import {ComputeContext} from '../types'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; + +export const createNaiveMatmulProgramInfo = + (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + + const M = aShape[aShape.length - 2]; + const N = bShape[bShape.length - 1]; + const K = aShape[aShape.length - 1]; + const components = getMaxComponents(N); + const aComponents = getMaxComponents(K); + const outputNumber = getMaxComponents(M); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const hasBias = inputs.length > 2; + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, + {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', inputs[1].dataType, bShape.length, components); + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const inputVariables = [a, b]; + let processBias = ''; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + processBias = `${ + isChannelsLast ? `value += bias[col / ${biasComponents}];` : + `value += ${output.type.value}(bias[row + i]);`}`; + } + + const outerDimsA = aShape.slice(0, -2); + const outerDimsB = bShape.slice(0, -2); + const broadCastADims = getBroadcastDims(outerDimsA, outerDims); + const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { + const rank = variable.rank; + const name = variable.name; + if (rank === 2) { + return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`; + } + const batchRank = batchDims.rank; + let resStr = `var ${name}_indices: ${variable.type.indices};`; + for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`; + } + broadCastDims.forEach(i => { + resStr += `\n${name}_indices[${i}] = 0;`; + }); + resStr += `${name}_indices[${rank - 2}] = 0u; + ${name}_indices[${rank - 1}] = 0u;`; + return resStr; + }; + + const calcResult = (): string => { + let calcStr = `var a_data: ${a.type.value};`; + for (let i = 0; i < aComponents; i++) { + calcStr += ` + let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; + } + for (let i = 0; i < outputNumber; i++) { + calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; + + for (let j = 0; j < aComponents; j++) { + calcStr += ` + values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${ + i}]);\n`; + } + } + return calcStr; + }; + + return ` + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('M', 'u32') + .registerUniform('N', 'u32') + .registerUniform('K', 'u32') + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} + ${activationFunction} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + let col = (global_idx % (uniforms.N / ${components})) * ${components}; + var index1 = global_idx / (uniforms.N / ${components}); + let stride1 = uniforms.M / ${outputNumber}; + let row = (index1 % stride1) * ${outputNumber}; + let batch = index1 / stride1; + + ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`} + ${getIndices(a, broadCastADims)} + let a_offset = ${a.indicesToOffset('a_indices')}; + ${getIndices(b, broadCastBDims)} + let b_offset = ${b.indicesToOffset('b_indices')}; + var values: array<${output.type.value}, ${outputNumber}>; + for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) { + ${calcResult()} + } + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + let cur_indices = ${output.type.indices}(batch, row + i, col); + let offset = ${output.indicesToOffset('cur_indices')}; + ${output.setByOffset(`offset / ${components}`, 'value')}; + } + } + `; + }; + return { + name: 'MatMulNaive', + shaderCache: { + hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ + isChannelsLast}`, + inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource + }; + }; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -23,5 +163,12 @@ export const matMul = (context: ComputeContext): void => { if (!outputShape) { throw new Error('Can\'t use matmul on the given tensors'); } - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + const N = outputShape[outputShape.length - 1]; + const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; + if (N < 8 && K < 8) { + context.compute( + createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + } else { + context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index b5c956e57a9b1..e8851ac546942 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { @@ -30,14 +30,14 @@ export type ReduceOp = (input: IndicesHelper, output: IndicesHelper, axes: readonly number[]) => [string, string, string, string, ...string[]]; -const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByOffset('inputOffset')};`, '']; +const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, '']; export const createReduceProgramInfo = (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceOp: ReduceOp, axesInput: number[], outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => { const outputShape: number[] = []; const inputShape = inputs[0].dims; - - const axes = ShapeUtil.normalizeAxes(axesInput, inputs[0].dims.length); + const inputRank = inputShape.length; + const axes = ShapeUtil.normalizeAxes(axesInput, inputRank); const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0; inputShape.forEach((d, i) => { if (reduceOnAllAxes || axes.indexOf(i) >= 0) { @@ -48,53 +48,50 @@ export const createReduceProgramInfo = outputShape.push(d); } }); - - const idxCopy: string[] = []; // copy output indexes to input indexes - - const input = inputVariable('_A', inputs[0].dataType, inputShape); - const output = outputVariable('output', outputDataType, outputShape); - const ops = reduceOp(input, output, axes); - const inputOffsetAssignment = `inputOffset = ${input.indicesToOffset('inputIndices')};`; - const initinputOffsetLet = `let ${inputOffsetAssignment};`; - const initinputOffsetVar = `var ${inputOffsetAssignment};`; - const initinputOffset = (ops[1] === '') ? '' : initinputOffsetVar; - let reduceOps = ((ops[1] === '') ? initinputOffsetLet : inputOffsetAssignment) + '\n' + ops[2]; - - for (let k = 0, l = 0; k < inputs[0].dims.length; k++) { - // if this axis is reduced - if (reduceOnAllAxes || axes.indexOf(k) >= 0) { - if (keepDims) { + const outputRank = outputShape.length; + const outputSize = ShapeUtil.size(outputShape); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; // copy output indexes to input indexes + + const input = inputVariable('_A', inputs[0].dataType, inputRank); + const output = outputVariable('output', outputDataType, outputRank); + const ops = reduceOp(input, output, axes); + let reduceOps = ops[2]; + + for (let k = 0, l = 0; k < inputRank; k++) { + // if this axis is reduced + if (reduceOnAllAxes || axes.indexOf(k) >= 0) { + if (keepDims) { + l++; + } + // loop over the d-th axis + reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) { + ${ops[2].includes('last_index') ? `let last_index = j${k};` : ''} + ${input.indicesSet('input_indices', k, `j${k}`)} + ${reduceOps} + }`; + } else { + idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`); l++; } - // loop over the d-th axis - reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) { - ${ops[2].includes('lastIndex') ? `let lastIndex = j${k};` : ''} - ${input.indicesSet('inputIndices', k, `j${k}`)} - ${reduceOps} - }`; - } else { - idxCopy.push(`${input.indicesSet('inputIndices', k, output.indicesGet('outputIndices', l))};`); - l++; } - } + return ` - const outputSize = ShapeUtil.size(outputShape); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - var inputIndices: ${input.type.indices}; - let outputIndices = ${output.offsetToIndices('global_idx')}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var input_indices: ${input.type.indices}; + let output_indices = ${output.offsetToIndices('global_idx')}; ${idxCopy.join('\n')} ${ops[0]} // init ops for reduce max/min - ${initinputOffset} ${ops[1]} ${reduceOps} ${ops[3]} ${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')} }`; + }; return { name, @@ -102,7 +99,11 @@ export const createReduceProgramInfo = getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + ...createTensorShapeVariables(outputShape) + ] }), }; }; @@ -125,7 +126,7 @@ const runReduceProgram = context.compute( createReduceProgramInfo( - name, {hint: updatedAttributes.cacheKey}, [inputs[0]], + name, {hint: updatedAttributes.cacheKey, inputDependencies: ['rank']}, [inputs[0]], updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp, updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims, updatedAttributes.noopWithEmptyAxes), @@ -137,7 +138,7 @@ const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += ${input.getByOffset('inputOffset')};`, + `value += ${input.getByIndices('input_indices')};`, 'value = log(value);', ]; runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); @@ -148,7 +149,7 @@ const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): v const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += abs(${input.getByOffset('inputOffset')});`, + `value += abs(${input.getByIndices('input_indices')});`, '', ]; runReduceProgram(context, 'ReduceL1', attributes, reduceOp); @@ -159,7 +160,7 @@ const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): v const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', - `t = ${input.getByOffset('inputOffset')}; value += (t * t);`, + `t = ${input.getByIndices('input_indices')}; value += (t * t);`, 'value = sqrt(value);', ]; runReduceProgram(context, 'ReduceL2', attributes, reduceOp); @@ -170,7 +171,7 @@ const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttribu const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += exp(${input.getByOffset('inputOffset')});`, + `value += exp(${input.getByIndices('input_indices')});`, 'value = log(value);', ]; runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); @@ -182,14 +183,14 @@ const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(input.indicesSet('inputIndices', k, 0)); + idxZero.push(input.indicesSet('input_indices', k, 0)); } } return [ `${idxZero.join('\n')}`, - `var value = ${input.getByOffset('inputOffset')};`, - `value = max(value, ${input.getByOffset('inputOffset')});`, + `var value = ${input.getByIndices('input_indices')};`, + `value = max(value, ${input.getByIndices('input_indices')});`, '', ]; }; @@ -210,7 +211,7 @@ const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): return [ 'var sum = f32(0);', '', - `sum += f32(${input.getByOffset('inputOffset')});`, + `sum += f32(${input.getByIndices('input_indices')});`, `let value = ${output.type.value}(sum / ${size});`, ]; }; @@ -223,14 +224,14 @@ const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIndices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ `${idxZero.join('\n')}`, - `var value = ${input.getByOffset('inputOffset')};`, - `value = min(value, ${input.getByOffset('inputOffset')});`, + `var value = ${input.getByIndices('input_indices')};`, + `value = min(value, ${input.getByIndices('input_indices')});`, '', ]; }; @@ -242,7 +243,7 @@ const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(1);`, '', - `value *= ${input.getByOffset('inputOffset')};`, + `value *= ${input.getByIndices('input_indices')};`, '', ]; runReduceProgram(context, 'ReduceProd', attributes, reduceOp); @@ -253,7 +254,7 @@ const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += ${input.getByOffset('inputOffset')};`, + `value += ${input.getByIndices('input_indices')};`, '', ]; runReduceProgram(context, 'ReduceSum', attributes, reduceOp); @@ -264,7 +265,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', - `t = ${input.getByOffset('inputOffset')}; value += t * t;`, + `t = ${input.getByIndices('input_indices')}; value += t * t;`, '', ]; runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); @@ -273,7 +274,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu const useNaiveReduceMethod = (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { if (axes.length === 0) { - return noopWithEmptyAxes ? true : false; + return noopWithEmptyAxes; } let outputSize = 1; @@ -289,7 +290,7 @@ const useNaiveReduceMethod = // The condition data is very rough, although considering the count of Execution Unit (EU), the potential // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments // on some machines. - return reduceSize < 32 && outputSize > 1024 ? true : false; + return reduceSize < 32 && outputSize > 1024; }; export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { @@ -371,6 +372,3 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut reduceLogSumShared(context, attributes); } }; - -export const parseReduceAttributes = (attributes: Record): ReduceAttributes => - createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 973a607f9377e..e1369c2c2b43b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; type CoordinateTransformMode = 'half_pixel'|'asymmetric'|'pytorch_half_pixel'|'tf_half_pixel_for_nn'|'align_corners'| 'tf_crop_and_resize'|'half_pixel_symmetric'; @@ -245,69 +245,67 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr }; const calculateOriginalIndicesFromOutputIndices = - (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], - roi: readonly number[]): string => ` - fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${ + (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scalesLength: number, + roiLength: number): string => ` + fn calculateOriginalIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> array<${ output.type.value}, ${outputShape.length}> { - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')}); - const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')}); - var originalIndices: array<${output.type.value}, ${outputShape.length}>; + var original_indices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - if (scales[i] == 1.0) { - originalIndices[i] = ${output.type.value}(outputIndex); + var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; + var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; + var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; + if (scale == 1.0) { + original_indices[i] = output_index; } else { - originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i], - ${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${ - inputShape.length}]); + var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); + var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, + input_shape_i, roi_low, roi_hi); } } - return originalIndices; + return original_indices; }`; const calculateInputIndicesFromOutputIndices = (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scales: readonly number[], roi: readonly number[], useExtrapolation: boolean): string => ` - fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')}); - const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')}); - var inputIndices: ${input.type.indices}; - for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex: u32; - if (scales[i] == 1.0) { - inputIndex = outputIndex; - } else { - var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i], - ${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${ - inputShape.length}]); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) { - if (original_idx < 0) { - inputIndex = 0; - } else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) { - inputIndex = inputShape[i] - 1; - } else { - inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1)); - } + scalesLength: number, roiLength: number, useExtrapolation: boolean): string => ` + fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { + var input_indices: ${input.type.indices}; + for (var i:u32 = 0; i < ${outputShape.length}; i++) { + var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var input_index: u32; + var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; + if (scale == 1.0) { + input_index = u32(output_index); + } else { + var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; + var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; + var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); + var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, + input_shape_i, roi_low, roi_hi); + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) { + if (original_idx < 0) { + input_index = 0; + } else if (original_idx > (input_shape_i - 1)) { + input_index = u32(input_shape_i) - 1; } else { - inputIndex = u32(original_idx); + input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1)); } + } else { + input_index = u32(original_idx); } - ${input.indicesSet('inputIndices', 'i', 'inputIndex')} } - return inputIndices; + ${input.indicesSet('input_indices', 'i', ' input_index')} + } + return input_indices; }`; - const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): string => ` - fn checkInputIndices(inputIndices: ${input.type.indices}) -> bool { - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + fn checkInputIndices(input_indices: ${input.type.indices}) -> bool { for (var i:u32 = 0; i < ${inputShape.length}; i++) { - var inputIndex = ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'}; - if (inputIndex < 0 || inputIndex >= inputShape[i]) { + var input_index = ${input.indicesGet('input_indices', 'i')}; + if (input_index < 0 || input_index >= ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}) { return false; } } @@ -322,18 +320,18 @@ const bilinearInterpolation = const dType = input.type.value; return ` fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { - var inputIndices: ${input.type.indices}; - inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1)); - inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1)); + var input_indices: ${input.type.indices}; + ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)}; + ${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)}; if (${inputShape.length} > 2) { - inputIndices[${channelIdx}] = channel; - inputIndices[${batchIdx}] = batch; + ${input.indicesSet('input_indices', channelIdx, 'channel')}; + ${input.indicesSet('input_indices', batchIdx, 'batch')}; }; - return input[${input.indicesToOffset('inputIndices')}]; + return ${input.getByIndices('input_indices')}; } - fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { - var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices); + fn bilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} { + var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); var row:${dType} = originalIndices[${heightIdx}]; var col:${dType} = originalIndices[${widthIdx}]; if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ @@ -373,10 +371,10 @@ const bicubicInterpolation = const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; return ` - fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${ + fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${ output.type.indices}) -> ${dType} { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`}; - var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]}, + var output_index = ${output.indicesGet('output_indices', idx)}; + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]}, ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); @@ -397,10 +395,11 @@ const bicubicInterpolation = ${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1)); } } - var inputIndicesCopy: ${input.type.indices} = inputIndices; - inputIndicesCopy[${idx}] = u32(${direction}); - data[i + 1] = ${idx === heightIdx ? `input[${input.indicesToOffset('inputIndicesCopy')}];` : ` - rowCubicInterpolation(inputIndicesCopy, outputIndices);`} + var input_indices_copy: ${input.type.indices} = input_indices; + ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)}; + data[i + 1] = ${ + idx === heightIdx ? input.getByIndices('input_indices_copy') : + 'rowCubicInterpolation(input_indices_copy, output_indices)'}; } return cubicInterpolation1D(data, coefs); }`; @@ -429,9 +428,9 @@ const bicubicInterpolation = return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; } - fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { - var inputIndices: ${input.type.indices} = outputIndices; - return colCubicInterpolation(inputIndices, outputIndices); + fn bicubicInterpolation(output_indices: ${output.type.indices}) -> ${dType} { + var input_indices: ${input.type.indices} = output_indices; + return colCubicInterpolation(input_indices, output_indices); } `; }; @@ -450,8 +449,8 @@ const createResizeProgramInfo = outputShape = adjustOutputShape(inputShape, scales, attributes); } } - const output = outputVariable('output', inputTensor.dataType, outputShape); - const input = inputVariable('input', inputTensor.dataType, inputShape); + const output = outputVariable('output', inputTensor.dataType, outputShape.length); + const input = inputVariable('input', inputTensor.dataType, inputShape.length); const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; @@ -467,11 +466,11 @@ const createResizeProgramInfo = ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; ${ calculateInputIndicesFromOutputIndices( - input, output, inputShape, outputShape, scales, roi, useExtrapolation)}; + input, output, inputShape, outputShape, scales.length, roi.length, useExtrapolation)}; `; case 'linear': return ` - ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales, roi)}; + ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)}; ${ bilinearInterpolation( input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}; @@ -488,25 +487,29 @@ const createResizeProgramInfo = } })()}; `} - ${shaderHelper.declareVariables(input, output)} + ${ + shaderHelper.registerUniform('output_size', 'u32') + .registerUniform('scales', 'f32', scales.length) + .registerUniform('roi', 'f32', roi.length) + .declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} ${noScale ? 'output[global_idx] = input[global_idx];' : ` - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; + let output_indices = ${output.offsetToIndices('global_idx')}; + var input_indices: ${input.type.indices}; ${(() => { switch (attributes.mode) { case 'nearest': - return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices); - if (checkInputIndices(inputIndices)) { - output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; + return `input_indices = calculateInputIndicesFromOutputIndices(output_indices); + if (checkInputIndices(input_indices)) { + output[global_idx] = ${input.getByIndices('input_indices')}; } else { output[global_idx] = ${attributes.extrapolationValue}; }`; case 'linear': - return 'output[global_idx] = bilinearInterpolation(outputIndices);'; + return 'output[global_idx] = bilinearInterpolation(output_indices);'; case 'cubic': - return 'output[global_idx] = bicubicInterpolation(outputIndices);'; + return 'output[global_idx] = bicubicInterpolation(output_indices);'; default: throw Error(`Unsupported resize mode: ${attributes.mode}`); } @@ -518,12 +521,20 @@ const createResizeProgramInfo = name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}|${noScale}` + sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`, + inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputTensor.dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, + {type: 'float32', data: scales}, + {type: 'float32', data: roi}, + ...createTensorShapeVariables(inputShape), + ...createTensorShapeVariables(outputShape), + ] }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 43d4e5356d1d9..5212c6475dce0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -77,25 +77,25 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { - var inputIndices: ${input.type.indices}; + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[]): string => + `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { + var input_indices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)}; - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps_i + starts_i + carry; - carry = inputIndex / input_shape_i; - inputIndex = inputIndex % input_shape_i; + var output_index = ${output.indicesGet('output_indices', 'i')}; + var input_index = output_index * steps_i + starts_i + carry; + carry = input_index / input_shape_i; + input_index = input_index % input_shape_i; if (signs_i < 0) { - inputIndex = input_shape_i - inputIndex - 1u + starts_i; + input_index = input_shape_i - input_index - 1u + starts_i; } - ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; + ${input.indicesSet('input_indices', 'i', 'input_index')}; } - return inputIndices; + return input_indices; }`; const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => { @@ -162,12 +162,12 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${calculateInputIndicesImpl(input, output, inputShape)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} - let outputIndices = ${output.offsetToIndices('global_idx')}; - let inputIndices = calculateInputIndices(outputIndices); - ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} + let output_indices = ${output.offsetToIndices('global_idx')}; + let input_indices = calculateInputIndices(output_indices); + ${output.setByOffset('global_idx', input.getByIndices('input_indices'))} }`; return { name: 'Slice', diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index fd60d81b87ae1..b8582614fa214 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -34,7 +34,7 @@ const createSplitAttributesFromInputs = const calculateOutputIndexImpl = (numberOfTensors: number): string => ` fn calculateOutputIndex(index: u32) -> u32 { for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { - if (index < sizeInConcatAxis[i]) { + if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) { return i; } } @@ -48,15 +48,15 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { if (numberOfTensors === 1) { codeLines.push(returnSnippet); } else if (i === 0) { - codeLines.push(`if (outputNumber == ${i}u) { ${returnSnippet} }`); + codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`); } else if (i === numberOfTensors - 1) { codeLines.push(`else { ${returnSnippet} }`); } else { - codeLines.push(`else if (outputNumber == ${i}) { ${returnSnippet} }`); + codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`); } } return ` - fn writeBufferData(outputNumber: u32, indices: ${outputs[0].type.indices}, global_idx: u32) { + fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) { ${codeLines.join('\n')} }`; }; @@ -65,48 +65,54 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; - const rank = inputShape.length; - const axis = attributes.axis; - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); const input = inputVariable('input', dataType, inputShape); - const sizeInConcatAxis = new Array(attributes.numOutputs); + const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; - sizeInConcatAxis[i] = previousSum; + sizeInSplitAxis[i] = previousSum; const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShapes[i]); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; + programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); + programUniforms.push(...createTensorShapeVariables(inputShape)); + outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, ...outputs)} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); - ${calculateOutputIndexImpl(sizeInConcatAxis.length)} + ${ + shaderHelper.registerUniform('input_size', 'u32') + .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length) + .declareVariables(input, ...outputs)} + ${calculateOutputIndexImpl(sizeInSplitAxis.length)} ${writeBufferDataImpl(outputs)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(inputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')} var indices = ${input.offsetToIndices('global_idx')}; - let outputNumber = calculateOutputIndex(${indicesAxis}); - if (outputNumber != 0) { - ${indicesAxis} -= sizeInConcatAxis[outputNumber - 1u]; + var index = ${input.indicesGet('indices', axis)}; + let output_number = calculateOutputIndex(index); + if (output_number != 0) { + index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)}; + ${input.indicesSet('indices', axis, 'index')}; } - writeBufferData(outputNumber, indices, global_idx); + writeBufferData(output_number, indices, global_idx); }`; return { name: 'Split', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: outputsTensorInfo, dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc b/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc index b00b10ad649b1..46a8b70d289b7 100644 --- a/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc @@ -47,7 +47,6 @@ struct ComputeCtx { float alpha; }; -#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) template inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrixMap& map_A, const ConstEigenMatrixMapRowMajor& map_B, EigenMatrixMapRowMajor& output_map) { @@ -64,7 +63,8 @@ inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrix template <> inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrixMap& map_A, - const ConstEigenMatrixMapRowMajor& map_B, EigenMatrixMapRowMajor& output_map) { + const ConstEigenMatrixMapRowMajor& map_B, + EigenMatrixMapRowMajor& output_map) { if (ctx.trans_A && ctx.trans_B) { output_map = map_A.transpose() * ctx.alpha * map_B.transpose(); } else if (ctx.trans_A && !ctx.trans_B) { @@ -84,21 +84,47 @@ struct SparseToDenseCsr { const auto& b_dims = B.Shape().GetDims(); const auto& out_dims = output.Shape().GetDims(); auto csr_view = A.AsCsr(); - - ConstSparseMatrixMap map_A(a_dims[0], a_dims[1], A.NumValues(), - csr_view.Outer().Data(), - csr_view.Inner().Data(), + const Eigen::Index* inner_index_pointer = nullptr; + const Eigen::Index* outer_index_pointer = nullptr; + // For auto-release the above two pointers when they are not NULL. + std::unique_ptr buffer_holder_inner, buffer_holder_outer; + if constexpr (std::is_integral::value && + std::is_signed::value && + (sizeof(Eigen::Index) == sizeof(int64_t))) { + // On macOS the following reinterpret_cast is necessary because Eigen::Index is an alias of `long` but int64_t is + // `long long`. Though they have the same size, compilers still do not allow an implicit casting between them. + inner_index_pointer = reinterpret_cast(csr_view.Inner().Data()); + outer_index_pointer = reinterpret_cast(csr_view.Outer().Data()); + } else { + // In a 32-bit build we need to cast the following two tensors to 32 bits + gsl::span inner_data = csr_view.Inner().DataAsSpan(); + gsl::span outer_data = csr_view.Outer().DataAsSpan(); + buffer_holder_inner.reset(new Eigen::Index[inner_data.size()]); + buffer_holder_outer.reset(new Eigen::Index[outer_data.size()]); + inner_index_pointer = buffer_holder_inner.get(); + outer_index_pointer = buffer_holder_outer.get(); + + std::transform(inner_data.begin(), inner_data.end(), + buffer_holder_inner.get(), [](int64_t v) -> Eigen::Index { + return narrow(v); + }); + std::transform(outer_data.begin(), outer_data.end(), + buffer_holder_outer.get(), [](int64_t v) -> Eigen::Index { + return narrow(v); + }); + } + ConstSparseMatrixMap map_A(narrow(a_dims[0]), narrow(a_dims[1]), + narrow(A.NumValues()), outer_index_pointer, inner_index_pointer, A.Values().Data()); - ConstEigenMatrixMapRowMajor map_B(B.Data(), b_dims[0], b_dims[1]); - EigenMatrixMapRowMajor output_map(output.MutableData(), out_dims[0], out_dims[1]); + ConstEigenMatrixMapRowMajor map_B(B.Data(), narrow(b_dims[0]), narrow(b_dims[1])); + EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), + narrow(out_dims[1])); // XXX: Consider re-writing it as a parallel loop as Eigen requires it to use OpenMP // XXX: Consider vectorization SparseDenseMatMulImpl(ctx, map_A, map_B, output_map); } }; -#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) - template inline T Mul(T a_value, float, T b_value) { return a_value * b_value; @@ -121,9 +147,11 @@ struct SparseToDenseCoo { auto coo_view = A.AsCoo(); const auto& ind_dims = coo_view.Indices().Shape().GetDims(); ORT_RETURN_IF_NOT(ind_dims.size() == 2, "COO indices must be 2-D, got: ", ind_dims.size()); - ConstEigenMatrixMapRowMajor a_indicies_map(coo_view.Indices().Data(), narrow(ind_dims[0]), narrow(ind_dims[1])); + ConstEigenMatrixMapRowMajor a_indicies_map(coo_view.Indices().Data(), narrow(ind_dims[0]), + narrow(ind_dims[1])); ConstEigenMatrixMapRowMajor map_b(B.Data(), narrow(b_dims[0]), narrow(b_dims[1])); - EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), narrow(out_dims[1])); + EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), + narrow(out_dims[1])); output_map.setZero(); const auto rhs_right = (ctx.trans_B) ? b_dims[0] : b_dims[1]; @@ -140,7 +168,8 @@ struct SparseToDenseCoo { ORT_RETURN_IF_NOT(m < out_left, "COO m index: ", m, " is out of bounds of out_left: ", out_left); const T a_value = a_values[i]; for (int64_t n = 0; n < rhs_right; ++n) { - const T b_value = (ctx.trans_B) ? map_b(narrow(n), narrow(k)) : map_b(narrow(k), narrow(n)); + const T b_value = + (ctx.trans_B) ? map_b(narrow(n), narrow(k)) : map_b(narrow(k), narrow(n)); output_map(narrow(m), narrow(n)) += Mul(a_value, ctx.alpha, b_value); } } @@ -170,8 +199,9 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { const auto inner_B = (trans_b_attr_) ? b_dims[1] : b_dims[0]; const auto outer_B = (trans_b_attr_) ? b_dims[0] : b_dims[1]; - ORT_RETURN_IF_NOT(inner_A == inner_B, "Can not multiply A and B as inner dimension does not match. inner_A: ", - inner_A, " vs inner_B: ", inner_B); + ORT_RETURN_IF_NOT(inner_A == inner_B, + "Can not multiply A and B as inner dimension does not match. inner_A: ", inner_A, + " vs inner_B: ", inner_B); TensorShape output_shape{outer_A, outer_B}; auto* output = ctx->Output(0, output_shape); @@ -184,12 +214,10 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { auto coo_view = A->AsCoo(); const auto num_dims = coo_view.Indices().Shape().NumDimensions(); ORT_RETURN_IF_NOT(num_dims == 2, "Expecting COO 2-D indices shape"); - ORT_RETURN_IF_NOT(A->Values().Shape().Size() * 2 == coo_view.Indices().Shape().Size(), "Expecting 2xValues == indices"); + ORT_RETURN_IF_NOT(A->Values().Shape().Size() * 2 == coo_view.Indices().Shape().Size(), + "Expecting 2xValues == indices"); auto status = t_disp.InvokeRet(compute_ctx, *A, *B, *output); ORT_RETURN_IF_ERROR(status); -// Eigen has a bug in x86 where it calculates reallocation size as -1 -// and throws bad_alloc -#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) } else if (A->Format() == SparseFormat::kCsrc) { auto csr_view = A->AsCsr(); ORT_RETURN_IF_NOT(A->Values().Shape().Size() == csr_view.Inner().Shape().Size(), @@ -199,11 +227,6 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Currently support only COO and CSR(x64) formats"); } -#else - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "WASM and 32-bit builds support only COO format"); - } -#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) return Status::OK(); } @@ -211,4 +234,4 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { } // namespace contrib } // namespace onnxruntime -#endif //! defined(DISABLE_SPARSE_TENSORS) \ No newline at end of file +#endif //! defined(DISABLE_SPARSE_TENSORS) diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index ea9040aa7875f..992bba0fc5e6b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -31,6 +31,7 @@ namespace internal { #ifdef USE_COMPOSABLE_KERNEL using onnxruntime::rocm::CKDataTypeAdaptor; +using onnxruntime::rocm::CKBlasOpAdaptor; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -39,9 +40,11 @@ using Nop = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu; -template +template auto GetCKGemmAddFastGeluTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, ck::Tuple, Row, CKDataType, CKDataType, ck::Tuple, CKDataType, @@ -76,9 +79,11 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { return ret; } -template +template auto GetCKGemmFastGeluTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, ck::Tuple<>, Row, CKDataType, CKDataType, ck::Tuple<>, CKDataType, diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu index 294e7be91e883..8d7e64b1015be 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -49,16 +49,16 @@ inline GEMMFASTGELU(T, ScalarT) { if (tuning_ctx->IsTunableOpEnabled()) { if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } } diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh index 229f868a215fd..e157aa57f8c43 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -51,24 +51,24 @@ Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { params->c); } -template +template class GemmFastGeluTunableOp : public TunableOp> { public: GemmFastGeluTunableOp() { this->RegisterOp(GemmFastGeluUnfused); #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif #ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 830a3a6a492db..1fed8af21b31c 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -86,11 +86,11 @@ Return Value: if constexpr (std::is_same_v || std::is_same_v) { auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, Output); + vec_xst(CharVector, 0, (int8_t *)Output); } else { static_assert(std::is_same_v || std::is_same_v); - vec_xst(ShortVector0, 0, Output); - vec_xst(ShortVector1, 0, &Output[8]); + vec_xst(ShortVector0, 0, (int16_t *)Output); + vec_xst(ShortVector1, 0, (int16_t *)&Output[8]); } Output += 16; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index d1b3f19100942..8bfa66710e2fc 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -872,6 +872,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "QLinearConv", "QLinearMatMul", "QuantizeLinear", + "DynamicQuantizeLinear", "RandomNormal", "RandomNormalLike", "RandomUniform", diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 7e4c0dc8d7267..b2a7028f49e55 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -74,17 +74,19 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if (GetGlobalContext().device_type.find("CPU") != std::string::npos || GetGlobalContext().device_type.find("GPU") != std::string::npos) { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " - << "Creating backend Dynamic Shapes"; - try { - concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, - GetGlobalContext(), - subgraph_context_); - } catch (std::string const& msg) { - throw msg; + if (!GetGlobalContext().disable_dynamic_shapes) { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " + << "Creating backend Dynamic Shapes"; + try { + concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, + GetGlobalContext(), + subgraph_context_); + } catch (std::string const& msg) { + throw msg; + } + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " + << "Backend created for graph " << subgraph_context_.subgraph_name; } - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " - << "Backend created for graph " << subgraph_context_.subgraph_name; } } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " @@ -260,7 +262,7 @@ void BackendManager::Compute(OrtKernelContext* context) { } #endif bool use_dynamic_backend = true; - if (subgraph_context_.has_dynamic_input_shape && + if (!GetGlobalContext().disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape && (GetGlobalContext().device_type.find("CPU") != std::string::npos || GetGlobalContext().device_type.find("GPU") != std::string::npos)) { concrete_backend_->Infer(context); diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index d47c91dd46622..5092fffcfc111 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -54,7 +54,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } const std::string model = model_proto.SerializeAsString(); try { - auto cnn_network = global_context.ie_core.ReadModel(model); + auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); if ((subgraph_context.precision == "FP16") && (global_context.device_type.find("NPU") == std::string::npos)) { // FP16 transformations @@ -95,7 +95,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } } #ifndef NDEBUG -#if defined(OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) +#if defined(OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) || (OPENVINO_2023_2) if (IsDebugEnabled()) { std::string name = cnn_network->get_friendly_name(); ov::pass::Serialize serializer(name + ".xml", name + ".bin"); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 09e1322ff59fb..2280d853e30f4 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -40,6 +40,9 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, // Enable streams; default=1 unless ovverriden by user config EnableStreams(); + // Set the inference_num_threads property of the CPU + SetNumThreads(device_config); + #ifndef NDEBUG if (IsDebugEnabled()) { std::string file_name = subgraph_context.subgraph_name + "_static.onnx"; @@ -67,8 +70,8 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else -#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) || (OPENVINO_2023_2) + if (global_context_.disable_dynamic_shapes && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); exe_network_ = global_context_.ie_core.LoadNetwork( model, hw_target, device_config, subgraph_context_.subgraph_name); @@ -96,16 +99,7 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, throw(msg); } - // The infer_requests_ pool will be intialized with a default value of 8 infer_request's - // The nireq value can also be configured to any num_of_threads during runtime - size_t nireq = global_context_.num_of_threads; - LOGS_DEFAULT(INFO) << log_tag << "The value of nireq being used is: " << nireq; -#ifndef NDEBUG - if (openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "The value of nireq being used is: " << nireq << std::endl; - } -#endif - inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, nireq)); + inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, 1)); } bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { @@ -132,7 +126,7 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_config.emplace(ov::enable_profiling(true)); } #endif -#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) || (OPENVION_2023_2) if (global_context_.device_type.find("NPU") != std::string::npos) { std::pair device_property; device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER"); @@ -168,7 +162,24 @@ void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { } void BasicBackend::EnableStreams() { - global_context_.ie_core.SetStreams(global_context_.device_type, global_context_.num_streams); + // Streams can be set only if the device is not one of AUTO, MULTI, or HETERO + // Throw an exception if the user tries to set num_streams for these devices + if ((global_context_.device_type.find("MULTI") != std::string::npos) || + (global_context_.device_type.find("HETERO") != std::string::npos) || + (global_context_.device_type.find("AUTO") != std::string::npos)) { + if (global_context_.num_streams != 1) { + throw(log_tag + "Cannot set NUM_STREAMS to " + std::to_string(global_context_.num_streams) + " for device " + global_context_.device_type); + } + // Do nothing + } else { + global_context_.ie_core.SetStreams(global_context_.device_type, global_context_.num_streams); + } +} + +void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { + // inference_num_threads is applicable only for the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) + device_config.emplace(ov::inference_num_threads(global_context_.num_of_threads)); } // Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on @@ -199,6 +210,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && + !global_context_.disable_dynamic_shapes && (global_context_.device_type.find("CPU") != std::string::npos || global_context_.device_type.find("GPU") != std::string::npos)) { auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 6eda641451a72..aa96dadbf0e2d 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -37,6 +37,7 @@ class BasicBackend : public IBackend { void EnableCaching(); void EnableGPUThrottling(ov::AnyMap& device_config); void EnableStreams(); + void SetNumThreads(ov::AnyMap& device_config); void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); #ifdef IO_BUFFER_ENABLED diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 29233e72c33b9..5f19c71683f24 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -17,7 +17,7 @@ struct GlobalContext { bool is_wholly_supported_graph = false; bool enable_npu_fast_compile = false; bool enable_opencl_throttling = false; - bool enable_dynamic_shapes = false; + bool disable_dynamic_shapes = false; size_t num_of_threads; std::string device_type; std::string precision_str; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index a4c6b0f851c04..aa389f6297d80 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -22,17 +22,9 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv openvino_ep::BackendManager::GetGlobalContext().num_streams = info.num_streams_; openvino_ep::BackendManager::GetGlobalContext().context = info.context_; openvino_ep::BackendManager::GetGlobalContext().enable_opencl_throttling = info.enable_opencl_throttling_; - openvino_ep::BackendManager::GetGlobalContext().enable_dynamic_shapes = info.enable_dynamic_shapes_; - - if (static_cast(info.num_of_threads_) <= 0) { - openvino_ep::BackendManager::GetGlobalContext().num_of_threads = 8; - } else if (static_cast(info.num_of_threads_) > 8) { - std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + - std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; - ORT_THROW(err_msg); - } else { - openvino_ep::BackendManager::GetGlobalContext().num_of_threads = info.num_of_threads_; - } + openvino_ep::BackendManager::GetGlobalContext().disable_dynamic_shapes = info.disable_dynamic_shapes_; + openvino_ep::BackendManager::GetGlobalContext().num_of_threads = info.num_of_threads_; + // to check if target device is available // using ie_core capability GetAvailableDevices to fetch list of devices plugged in if (info.cache_dir_.empty()) { @@ -120,15 +112,7 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = graph_viewer.DomainToVersionMap().at(kOnnxDomain); -#if defined(OPENVINO_2022_1) - openvino_ep::GetCapability obj(graph_viewer, - openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2022_1"); - result = obj.Execute(); -#elif defined(OPENVINO_2022_2) - openvino_ep::GetCapability obj(graph_viewer, - openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2022_2"); - result = obj.Execute(); -#elif defined(OPENVINO_2022_3) +#if defined(OPENVINO_2022_3) openvino_ep::GetCapability obj(graph_viewer, openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2022_3"); result = obj.Execute(); @@ -140,6 +124,10 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::GetCapability obj(graph_viewer, openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2023_1"); result = obj.Execute(); +#elif defined(OPENVINO_2023_2) + openvino_ep::GetCapability obj(graph_viewer, + openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2023_2"); + result = obj.Execute(); #endif return result; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 3b56b54410e40..7cc2fb9b1ea98 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -69,12 +69,12 @@ struct OpenVINOExecutionProviderInfo { int num_streams_; void* context_; bool enable_opencl_throttling_; - bool enable_dynamic_shapes_; + bool disable_dynamic_shapes_; explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_npu_fast_compile, std::string dev_id, size_t num_of_threads, std::string cache_dir, int num_streams, void* context, bool enable_opencl_throttling, - bool enable_dynamic_shapes) + bool disable_dynamic_shapes) : enable_npu_fast_compile_(enable_npu_fast_compile), device_id_(dev_id), num_of_threads_(num_of_threads), @@ -82,7 +82,7 @@ struct OpenVINOExecutionProviderInfo { num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), - enable_dynamic_shapes_(enable_dynamic_shapes) { + disable_dynamic_shapes_(disable_dynamic_shapes) { if (dev_type == "") { LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" << "No runtime device selection option provided."; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index fbb89710c8008..749907da18354 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -11,13 +11,13 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { OpenVINOProviderFactory(const char* device_type, bool enable_npu_fast_compile, const char* device_id, size_t num_of_threads, const char* cache_dir, int num_streams, void* context, - bool enable_opencl_throttling, bool enable_dynamic_shapes) + bool enable_opencl_throttling, bool disable_dynamic_shapes) : enable_npu_fast_compile_(enable_npu_fast_compile), num_of_threads_(num_of_threads), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), - enable_dynamic_shapes_(enable_dynamic_shapes) { + disable_dynamic_shapes_(disable_dynamic_shapes) { device_type_ = (device_type == nullptr) ? "" : device_type; device_id_ = (device_id == nullptr) ? "" : device_id; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; @@ -36,13 +36,13 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { int num_streams_; void* context_; bool enable_opencl_throttling_; - bool enable_dynamic_shapes_; + bool disable_dynamic_shapes_; }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { OpenVINOExecutionProviderInfo info(device_type_, enable_npu_fast_compile_, device_id_, num_of_threads_, cache_dir_, num_streams_, context_, enable_opencl_throttling_, - enable_dynamic_shapes_); + disable_dynamic_shapes_); return std::make_unique(info); } @@ -67,7 +67,7 @@ struct OpenVINO_Provider : Provider { bool enable_npu_fast_compile = false; // [enable_npu_fast_compile]: Fast-compile may be optionally enabled to // speeds up the model's compilation to NPU device specific format. const char* device_id = ""; // [device_id]: Selects a particular hardware device for inference. - int num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of + int num_of_threads = 0; // [num_of_threads]: Overrides the accelerator default value of number of // threads with this value at runtime. const char* cache_dir = ""; // [cache_dir]: specify the path to // dump and load the blobs for the model caching/kernel caching (GPU) @@ -78,7 +78,7 @@ struct OpenVINO_Provider : Provider { // with this value at runtime. bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU // device (Reduces CPU Utilization when using GPU) - bool enable_dynamic_shapes = false; // [enable_dynamic_shapes]: Enables Dynamic Shapes feature for CPU device) + bool disable_dynamic_shapes = false; // [disable_dynamic_shapes]: Execute model with default static shape for optimal performance. void* context = nullptr; if (provider_options_map.find("device_type") != provider_options_map.end()) { @@ -147,12 +147,12 @@ struct OpenVINO_Provider : Provider { bool_flag = ""; } - if (provider_options_map.find("enable_dynamic_shapes") != provider_options_map.end()) { - bool_flag = provider_options_map.at("enable_dynamic_shapes"); + if (provider_options_map.find("disable_dynamic_shapes") != provider_options_map.end()) { + bool_flag = provider_options_map.at("disable_dynamic_shapes"); if (bool_flag == "true" || bool_flag == "True") - enable_dynamic_shapes = true; + disable_dynamic_shapes = true; else if (bool_flag == "false" || bool_flag == "False") - enable_dynamic_shapes = false; + disable_dynamic_shapes = false; } return std::make_shared(const_cast(device_type.c_str()), enable_npu_fast_compile, @@ -162,7 +162,7 @@ struct OpenVINO_Provider : Provider { num_streams, context, enable_opencl_throttling, - enable_dynamic_shapes); + disable_dynamic_shapes); } void Initialize() override { diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index d2ce378c97e02..31952e5b15e37 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -6,6 +6,7 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/shared_library/provider_api.h" +#include "backend_utils.h" #if defined(OV_API_20) using Exception = ov::Exception; @@ -18,10 +19,22 @@ namespace onnxruntime { namespace openvino_ep { const std::string log_tag = "[OpenVINO-EP] "; -std::shared_ptr OVCore::ReadModel(const std::string& model) const { +std::shared_ptr OVCore::ReadModel(const std::string& model, const std::string& model_path) const { try { - OVTensor weights; - return oe.read_model(model, weights); + std::istringstream modelStringStream(model); + std::istream& modelStream = modelStringStream; + // Try to load with FrontEndManager + ov::frontend::FrontEndManager manager; + ov::frontend::FrontEnd::Ptr FE; + ov::frontend::InputModel::Ptr inputModel; + + ov::AnyVector params{&modelStream, model_path}; + + FE = manager.load_by_model(params); + if (FE) { + inputModel = FE->load(params); + } + return FE->convert(inputModel); } catch (const Exception& e) { throw std::string(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what())); } catch (...) { @@ -36,6 +49,35 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, ov::CompiledModel obj; try { obj = oe.compile_model(ie_cnn_network, hw_target, device_config); + +#ifndef NDEBUG + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + // output of the actual settings that the device selected + auto supported_properties = obj.get_property(ov::supported_properties); + std::cout << "Model:" << std::endl; + for (const auto& cfg : supported_properties) { + if (cfg == ov::supported_properties) + continue; + auto prop = obj.get_property(cfg); + if (cfg == ov::device::properties) { + auto devices_properties = prop.as(); + for (auto& item : devices_properties) { + std::cout << " " << item.first << ": " << std::endl; + for (auto& item2 : item.second.as()) { + OPENVINO_SUPPRESS_DEPRECATED_START + if (item2.first == ov::supported_properties || item2.first == "SUPPORTED_CONFIG_KEYS)" || + item2.first == "SUPPORTED_METRICS") + continue; + OPENVINO_SUPPRESS_DEPRECATED_END + std::cout << " " << item2.first << ": " << item2.second.as() << std::endl; + } + } + } else { + std::cout << " " << cfg << ": " << prop.as() << std::endl; + } + } + } +#endif OVExeNetwork exe(obj); return exe; } catch (const Exception& e) { @@ -45,7 +87,7 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, } } -#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) || (OPENVINO_2023_2) OVExeNetwork OVCore::LoadNetwork(const std::string& model, std::string& hw_target, ov::AnyMap& device_config, diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 935ac8f68411d..690e91742beed 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -6,10 +6,11 @@ #include #include -#if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) +#if defined(OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) || (OPENVINO_2023_2) #define OV_API_20 #include "openvino/openvino.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" +#include "openvino/frontend/manager.hpp" #else #include #endif @@ -43,12 +44,12 @@ class OVCore { ov::Core oe; public: - std::shared_ptr ReadModel(const std::string& model_stream) const; + std::shared_ptr ReadModel(const std::string& model_stream, const std::string& model_path) const; OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name); -#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) || (OPENVINO_2023_2) OVExeNetwork LoadNetwork(const std::string& model_stream, std::string& hw_target, ov::AnyMap& device_config, diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 454f3dd5eb3cc..4494bb8ab2d60 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -26,18 +26,16 @@ namespace openvino_ep { GetCapability::GetCapability(const GraphViewer& graph_viewer_param, std::string device_type_param, const std::string version_param) : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { - if (version_param == "V_2022_1") { - data_ops_ = new DataOps(graph_viewer_, V_2022_1, device_type_); - } else if (version_param == "V_2022_2") { - data_ops_ = new DataOps(graph_viewer_, V_2022_2, device_type_); - } else if (version_param == "V_2022_3") { + if (version_param == "V_2022_3") { data_ops_ = new DataOps(graph_viewer_, V_2022_3, device_type_); } else if (version_param == "V_2023_0") { data_ops_ = new DataOps(graph_viewer_, V_2023_0, device_type_); } else if (version_param == "V_2023_1") { data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_); + } else if (version_param == "V_2023_2") { + data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_); } else { - data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_); + data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_); } } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index a5a0faa3a8f24..8749885660314 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -146,7 +146,7 @@ std::vector supported_op_mode = { {"Dropout", V_2023_0, {"NPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, {"Elu", V_2023_0, {"NPU"}}, - // {"Einsum", V_2023_0, {"CPU", "GPU"}}, + {"Einsum", V_2023_1, {"CPU", "GPU"}}, {"Equal", V_2020_4, {"CPU", "GPU"}}, {"Equal", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Erf", V_2020_4, {"CPU", "GPU"}}, @@ -705,7 +705,7 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"PRelu", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1, V_2023_2}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); @@ -820,7 +820,7 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Squeeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1, V_2023_2}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -835,7 +835,7 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1, V_2023_2}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index a5aa3f825602c..f6ad2dd5c9d60 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -25,6 +25,7 @@ enum versionNum { V_2022_3, V_2023_0, V_2023_1, + V_2023_2 }; using VersionNum = enum versionNum; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.cu b/onnxruntime/core/providers/rocm/tunable/gemm.cu index 3d96916a5edda..b4b7eb47bed2f 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm.cu +++ b/onnxruntime/core/providers/rocm/tunable/gemm.cu @@ -53,16 +53,16 @@ inline GEMM(T, ScalarT) { if (tuning_ctx->IsTunableOpEnabled()) { if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmTunableOp gemm{}; + static internal::GemmTunableOp gemm{}; return gemm(¶ms); } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmTunableOp gemm{}; + static internal::GemmTunableOp gemm{}; return gemm(¶ms); } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmTunableOp gemm{}; + static internal::GemmTunableOp gemm{}; return gemm(¶ms); } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmTunableOp gemm{}; + static internal::GemmTunableOp gemm{}; return gemm(¶ms); } } @@ -94,16 +94,16 @@ inline BATCHED_GEMM(T, ScalarT) { if (tuning_ctx->IsTunableOpEnabled()) { if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::BatchedGemmTunableOp gemm{}; + static internal::BatchedGemmTunableOp gemm{}; return gemm(¶ms); } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::BatchedGemmTunableOp gemm{}; + static internal::BatchedGemmTunableOp gemm{}; return gemm(¶ms); } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::BatchedGemmTunableOp gemm{}; + static internal::BatchedGemmTunableOp gemm{}; return gemm(¶ms); } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::BatchedGemmTunableOp gemm{}; + static internal::BatchedGemmTunableOp gemm{}; return gemm(¶ms); } } @@ -138,16 +138,16 @@ inline STRIDED_BATCHED_GEMM(T, ScalarT) { if (tuning_ctx->IsTunableOpEnabled()) { if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::StridedBatchedGemmTunableOp gemm{}; + static internal::StridedBatchedGemmTunableOp gemm{}; return gemm(¶ms); } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::StridedBatchedGemmTunableOp gemm{}; + static internal::StridedBatchedGemmTunableOp gemm{}; return gemm(¶ms); } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::StridedBatchedGemmTunableOp gemm{}; + static internal::StridedBatchedGemmTunableOp gemm{}; return gemm(¶ms); } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::StridedBatchedGemmTunableOp gemm{}; + static internal::StridedBatchedGemmTunableOp gemm{}; return gemm(¶ms); } } diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh index 2518f45e0995e..b342bd6bc8a72 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -36,9 +36,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using Nop = ck::tensor_operation::element_wise::PassThrough; -template +template auto GetCKGemmTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemm = ck::tensor_operation::device::DeviceGemm< ALayout, BLayout, Row, CKDataType, CKDataType, CKDataType, @@ -70,9 +72,11 @@ auto GetCKGemmTypeStringAndOps() { return ret; } -template +template auto GetCKStreamKGemmTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemm = ck::tensor_operation::device::DeviceGemmStreamK< ALayout, BLayout, Row, CKDataType, CKDataType, CKDataType, @@ -104,9 +108,11 @@ auto GetCKStreamKGemmTypeStringAndOps() { return ret; } -template +template auto GetCKSplitKGemmTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< ALayout, BLayout, Row, CKDataType, CKDataType, CKDataType, @@ -144,9 +150,11 @@ auto GetCKSplitKGemmTypeStringAndOps() { return ret; } -template +template auto GetCKStridedBatchedGemmTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceStridedBatchedGemm = ck::tensor_operation::device::DeviceBatchedGemm< ALayout, BLayout, Row, CKDataType, CKDataType, CKDataType, diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index 776dabd757af4..6554ed977cef6 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -59,9 +59,9 @@ constexpr hipblasltDatatype_t HipBlasDataTypeFor() { return HIPBLASLT_R_64F; } -template -constexpr hipblasOperation_t MapCKLayoutToHipBlasLt() { - if constexpr (std::is_same_v) { +template +constexpr hipblasOperation_t MapBlasOpToHipBlasLt() { + if constexpr (Op == BlasOp::NonTrans) { return HIPBLAS_OP_N; } return HIPBLAS_OP_T; @@ -101,13 +101,13 @@ std::string TypeStringFor() { return "UnknownType"; } -template +template auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationType::NONE) { hipblasLtHandle_t handle; HIPBLASLT_CALL_THROW(hipblasLtCreate(&handle)); - hipblasOperation_t trans_a = MapCKLayoutToHipBlasLt(); - hipblasOperation_t trans_b = MapCKLayoutToHipBlasLt(); + hipblasOperation_t trans_a = MapBlasOpToHipBlasLt(); + hipblasOperation_t trans_b = MapBlasOpToHipBlasLt(); hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor(); std::vector heuristic_result; @@ -266,19 +266,19 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp return ret; } -template +template auto GetHipBlasLtGemmTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(); + return GetHipBlasLtTypeStringAndOps>(); } -template +template auto GetHipBlasLtStridedBatchedGemmTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(); + return GetHipBlasLtTypeStringAndOps>(); } -template +template auto GetHipBlasLtGemmFastGeluTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(ActivationType::GELU); + return GetHipBlasLtTypeStringAndOps>(ActivationType::GELU); } #endif // USE_HIPBLASLT diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh index dbef772f8cd96..9228287fbbb89 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh @@ -33,14 +33,14 @@ bool IsZero(half v) { return __half2float(v) == 0.0f; } -template +template class GemmTunableOp : public TunableOp> { public: GemmTunableOp() { this->RegisterOp(RocBlasGemmOp); #ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetHipBlasLtGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } @@ -54,16 +54,16 @@ class GemmTunableOp : public TunableOp> { #endif #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKStreamKGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetCKStreamKGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKSplitKGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetCKSplitKGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } @@ -96,7 +96,7 @@ class GemmTunableOp : public TunableOp> { } }; -template +template class BatchedGemmTunableOp : public TunableOp> { public: BatchedGemmTunableOp() { @@ -146,14 +146,14 @@ class BatchedGemmTunableOp : public TunableOp> { } }; -template +template class StridedBatchedGemmTunableOp : public TunableOp> { public: StridedBatchedGemmTunableOp() { this->RegisterOp(RocBlasStridedBatchedGemmOp); #ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } @@ -167,7 +167,7 @@ class StridedBatchedGemmTunableOp : public TunableOp #endif #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index df4dd55417755..e3b8dea90a898 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1449,8 +1449,12 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["context"] = context_string.str(); ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling; - ov_options_converted_map["enable_dynamic_shapes"] = legacy_ov_options->enable_dynamic_shapes; - + std::string enable_dynamic_shapes = reinterpret_cast(legacy_ov_options->enable_dynamic_shapes); + if (enable_dynamic_shapes == "true" || enable_dynamic_shapes == "True") { + ov_options_converted_map["disable_dynamic_shapes"] = "false"; + } else if (enable_dynamic_shapes == "false" || enable_dynamic_shapes == "False") { + ov_options_converted_map["disable_dynamic_shapes"] = "true"; + } // Add new provider option below ov_options_converted_map["num_streams"] = "1"; return ov_options_converted_map; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 81e58c9dd02d0..2e9af9f1f9bb2 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -104,6 +104,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #else status = create_not_supported_status(); #endif + } else if (strcmp(provider_name, "SNPE") == 0) { #if defined(USE_SNPE) options->provider_factories.push_back(SNPEProviderFactoryCreator::Create(provider_options)); diff --git a/onnxruntime/core/util/math_cpuonly.h b/onnxruntime/core/util/math_cpuonly.h index f4fa3aa54b2ca..73caf9f86180d 100644 --- a/onnxruntime/core/util/math_cpuonly.h +++ b/onnxruntime/core/util/math_cpuonly.h @@ -93,7 +93,7 @@ template using ConstEigenMatrixMap = Eigen::Map>; template -using ConstSparseMatrixMap = Eigen::Map>; +using ConstSparseMatrixMap = Eigen::Map>; template using ConstEigenArrayMap = Eigen::Map>; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 27fbf19084d77..6f383d733edbd 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -903,10 +903,10 @@ std::unique_ptr CreateExecutionProviderInstance( ORT_THROW("Invalid value passed for enable_opencl_throttling: ", option.second); } OV_provider_options_map[option.first] = option.second; - } else if (option.first == "enable_dynamic_shapes") { + } else if (option.first == "disable_dynamic_shapes") { if (!(option.second == "True" || option.second == "true" || option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second); + ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second); } OV_provider_options_map[option.first] = option.second; } else if (option.first == "device_id") { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu index 6707892cca50e..6c6bc147bd2a0 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu @@ -23,7 +23,7 @@ namespace py = pybind11; namespace onnxruntime { #ifdef USE_COMPOSABLE_KERNEL -template +template class CKGemm : public IKernelExplorer { public: CKGemm(BlasOp opa, BlasOp opb, @@ -34,9 +34,7 @@ class CKGemm : public IKernelExplorer { double beta, DeviceArray& c, int64_t ldc) : params_{} { - auto supports_a = opa == BlasOp::N ? std::is_same_v : std::is_same_v; - auto supports_b = opb == BlasOp::N ? std::is_same_v : std::is_same_v; - ORT_ENFORCE(supports_a && supports_b); + ORT_ENFORCE(opa == OpA && opb == OpB); params_.tuning_ctx = TuningContext(); params_.stream = Stream(); @@ -56,15 +54,15 @@ class CKGemm : public IKernelExplorer { params_.c = static_cast(c.ptr()); params_.ldc = ldc; - for (auto&& [type_string, op] : GetCKGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } - for (auto&& [type_string, op] : GetCKStreamKGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKStreamKGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } - for (auto&& [type_string, op] : GetCKSplitKGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKSplitKGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -100,7 +98,7 @@ class CKGemm : public IKernelExplorer { size_t selected_op_{}; }; -template +template class CKStridedBatchedGemm : public IKernelExplorer { public: CKStridedBatchedGemm( @@ -113,9 +111,7 @@ class CKStridedBatchedGemm : public IKernelExplorer { DeviceArray& c, int64_t ldc, int64_t stride_c, int64_t batch) : params_{} { - auto supports_a = opa == BlasOp::N ? std::is_same_v : std::is_same_v; - auto supports_b = opb == BlasOp::N ? std::is_same_v : std::is_same_v; - ORT_ENFORCE(supports_a && supports_b); + ORT_ENFORCE(opa == OpA && opb == OpB); params_.tuning_ctx = TuningContext(); params_.stream = Stream(); @@ -139,7 +135,7 @@ class CKStridedBatchedGemm : public IKernelExplorer { params_.stride_c = stride_c; params_.batch = batch; - for (auto&& [type_string, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -175,44 +171,44 @@ class CKStridedBatchedGemm : public IKernelExplorer { size_t selected_op_{}; }; -#define REGISTER_OP_COMMON(type, dtype, alayout, blayout, layout_string) \ - py::class_>(m, #type "_" #dtype "_" layout_string) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_CKGEMM(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(CKGemm, dtype, alayout, blayout, layout_string) \ - .def(py::init>(m, #type "_" #dtype "_" layout_string) \ + .def("SetRepeats", &type::SetRepeats) \ + .def("Profile", &type::Profile) \ + .def("Run", &type::Run) \ + .def("ListOps", &type::ListOps) \ + .def("SelectOp", &type::SelectOp) + +#define REGISTER_CKGEMM(dtype, opa, opb, layout_string) \ + REGISTER_OP_COMMON(CKGemm, dtype, opa, opb, layout_string) \ + .def(py::init()); -#define REGISTER_CKGEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_CKGEMM(dtype, Row, Row, "NN"); \ - REGISTER_CKGEMM(dtype, Row, Col, "NT"); \ - REGISTER_CKGEMM(dtype, Col, Row, "TN"); \ - REGISTER_CKGEMM(dtype, Col, Col, "TT"); - -#define REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(CKStridedBatchedGemm, dtype, alayout, blayout, layout_string) \ - .def(py::init()); -#define REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Row, Row, "NN"); \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Row, Col, "NT"); \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Row, "TN"); \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, Col, Col, "TT"); +#define REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(dtype) \ + REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::T, BlasOp::T, "TT"); KE_REGISTER(m) { REGISTER_CKGEMM_FOR_ALL_TRANSAB(float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu index 78446aa2b2008..ec7083186b977 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu @@ -23,7 +23,7 @@ namespace py = pybind11; namespace onnxruntime { #ifdef USE_COMPOSABLE_KERNEL -template +template class CKGemmFastGelu : public IKernelExplorer { public: CKGemmFastGelu(BlasOp opa, BlasOp opb, @@ -35,9 +35,7 @@ class CKGemmFastGelu : public IKernelExplorer { double beta, DeviceArray& c, int64_t ldc) : params_{} { - auto supports_a = opa == BlasOp::N ? std::is_same_v : std::is_same_v; - auto supports_b = opb == BlasOp::N ? std::is_same_v : std::is_same_v; - ORT_ENFORCE(supports_a && supports_b); + ORT_ENFORCE(opa == OpA && opb == OpB); params_.tuning_ctx = TuningContext(); params_.stream = Stream(); @@ -58,11 +56,11 @@ class CKGemmFastGelu : public IKernelExplorer { params_.c = static_cast(c.ptr()); params_.ldc = ldc; - for (auto&& [type_string, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } - for (auto&& [type_string, op] : GetCKGemmFastGeluTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKGemmFastGeluTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -97,26 +95,26 @@ class CKGemmFastGelu : public IKernelExplorer { size_t selected_op_{}; }; -#define REGISTER_OP(type, alayout, blayout, layout_string) \ - py::class_>(m, "CKGemmFastGelu_" #type "_" layout_string) \ - .def(py::init()) \ - .def("SetRepeats", &CKGemmFastGelu::SetRepeats) \ - .def("Profile", &CKGemmFastGelu::Profile) \ - .def("Run", &CKGemmFastGelu::Run) \ - .def("ListOps", &CKGemmFastGelu::ListOps) \ - .def("SelectOp", &CKGemmFastGelu::SelectOp); - -#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ - REGISTER_OP(type, Row, Row, "NN"); \ - REGISTER_OP(type, Row, Col, "NT"); \ - REGISTER_OP(type, Col, Row, "TN"); \ - REGISTER_OP(type, Col, Col, "TT"); +#define REGISTER_OP(type, opa, opb, layout_string) \ + py::class_>(m, "CKGemmFastGelu_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &CKGemmFastGelu::SetRepeats) \ + .def("Profile", &CKGemmFastGelu::Profile) \ + .def("Run", &CKGemmFastGelu::Run) \ + .def("ListOps", &CKGemmFastGelu::ListOps) \ + .def("SelectOp", &CKGemmFastGelu::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_OP(type, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_OP(type, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_OP(type, BlasOp::T, BlasOp::T, "TT"); KE_REGISTER(m) { REGISTER_OP_FOR_ALL_TRANSAB(float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu index 3a73984f53d49..4d8ecfc34219e 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu @@ -23,7 +23,7 @@ namespace onnxruntime { using namespace rocm::tunable::blas::internal; -template +template class GemmFastGeluHipBlasLt : public IKernelExplorer { public: GemmFastGeluHipBlasLt(BlasOp opa, BlasOp opb, @@ -53,7 +53,7 @@ class GemmFastGeluHipBlasLt : public IKernelExplorer { params_.c = static_cast(c.ptr()); params_.ldc = ldc; - for (auto&& [type_string, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + for (auto&& [type_string, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -89,26 +89,26 @@ class GemmFastGeluHipBlasLt : public IKernelExplorer { size_t selected_op_{}; }; -#define REGISTER_OP(type, alayout, blayout, layout_string) \ - py::class_>(m, "GemmFastGeluHipBlasLt_" #type "_" layout_string) \ - .def(py::init()) \ - .def("SetRepeats", &GemmFastGeluHipBlasLt::SetRepeats) \ - .def("Profile", &GemmFastGeluHipBlasLt::Profile) \ - .def("Run", &GemmFastGeluHipBlasLt::Run) \ - .def("ListOps", &GemmFastGeluHipBlasLt::ListOps) \ - .def("SelectOp", &GemmFastGeluHipBlasLt::SelectOp); - -#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ - REGISTER_OP(type, Row, Row, "NN"); \ - REGISTER_OP(type, Row, Col, "NT"); \ - REGISTER_OP(type, Col, Row, "TN"); \ - REGISTER_OP(type, Col, Col, "TT"); +#define REGISTER_OP(type, opa, opb, layout_string) \ + py::class_>(m, "GemmFastGeluHipBlasLt_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluHipBlasLt::SetRepeats) \ + .def("Profile", &GemmFastGeluHipBlasLt::Profile) \ + .def("Run", &GemmFastGeluHipBlasLt::Run) \ + .def("ListOps", &GemmFastGeluHipBlasLt::ListOps) \ + .def("SelectOp", &GemmFastGeluHipBlasLt::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_OP(type, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_OP(type, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_OP(type, BlasOp::T, BlasOp::T, "TT"); KE_REGISTER(m) { REGISTER_OP_FOR_ALL_TRANSAB(float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu index 7ecb87828acdc..3f375c67acf85 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu @@ -17,7 +17,7 @@ using namespace onnxruntime::contrib::rocm::blas::internal; namespace py = pybind11; namespace onnxruntime { -template +template class GemmFastGeluTunable : public IKernelExplorer { public: GemmFastGeluTunable(BlasOp opa, BlasOp opb, @@ -72,29 +72,29 @@ class GemmFastGeluTunable : public IKernelExplorer { using ParamsT = GemmFastGeluParams; ParamsT params_{}; rocblas_handle rocblas_handle_; - GemmFastGeluTunableOp op_{}; + GemmFastGeluTunableOp op_{}; }; -#define REGISTER_OP(type, alayout, blayout, layout_string) \ - py::class_>(m, "GemmFastGeluTunable_" #type "_" layout_string) \ - .def(py::init()) \ - .def("SetRepeats", &GemmFastGeluTunable::SetRepeats) \ - .def("Profile", &GemmFastGeluTunable::Profile) \ - .def("Run", &GemmFastGeluTunable::Run) \ - .def("ListOps", &GemmFastGeluTunable::ListOps) \ - .def("SelectOp", &GemmFastGeluTunable::SelectOp); - -#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ - REGISTER_OP(type, Row, Row, "NN"); \ - REGISTER_OP(type, Row, Col, "NT"); \ - REGISTER_OP(type, Col, Row, "TN"); \ - REGISTER_OP(type, Col, Col, "TT"); +#define REGISTER_OP(type, opa, opb, layout_string) \ + py::class_>(m, "GemmFastGeluTunable_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluTunable::SetRepeats) \ + .def("Profile", &GemmFastGeluTunable::Profile) \ + .def("Run", &GemmFastGeluTunable::Run) \ + .def("ListOps", &GemmFastGeluTunable::ListOps) \ + .def("SelectOp", &GemmFastGeluTunable::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_OP(type, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_OP(type, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_OP(type, BlasOp::T, BlasOp::T, "TT"); KE_REGISTER(m) { REGISTER_OP_FOR_ALL_TRANSAB(float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu index 7ab6e5ae81847..c0658dff193ae 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu @@ -25,7 +25,7 @@ namespace onnxruntime { using namespace rocm::tunable::blas::internal; -template +template class GemmHipBlasLt : public IKernelExplorer { public: GemmHipBlasLt(BlasOp opa, BlasOp opb, @@ -54,7 +54,7 @@ class GemmHipBlasLt : public IKernelExplorer { params_.c = static_cast(c.ptr()); params_.ldc = ldc; - for (auto&& [type_string, op] : GetHipBlasLtGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetHipBlasLtGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -90,7 +90,7 @@ class GemmHipBlasLt : public IKernelExplorer { size_t selected_op_{}; }; -template +template class StridedBatchedGemmHipBlasLt : public IKernelExplorer { public: StridedBatchedGemmHipBlasLt( @@ -125,7 +125,7 @@ class StridedBatchedGemmHipBlasLt : public IKernelExplorer { params_.stride_c = stride_c; params_.batch = batch; - for (auto&& [type_string, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -161,44 +161,44 @@ class StridedBatchedGemmHipBlasLt : public IKernelExplorer { size_t selected_op_{}; }; -#define REGISTER_OP_COMMON(type, dtype, alayout, blayout, layout_string) \ - py::class_>(m, #type "_" #dtype "_" layout_string) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_GEMM_HIPBLASLT(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(GemmHipBlasLt, dtype, alayout, blayout, layout_string) \ - .def(py::init>(m, #type "_" #dtype "_" layout_string) \ + .def("SetRepeats", &type::SetRepeats) \ + .def("Profile", &type::Profile) \ + .def("Run", &type::Run) \ + .def("ListOps", &type::ListOps) \ + .def("SelectOp", &type::SelectOp) + +#define REGISTER_GEMM_HIPBLASLT(dtype, opa, opb, layout_string) \ + REGISTER_OP_COMMON(GemmHipBlasLt, dtype, opa, opb, layout_string) \ + .def(py::init()); -#define REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ - REGISTER_GEMM_HIPBLASLT(dtype, Row, Row, "NN"); \ - REGISTER_GEMM_HIPBLASLT(dtype, Row, Col, "NT"); \ - REGISTER_GEMM_HIPBLASLT(dtype, Col, Row, "TN"); \ - REGISTER_GEMM_HIPBLASLT(dtype, Col, Col, "TT"); - -#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(StridedBatchedGemmHipBlasLt, dtype, alayout, blayout, layout_string) \ - .def(py::init()); -#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Row, Row, "NN"); \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Row, Col, "NT"); \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Col, Row, "TN"); \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Col, Col, "TT"); +#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::T, BlasOp::T, "TT"); KE_REGISTER(m) { REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu index d1786f94b1a3b..e1d9b5de20e00 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu @@ -19,7 +19,7 @@ using namespace onnxruntime::rocm::tunable::blas::internal; namespace onnxruntime { -template +template class GemmTunable : public IKernelExplorer { public: GemmTunable(BlasOp opa, BlasOp opb, @@ -73,11 +73,11 @@ class GemmTunable : public IKernelExplorer { ParamsT params_; // tunable is stateful, store it as an instance - GemmTunableOp op_{}; + GemmTunableOp op_{}; rocblas_handle rocblas_handle_; }; -template +template class BatchedGemmTunable : public IBatchedGemmKernelExplorer { public: BatchedGemmTunable(BlasOp opa, BlasOp opb, @@ -135,11 +135,11 @@ class BatchedGemmTunable : public IBatchedGemmKernelExplorer { ParamsT params_; // tunable is stateful, store it as an instance - BatchedGemmTunableOp op_{}; + BatchedGemmTunableOp op_{}; rocblas_handle rocblas_handle_; }; -template +template class StridedBatchedGemmTunable : public IKernelExplorer { public: StridedBatchedGemmTunable(BlasOp opa, BlasOp opb, @@ -198,64 +198,64 @@ class StridedBatchedGemmTunable : public IKernelExplorer { ParamsT params_; // tunable is stateful, store it as an instance - StridedBatchedGemmTunableOp op_{}; + StridedBatchedGemmTunableOp op_{}; rocblas_handle rocblas_handle_; }; -#define REGISTER_OP_COMMON(type, dtype, alayout, blayout, layout_string) \ - py::class_>(m, #type "_" #dtype "_" layout_string) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_GEMM(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(GemmTunable, dtype, alayout, blayout, layout_string) \ - .def(py::init>(m, #type "_" #dtype "_" layout_string) \ + .def("SetRepeats", &type::SetRepeats) \ + .def("Profile", &type::Profile) \ + .def("Run", &type::Run) \ + .def("ListOps", &type::ListOps) \ + .def("SelectOp", &type::SelectOp) + +#define REGISTER_GEMM(dtype, opa, opb, layout_string) \ + REGISTER_OP_COMMON(GemmTunable, dtype, opa, opb, layout_string) \ + .def(py::init()) -#define REGISTER_GEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_GEMM(dtype, Row, Row, "NN"); \ - REGISTER_GEMM(dtype, Row, Col, "NT"); \ - REGISTER_GEMM(dtype, Col, Row, "TN"); \ - REGISTER_GEMM(dtype, Col, Col, "TT"); - -#define REGISTER_BATCHED_GEMM(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(BatchedGemmTunable, dtype, alayout, blayout, layout_string) \ - .def(py::init&, int64_t, \ - std::vector&, int64_t, \ - double, \ - std::vector&, int64_t, \ +#define REGISTER_GEMM_FOR_ALL_TRANSAB(dtype) \ + REGISTER_GEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_GEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_GEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_GEMM(dtype, BlasOp::T, BlasOp::T, "TT"); + +#define REGISTER_BATCHED_GEMM(dtype, opa, opb, layout_string) \ + REGISTER_OP_COMMON(BatchedGemmTunable, dtype, opa, opb, layout_string) \ + .def(py::init&, int64_t, \ + std::vector&, int64_t, \ + double, \ + std::vector&, int64_t, \ int64_t>()) -#define REGISTER_BATCHED_GEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_BATCHED_GEMM(dtype, Row, Row, "NN"); \ - REGISTER_BATCHED_GEMM(dtype, Row, Col, "NT"); \ - REGISTER_BATCHED_GEMM(dtype, Col, Row, "TN"); \ - REGISTER_BATCHED_GEMM(dtype, Col, Col, "TT"); - -#define REGISTER_STRIDED_BATCHED_GEMM(dtype, alayout, blayout, layout_string) \ - REGISTER_OP_COMMON(StridedBatchedGemmTunable, dtype, alayout, blayout, layout_string) \ - .def(py::init()) -#define REGISTER_STRIDED_BATCHED_GEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, Row, Row, "NN"); \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, Row, Col, "NT"); \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Row, "TN"); \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, Col, Col, "TT"); +#define REGISTER_STRIDED_BATCHED_GEMM_FOR_ALL_TRANSAB(dtype) \ + REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ + REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ + REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ + REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::T, BlasOp::T, "TT"); KE_REGISTER(m) { REGISTER_GEMM_FOR_ALL_TRANSAB(float); diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py index c5f0b27f7576a..61a264c275a13 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py @@ -1 +1,2 @@ +from .preprocess import qnn_preprocess_model # noqa: F401 from .quant_config import get_qnn_qdq_config # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py new file mode 100644 index 0000000000000..9ebf400498e0e --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import onnx + +from ...fusions import Fusion +from ...onnx_model import ONNXModel + + +class FusionLpNormalization(Fusion): + def __init__(self, model: ONNXModel, epsilon: float = 1e-12): + super().__init__(model, "LpNormalization", "ReduceL2") + self.epsilon = epsilon + + def fuse( + self, + reduce_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function that tries to fuse a node sequence containing a ReduceL2 node into a single + LpNormalization node. + + Pattern 1: + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + Notes: + - ReduceL2 must use the last axis, and keepdims == True + - Clip must only have a min attribute that is ~1e-12 + - Expand must restore the shape to root.shape + - The output of Expand must be the second input to Div. + """ + if reduce_node.output[0] not in input_name_to_nodes: + return + + # ReduceL2 must have one Clip child + children = input_name_to_nodes[reduce_node.output[0]] + if len(children) != 1 or children[0].op_type != "Clip": + return + + # ReduceL2 must have keepdims == True + keepdims = self.get_node_attribute(reduce_node, "keepdims") + if not keepdims: + return + + # ReduceL2 axes must refer only to the last dimension. + # Axes became an input in opset 18. Before then, axes was an attribute + reduce_input_ttype = self.model.get_tensor_type(reduce_node.input[0]) + if not reduce_input_ttype: + return + + reduce_input_shape = self.tensor_shape_to_list(reduce_input_ttype) + if not reduce_input_shape: + return + + axes = self.get_node_attribute(reduce_node, "axes") + if not axes and len(reduce_node.input) > 1: + axes = self.model.get_constant_value(reduce_node.input[1]) + + if not axes or len(axes) != 1: + return + + last_dim = len(reduce_input_shape) - 1 + if axes[0] != -1 and axes[0] != last_dim: + return + + # Clip node must have a min attribute approximately equal to 1e-12 + clip_node = children[0] + clip_min = self.get_node_attribute(clip_node, "min") + if clip_min is None and len(clip_node.input) > 1: + clip_min = self.model.get_constant_value(clip_node.input[1]) + + clip_max = self.get_node_attribute(clip_node, "max") # TODO: clip_max could be FLOAT_MAX + if clip_max is None and len(clip_node.input) > 2: + clip_max = self.model.get_constant_value(clip_node.input[2]) + + if not (clip_max is None and clip_min is not None and clip_min > 0 and abs(clip_min - self.epsilon) < 1e-13): + return + + if clip_node.output[0] not in input_name_to_nodes: + return + + # Clip must have a single Expand child. + children = input_name_to_nodes[clip_node.output[0]] + if len(children) != 1 or children[0].op_type != "Expand": + return + + expand_node = children[0] + if expand_node.output[0] not in input_name_to_nodes: + return + + # Expand must have a single Div child + children = input_name_to_nodes[expand_node.output[0]] + if len(children) != 1 or children[0].op_type != "Div": + return + + div_node = children[0] + + # The first input to Div must be the root of the subgraph (i.e., reduce_node.input[0]) + # The second input to Div must be the output of the Expand. + # As long as these two inputs go to the same Div node, then ONNX validation will ensure that + # their shapes match. + if div_node.input[0] != reduce_node.input[0]: + return + if div_node.input[1] != expand_node.output[0]: + return + + subgraph_input = reduce_node.input[0] + subgraph_output = div_node.output[0] + + subgraph_nodes = [reduce_node, clip_node, expand_node, div_node] + if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node): + return + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node( + self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1 + ) + self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py new file mode 100644 index 0000000000000..becbaceab184e --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +from pathlib import Path + +import onnx + +from ...fusions import FusionGelu, FusionLayerNormalization +from ...onnx_model import ONNXModel +from .fusion_lpnorm import FusionLpNormalization + + +def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool: + modified = False + model = onnx.load_model(model_input) + onnx_model = ONNXModel(model) + + # Fuse Erf sequence into a single Gelu + fusion_gelu = FusionGelu(onnx_model) + if fusion_gelu.apply(): + modified = True + + # Fuse ReduceL2 sequence into a single LpNormalization node with p == 2. + fusion_lpnorm = FusionLpNormalization(onnx_model) + if fusion_lpnorm.apply(): + modified = True + + # Optionally, fuse ReduceMean sequence into a single LayerNormalization node. + if fuse_layernorm: + onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") + + # Need opset >= 17 to use LayerNormalization. + if onnx_opset.version < 17: + logging.warning( + "Unable to fuse ReduceMean sequence into a LayerNormalization node. " + "ONNX model must use an opset >= 17 in order to use LayerNormalization, " + f"but found version {onnx_opset.version}. Please use onnx.version_converter to update your model." + ) + else: + fusion_layernorm = FusionLayerNormalization(onnx_model) + if fusion_layernorm.apply(): + modified = True + + if modified: + onnx_model.topological_sort() + onnx.save_model(model, model_output) + + return modified diff --git a/onnxruntime/python/tools/quantization/fusions/__init__.py b/onnxruntime/python/tools/quantization/fusions/__init__.py new file mode 100644 index 0000000000000..f1576240a2ee3 --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/__init__.py @@ -0,0 +1,3 @@ +from .fusion import Fusion # noqa: F401 +from .fusion_gelu import FusionGelu # noqa: F401 +from .fusion_layernorm import FusionLayerNormalization # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py new file mode 100644 index 0000000000000..456a75eec2f8c --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -0,0 +1,298 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +from collections import deque + +import onnx + +from ..onnx_model import ONNXModel + + +class Fusion: + """ + Base class for fusions. + """ + + def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str): + self.search_op_type: str = search_op_type + self.fused_op_type: str = fused_op_type + self.model: ONNXModel = model + self.nodes_to_remove: list = [] + self.nodes_to_add: list = [] + + def fuse( + self, + node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function for derived fusion classes. Tries to fuse a node sequence containing + the specified node. + """ + raise NotImplementedError + + def apply(self) -> bool: + """ + Apply graph fusion on the entire model graph. + """ + input_name_to_nodes = self.model.input_name_to_nodes() + output_name_to_node = self.model.output_name_to_node() + + for node in self.model.nodes(): + if node.op_type == self.search_op_type: + self.fuse(node, input_name_to_nodes, output_name_to_node) + + self.model.remove_nodes(self.nodes_to_remove) + self.model.add_nodes(self.nodes_to_add) + + graph_updated = bool(self.nodes_to_remove or self.nodes_to_add) + + if graph_updated: + self.model.remove_unused_constant() + + return graph_updated + + @staticmethod + def is_safe_to_fuse_nodes( + nodes_to_remove: list[onnx.NodeProto], + keep_outputs: list[str], + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + for node_to_remove in nodes_to_remove: + for output_to_remove in node_to_remove.output: + if output_to_remove in keep_outputs: + continue + + if output_to_remove in input_name_to_nodes: + for impacted_node in input_name_to_nodes[output_to_remove]: + if impacted_node not in nodes_to_remove: + # Not safe to remove nodes since output is used by impacted_node + return False + return True + + @staticmethod + def get_node_attribute(node: onnx.NodeProto, attribute_name: str): + for attr in node.attribute: + if attr.name == attribute_name: + value = onnx.helper.get_attribute_value(attr) + return value + return None + + @staticmethod + def input_index(node_output: str, child_node: onnx.NodeProto) -> int: + index = 0 + for input_name in child_node.input: + if input_name == node_output: + return index + index += 1 + return -1 + + @staticmethod + def tensor_shape_to_list(tensor_type) -> list[int]: + shape_list = [] + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + shape_list.append(d.dim_value) # known dimension + elif d.HasField("dim_param"): + shape_list.append(d.dim_param) # unknown dimension with symbolic name + else: + shape_list.append("?") # shall not happen + return shape_list + + def get_constant_input(self, node: onnx.NodeProto): + for i, inp in enumerate(node.input): + value = self.model.get_constant_value(inp) + if value is not None: + return i, value + + return None, None + + def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int: + i, value = self.get_constant_input(node) + if value is not None and value.size == 1 and abs(value - expected_value) < delta: + return i + + return -1 + + def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool: + return self.find_constant_input(node, expected_value, delta) >= 0 + + def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool: + value = self.model.get_constant_value(output_name) + if value is None: + return False # Not an initializer + + if len(value.shape) != rank: + return False # Wrong dimensions + + return True + + def match_first_parent( + self, + node: onnx.NodeProto, + parent_op_type: str, + output_name_to_node: dict[str, onnx.NodeProto] | None = None, + exclude: list[onnx.NodeProto] = [], # noqa: B006 + ) -> tuple[onnx.NodeProto | None, int | None]: + """ + Find parent node based on constraints on op_type. + + Args: + node: current node. + parent_op_type (str): constraint of parent node op_type. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + + Returns: + parent: The matched parent node. None if not found. + index: The input index of matched parent node. None if not found. + """ + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + + for i, inp in enumerate(node.input): + if inp in output_name_to_node: + parent = output_name_to_node[inp] + if parent.op_type == parent_op_type and parent not in exclude: + return parent, i + + return None, None + + def match_parent( + self, + node: onnx.NodeProto, + parent_op_type: str, + input_index: int | None = None, + output_name_to_node: dict[str, onnx.NodeProto] | None = None, + exclude: list[onnx.NodeProto] = [], # noqa: B006 + return_indice: list[int] | None = None, + ) -> onnx.NodeProto | None: + """ + Find parent node based on constraints on op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + input_index (int or None): only check the parent given input index of current node. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + return_indice (list): a list to append the input index when input_index is None. + + Returns: + parent: The matched parent node. + """ + assert node is not None + assert input_index is None or input_index >= 0 + + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + + if input_index is None: + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + if return_indice is not None: + return_indice.append(index) + return parent + + if input_index >= len(node.input): + # Input index out of bounds. + return None + + parent = self.model.get_parent(node, input_index, output_name_to_node) + if parent is not None and parent.op_type == parent_op_type and parent not in exclude: + return parent + + return None + + def match_parent_path( + self, + node: onnx.NodeProto, + parent_op_types: list[str], + parent_input_index: list[int] | None = None, + output_name_to_node: dict[str, onnx.NodeProto] | None = None, + return_indice: list[int] | None = None, + ) -> list[onnx.NodeProto] | None: + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + parent_op_types (str): constraint of parent node op_type of each input edge. + parent_input_index (list): constraint of input index of each input edge. None means no constraint. + output_name_to_node (dict): dictionary with output name as key, and node as value. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + parents: a list of matched parent node. + """ + if parent_input_index is not None: + assert len(parent_input_index) == len(parent_op_types) + + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + + current_node = node + matched_parents = [] + for i, op_type in enumerate(parent_op_types): + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i] if parent_input_index is not None else None, + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) + if matched_parent is None: + return None + + matched_parents.append(matched_parent) + current_node = matched_parent + + return matched_parents + + def match_parent_paths( + self, + node: onnx.NodeProto, + paths: list[tuple[list[str], list[int]]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]: + """ + Find a matching parent path to the given node. + """ + for i, path in enumerate(paths): + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + return i, matched, return_indice + return -1, None, None + + def find_first_child_by_type( + self, + node: onnx.NodeProto, + child_type: str, + input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None, + recursive: bool = True, + ) -> onnx.NodeProto | None: + children = self.model.get_children(node, input_name_to_nodes) + dq = deque(children) + while len(dq) > 0: + current_node = dq.pop() + if current_node.op_type == child_type: + return current_node + + if recursive: + children = self.model.get_children(current_node, input_name_to_nodes) + for child in children: + dq.appendleft(child) + + return None diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py new file mode 100644 index 0000000000000..a20d6dbffd7a7 --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py @@ -0,0 +1,269 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import onnx + +from ..onnx_model import ONNXModel +from .fusion import Fusion + + +class FusionGelu(Fusion): + def __init__(self, model: ONNXModel): + super().__init__(model, "Gelu", "Erf") + + def fuse( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function that tries to fuse a node sequence containing an Erf node into a single + Gelu node. + """ + if ( + self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node) + or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node) + or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node) + ): + self.model.set_opset_import("com.microsoft", 1) + + def fuse_1( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + """ + This pattern is from PyTorch model + Fuse Gelu with Erf into one node: + Pattern 1: + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) + + Pattern 2: + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) + + Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine. + """ + if erf_node.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[erf_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return False + add_after_erf = children[0] + + if not self.has_constant_input(add_after_erf, 1): + return False + + if add_after_erf.output[0] not in input_name_to_nodes: + return False + + children = input_name_to_nodes[add_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + + mul_after_erf = children[0] + + div = self.match_parent(erf_node, "Div", 0, output_name_to_node) + if div is None: + return False + + if self.find_constant_input(div, 1.4142, delta=0.001) != 1: + return False + + subgraph_input = div.input[0] + + another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0 + if subgraph_input == mul_after_erf.input[another]: # pattern 2 + children = input_name_to_nodes[mul_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul_half = children[0] + if not self.has_constant_input(mul_half, 0.5): + return False + subgraph_output = mul_half.output[0] + else: # pattern 1 + mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node) + if mul_half is None: + return False + + if not self.has_constant_input(mul_half, 0.5): + return False + + if subgraph_input not in mul_half.input: + return False + + subgraph_output = mul_after_erf.output[0] + + subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half] + if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node): + return False + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + return True + + def fuse_2( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + """ + This pattern is from Keras model + Fuse Gelu with Erf into one node: + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) + + Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine. + """ + if erf_node.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[erf_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return False + add_after_erf = children[0] + + if not self.has_constant_input(add_after_erf, 1): + return False + + if add_after_erf.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[add_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul_after_erf = children[0] + + if not self.has_constant_input(mul_after_erf, 0.5): + return False + + if mul_after_erf.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[mul_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul = children[0] + + div = self.match_parent(erf_node, "Div", 0, output_name_to_node) + if div is None: + return False + + sqrt_node = None + if self.find_constant_input(div, 1.4142, delta=0.001) != 1: + sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node) + if sqrt_node is None: + return False + if not self.has_constant_input(sqrt_node, 2.0): + return False + + root_node = self.model.get_parent(div, 0, output_name_to_node) + if root_node is None: + return False + + if root_node.output[0] not in mul.input: + return False + + subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul] + if sqrt_node: + subgraph_nodes.append(sqrt_node) + + if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node): + return False + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + return True + + def fuse_3( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + """ + This pattern is from TensorFlow model + Fuse Gelu with Erf into one node: + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) + + Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine. + """ + + if erf_node.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[erf_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return False + add_after_erf = children[0] + + if not self.has_constant_input(add_after_erf, 1): + return False + + if add_after_erf.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[add_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul_half = children[0] + + if not self.has_constant_input(mul_half, 0.5): + return False + + first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node) + if first_mul is None: + return False + + i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001) + if i < 0: + return False + + root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node) + if root_node is None: + return False + + if mul_half.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[mul_half.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + last_mul = children[0] + + if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]): + return False + + subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] + if not self.is_safe_to_fuse_nodes( + subgraph_nodes, + [last_mul.output[0]], + input_name_to_nodes, + output_name_to_node, + ): + return False + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + return True diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py new file mode 100644 index 0000000000000..d7fb89236d3d2 --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -0,0 +1,134 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import onnx + +from ..onnx_model import ONNXModel +from .fusion import Fusion + + +class FusionLayerNormalization(Fusion): + def __init__(self, model: ONNXModel): + super().__init__(model, "LayerNormalization", "ReduceMean") + + def fuse( + self, + reduce_mean_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function that tries to fuse a node sequence containing a ReduceMean node into a single + LayerNormalization node. + + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ + | | + +-------------------------------------------------+ + + It also handles cases of duplicated sub nodes exported from older version of PyTorch: + + +----------------------+ + | v + | +-------> Sub-----------------------------------------------+ + | | | + | | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + | ^ + | | + +----------------------+ + """ + children = self.model.get_children(reduce_mean_node, input_name_to_nodes) + if len(children) == 0 or len(children) > 2: + return + + root_input = reduce_mean_node.input[0] + + if children[0].op_type != "Sub" or children[0].input[0] != root_input: + return + + if len(children) == 2: + if children[1].op_type != "Sub" or children[1].input[0] != root_input: + return + + div_node = None + for child in children: + div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) + if div_node is not None: + break + if div_node is None: + return + + path_id, parent_nodes, _ = self.match_parent_paths( + div_node, + [ + (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), + ( + ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], + [1, 0, 0, 0, 0, 0], + ), + ], + output_name_to_node, + ) + if path_id < 0: + return + + sub_node = parent_nodes[-1] + if sub_node not in children: + return + + second_add_node = parent_nodes[1] + i, add_weight = self.get_constant_input(second_add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + # Skip fusion since epsilon value is not expected. + return + + pow_node = parent_nodes[3] + if self.find_constant_input(pow_node, 2.0) != 1: + return + + mul_node = input_name_to_nodes[div_node.output[0]][0] + if mul_node.op_type != "Mul": + return + + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + return + + subgraph_nodes = [reduce_mean_node] + subgraph_nodes.extend(children) + subgraph_nodes.extend(parent_nodes[:-1]) + + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + if not self.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + return + + weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)] + if not self.is_constant_with_specified_rank(weight_input, 1): + return + + bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)] + if not self.is_constant_with_specified_rank(bias_input, 1): + return + + self.nodes_to_remove.extend(subgraph_nodes) + + normalize_node = onnx.helper.make_node( + "LayerNormalization", + inputs=[reduce_mean_node.input[0], weight_input, bias_input], + outputs=[last_add_node.output[0]], + ) + normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))]) + self.nodes_to_add.append(normalize_node) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index e4342908f68ea..4591c9c950e6e 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -1,3 +1,7 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- from pathlib import Path import onnx @@ -114,6 +118,14 @@ def ir_version(self): def opset_import(self): return self.model.opset_import + def set_opset_import(self, domain, version): + for opset in self.model.opset_import: + if opset.domain == domain: + opset.version = version + return + + self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)]) + def remove_node(self, node): if node in self.model.graph.node: self.model.graph.node.remove(node) @@ -140,6 +152,49 @@ def get_initializer(self, name): return tensor return None + def find_graph_input(self, input_name): + for input in self.model.graph.input: + if input.name == input_name: + return input + return None + + def find_graph_output(self, output_name): + for output in self.model.graph.output: + if output.name == output_name: + return output + return None + + def get_tensor_type(self, tensor_name: str): + tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + + if tensor_name in tensor_type_map: + return tensor_type_map[tensor_name].tensor_type + + g_input = self.find_graph_input(tensor_name) + if g_input: + return g_input.type.tensor_type + + g_output = self.find_graph_output(tensor_name) + if g_output: + return g_output.type.tensor_type + + return None + + def get_constant_value(self, output_name): + for node in self.model.graph.node: + if node.op_type == "Constant": + if node.output[0] == output_name: + for attr in node.attribute: + if attr.name == "value": + return onnx_numpy_helper.to_array(attr.t) + + # Fallback to initializer since constant folding may have been applied. + initializer = self.get_initializer(output_name) + if initializer is not None: + return onnx_numpy_helper.to_array(initializer) + + return None + def get_initializer_name_set(self): return {initializer.name for initializer in self.model.graph.initializer} @@ -167,17 +222,19 @@ def input_name_to_nodes(self): input_name_to_nodes = {} for node in self.model.graph.node: for input_name in node.input: - if input_name not in input_name_to_nodes: - input_name_to_nodes[input_name] = [node] - else: - input_name_to_nodes[input_name].append(node) + if input_name: # Could be empty when it is optional + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) return input_name_to_nodes def output_name_to_node(self): output_name_to_node = {} for node in self.model.graph.node: for output_name in node.output: - output_name_to_node[output_name] = node + if output_name: # Could be empty when it is optional + output_name_to_node[output_name] = node return output_name_to_node def get_children(self, node, input_name_to_nodes=None): diff --git a/onnxruntime/test/contrib_ops/math/matmul_sparse_test.cc b/onnxruntime/test/contrib_ops/math/matmul_sparse_test.cc index b77c5e0ed988b..8f8946e0d467d 100644 --- a/onnxruntime/test/contrib_ops/math/matmul_sparse_test.cc +++ b/onnxruntime/test/contrib_ops/math/matmul_sparse_test.cc @@ -140,7 +140,6 @@ void resize(Index size, double reserveSizeFactor = 0) { } */ #if !defined(DISABLE_SPARSE_TENSORS) -#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) TEST(SparseToDenseMatMul, TestCsr) { constexpr int64_t rows = 9; constexpr int64_t cols = 9; @@ -261,7 +260,6 @@ TEST(SparseToDenseMatMul, TestCsr) { tester.Run(OpTester::ExpectResult::kExpectSuccess); } } -#endif // //!defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) TEST(SparseToDenseMatMul, TestCoo) { constexpr int64_t rows = 9; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index eb2a77c07f803..6a99d6a0b0246 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -272,7 +272,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else { ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_opencl_throttling' should be a boolean i.e. true or false. Default value is false.\n"); } - } else if (key == "enable_dynamic_shapes") { + } else if (key == "disable_dynamic_shapes") { if (value == "true" || value == "True" || value == "false" || value == "False") { ov_options[key] = value; @@ -298,7 +298,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); + ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); } } session_options.AppendExecutionProvider("OpenVINO", ov_options); diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index b9f7e3fe465b8..0944e46ff8eaf 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -8,7 +8,6 @@ #include "core/framework/kernel_type_str_resolver.h" #include "core/session/inference_session.h" -#include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/gradient_graph_builder.h" #include "orttraining/core/graph/gradient_config.h" @@ -76,7 +75,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, } } - onnxruntime::training::TrainingSession session_object{so, GetEnvironment()}; + onnxruntime::InferenceSession session_object{so, GetEnvironment()}; ASSERT_TRUE(!execution_providers->empty()) << "Empty execution providers vector."; std::string provider_types; @@ -102,7 +101,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, has_run = true; - ExecuteModel( + ExecuteModel( model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_types); } else { for (const std::string& provider_type : all_provider_types) { @@ -158,11 +157,11 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, continue; has_run = true; - onnxruntime::training::TrainingSession session_object{so, GetEnvironment()}; + onnxruntime::InferenceSession session_object{so, GetEnvironment()}; EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - ExecuteModel( + ExecuteModel( model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_type); } } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index eb71f212a4b11..f944d8bc5ef42 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6396,6 +6396,9 @@ def run_step(model, x): del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] +@pytest.mark.skip( + reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now." +) def test_bert_result_with_layerwise_recompute(): original_val = os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ else None # Create PyTorch model with dropout disabled. diff --git a/setup.py b/setup.py index 2ede39915cc8d..44c97937ebe2a 100644 --- a/setup.py +++ b/setup.py @@ -408,6 +408,7 @@ def finalize_options(self): "onnxruntime.quantization", "onnxruntime.quantization.operators", "onnxruntime.quantization.CalTableFlatBuffers", + "onnxruntime.quantization.fusions", "onnxruntime.quantization.execution_providers.qnn", "onnxruntime.transformers", "onnxruntime.transformers.models.bart", diff --git a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py index ec1feaae82175..ef2b645f988d6 100755 --- a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py +++ b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py @@ -154,6 +154,7 @@ def path_patterns_as_variable_value(patterns: list[str]): "DESCRIPTION": pod_config["description"], "INCLUDE_DIR_LIST": path_patterns_as_variable_value(include_dirs), "IOS_DEPLOYMENT_TARGET": framework_info["iphonesimulator"]["APPLE_DEPLOYMENT_TARGET"], + "MACOSX_DEPLOYMENT_TARGET": framework_info.get("macosx", {}).get("APPLE_DEPLOYMENT_TARGET", ""), "LICENSE_FILE": license_file, "NAME": pod_name, "PUBLIC_HEADER_FILE_LIST": path_patterns_as_variable_value(pod_files["public_header_files"]), diff --git a/tools/ci_build/github/apple/objectivec/objc.podspec.template b/tools/ci_build/github/apple/objectivec/objc.podspec.template index 8832b939f440f..b90ae4f8f267c 100644 --- a/tools/ci_build/github/apple/objectivec/objc.podspec.template +++ b/tools/ci_build/github/apple/objectivec/objc.podspec.template @@ -8,6 +8,12 @@ Pod::Spec.new do |s| s.author = { "ONNX Runtime" => "onnxruntime@microsoft.com" } s.source = { :http => "file:///http_source_placeholder" } s.ios.deployment_target = "@IOS_DEPLOYMENT_TARGET@" + + macosx_deployment_target = "@MACOSX_DEPLOYMENT_TARGET@" + if macosx_deployment_target != "" + s.osx.deployment_target = macosx_deployment_target + end + s.preserve_paths = [ "@LICENSE_FILE@" ] s.default_subspec = "Core" s.static_framework = true diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index f3c7930aa1ec7..7e389d1761613 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -1319,6 +1319,4 @@ stages: displayName: 'Publish Pipeline NuGet Artifact' inputs: artifactName: 'drop-signed-nuget-dml' - targetPath: '$(Build.ArtifactStagingDirectory)' - -- template: templates/publish-nuget.yml + targetPath: '$(Build.ArtifactStagingDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index f46febee178e1..64b78dca504ca 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -106,8 +106,7 @@ stages: ls $(Build.BinariesDirectory)/gccbin/bin mkdir $(Build.BinariesDirectory)/arm32build cd $(Build.BinariesDirectory)/arm32build - # TODO: fix the warnings and remove the --compile-no-warning-as-error arg - cmake --compile-no-warning-as-error $(Build.SourcesDirectory)/cmake -Donnxruntime_ENABLE_CPUINFO=OFF -DPython_EXECUTABLE=/usr/bin/python3 -DPYTHON_EXECUTABLE=/usr/bin/python3 -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=$(Build.SourcesDirectory)/cmake/linux_arm32_crosscompile_toolchain.cmake -G Ninja + cmake $(Build.SourcesDirectory)/cmake -Donnxruntime_ENABLE_CPUINFO=OFF -DPython_EXECUTABLE=/usr/bin/python3 -DPYTHON_EXECUTABLE=/usr/bin/python3 -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=$(Build.SourcesDirectory)/cmake/linux_arm32_crosscompile_toolchain.cmake -G Ninja ninja rm -rf $(Build.BinariesDirectory)/arm32build $(Build.BinariesDirectory)/gccbin displayName: Cross-compile for Linux ARM32 and ARM64 diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml similarity index 68% rename from tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml rename to tools/ci_build/github/azure-pipelines/publish-nuget.yml index 90020d217b800..8e029f4e679b2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -1,21 +1,12 @@ -parameters: -- name: PublishingNuget - displayName: Publishing Nuget Packages and report binary size to mysql - type: boolean - default: true +resources: + pipelines: + - pipeline: build + source: 'Zip-Nuget-Java-Nodejs Packaging Pipeline' + trigger: true + branch: main + stages: - stage: Publish_NuGet_Package_And_Report - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - dependsOn: - - NuGet_Test_Win_CPU - - NuGet_Test_Linux_CPU - - NuGet_Test_Win_GPU - - NuGet_Test_Linux_GPU - - NuGet_Test_Linux_ROCm - - NuGet_Test_MacOS - - NuGet_Packaging_DML - - NuGet_Test_Win_Training_CPU - - NuGet_Test_Linux_Training_CPU jobs: - job: workspace: @@ -28,18 +19,21 @@ stages: steps: - checkout: self submodules: false - - template: set-version-number-variables-step.yml - - - task: DownloadPipelineArtifact@0 + - template: templates/set-version-number-variables-step.yml + + - script: mkdir "$(Build.BinariesDirectory)\nuget-artifact\final-package" + + - download: build displayName: 'Download Pipeline Artifact - Signed NuGet Package' - inputs: - artifactName: 'drop-signed-nuget-CPU' - targetPath: $(Build.BinariesDirectory)/nuget-artifact/final-package + artifact: 'drop-signed-nuget-CPU' + + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-CPU\*" "$(Build.BinariesDirectory)\nuget-artifact\final-package" - - template: ../nuget/templates/get-nuget-package-version-as-variable.yml + - template: nuget/templates/get-nuget-package-version-as-variable.yml parameters: packageFolder: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + # TODO: the following step has no error checking - task: CmdLine@2 displayName: 'Post binary sizes to the dashboard database using command line' inputs: @@ -64,8 +58,10 @@ stages: ) ) + # Only report binary sizes to database if the build build was auto-triggered from the main branch - task: AzureCLI@2 displayName: 'Azure CLI' + condition: and (succeeded(), and(eq(variables['Build.SourceBranch'], 'refs/heads/main'), eq(variables['Build.Reason'], 'ResourceTrigger'))) inputs: azureSubscription: AIInfraBuildOnnxRuntimeOSS scriptLocation: inlineScript @@ -75,39 +71,36 @@ stages: python.exe $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_binary_sizes_to_dashboard.py --commit_hash=$(Build.SourceVersion) --size_data_file=binary_size_data.txt --build_project=Lotus --build_id=$(Build.BuildId) workingDirectory: '$(Build.BinariesDirectory)' - - task: DownloadPipelineArtifact@0 + - download: build displayName: 'Download Pipeline Artifact - Signed NuGet Package' - inputs: - artifactName: 'drop-signed-nuget-dml' - targetPath: $(Build.BinariesDirectory)/nuget-artifact/final-package + artifact: 'drop-signed-nuget-dml' - - task: DownloadPipelineArtifact@0 + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build displayName: 'Download Pipeline Artifact - Signed NuGet Package' - inputs: - artifactName: 'drop-signed-nuget-Training-CPU' - targetPath: $(Build.BinariesDirectory)/nuget-artifact/final-package + artifact: 'drop-signed-nuget-Training-CPU' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - task: DownloadPipelineArtifact@0 + - download: build displayName: 'Download Pipeline Artifact - Signed NuGet Package' - inputs: - artifactName: 'drop-signed-nuget-GPU' - targetPath: $(Build.BinariesDirectory)/nuget-artifact/final-package + artifact: 'drop-signed-nuget-GPU' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - task: DownloadPipelineArtifact@0 + - download: build displayName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' - inputs: - artifactName: 'drop-signed-nuget-ROCm' - targetPath: $(Build.BinariesDirectory)/nuget-artifact/final-package + artifact: 'drop-signed-nuget-ROCm' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + #TODO: allow choosing different feeds - task: NuGetCommand@2 displayName: 'Copy Signed Native NuGet Package to ORT-NIGHTLY' - condition: ne(variables['IsReleaseBuild'], 'true') # release build has a different package naming scheme inputs: command: 'push' packagesToPush: '$(Build.BinariesDirectory)/nuget-artifact/final-package/*.nupkg' publishVstsFeed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/7982ae20-ed19-4a35-a362-a96ac99897b7' - - template: component-governance-component-detection-steps.yml + - template: templates/component-governance-component-detection-steps.yml parameters : condition : 'succeeded' - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 1a7915172e211..d1dff0769e25f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -38,7 +38,7 @@ stages: cPodName: onnxruntime-training-c objcPodName: onnxruntime-training-objc - timeoutInMinutes: 180 + timeoutInMinutes: 210 steps: - script: |