diff --git a/cmake/patches/cutlass/cutlass_3.5.0.patch b/cmake/patches/cutlass/cutlass_3.5.0.patch index 3b829d2f8b2cf..93b8c474af9ed 100644 --- a/cmake/patches/cutlass/cutlass_3.5.0.patch +++ b/cmake/patches/cutlass/cutlass_3.5.0.patch @@ -1,13 +1,64 @@ +diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h +index 4c80f549..34327633 100644 +--- a/examples/41_fused_multi_head_attention/kernel_forward.h ++++ b/examples/41_fused_multi_head_attention/kernel_forward.h +@@ -221,6 +221,8 @@ struct AttentionKernel { + int32_t num_batches = 0; + int32_t num_heads = 0; + ++ bool use_smooth_softmax = false; ++ + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; +@@ -897,7 +899,8 @@ struct AttentionKernel { + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, +- kSupportsBias ? 1.0f : p.scale); ++ kSupportsBias ? 1.0f : p.scale, ++ p.use_smooth_softmax); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % +@@ -1166,7 +1169,8 @@ struct AttentionKernel { + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, +- float scaling) { ++ float scaling, ++ bool use_smooth_softmax) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` +@@ -1257,7 +1261,7 @@ struct AttentionKernel { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, +- [&](int accum_m) { mi_row = mi[accum_m]; }, ++ [&](int accum_m) { mi_row = mi[accum_m];}, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); +@@ -1294,7 +1298,7 @@ struct AttentionKernel { + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } +- s_prime[id] = total_row; ++ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row; + } + } + diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 964d2ff3..b366bc14 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -39,6 +39,7 @@ #include "cutlass/numeric_types.h" - + #include +#include - + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include @@ -230,8 +231,12 @@ struct inverse_square_root { @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644 return reinterpret_cast(result); +#else + return half_t::convert((rsqrtf(half_t::convert(lhs)))); -+#endif ++#endif #else return half_t(1.f / std::sqrt(half_t::convert(lhs))); - #endif + #endif \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index c80b8c0c164b6..9942f8c656760 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -49,3 +49,8 @@ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_con // If the value is set to -1, cuda graph capture/replay is disabled in that run. // User are not expected to set the value to 0 as it is reserved for internal use. static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; + +// Specify the type of workload for this run. +// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] +// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +static const char* const kOrtRunOptionsWorkloadType = "run.workload_type"; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 209fd4279cc99..02dd622f42e88 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -279,3 +279,8 @@ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas // Refer to MatMulNBits op schema for more details. // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; + +// Specify the type of workload for this session. +// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] +// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +static const char* const kOrtSessionOptionsWorkloadType = "session.workload_type"; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b31fbc6255c41..2f0e5da2b3f27 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -23,14 +23,6 @@ class TensorViewImpl implements TensorView { public readonly dims: readonly number[], ) {} - getUint16Array(): Uint16Array { - if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) { - throw new Error('Invalid data type'); - } - const elementCount = ShapeUtil.size(this.dims); - return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount); - } - getFloat32Array(): Float32Array { if (this.dataType !== DataType.float) { throw new Error('Invalid data type'); @@ -59,6 +51,14 @@ class TensorViewImpl implements TensorView { return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount); } + getUint16Array(): Uint16Array { + if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) { + throw new Error('Invalid data type'); + } + const elementCount = ShapeUtil.size(this.dims); + return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount); + } + reshape(newDims: readonly number[]): TensorView { if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) { throw new Error('Invalid new shape'); diff --git a/js/web/lib/wasm/jsep/tensor-view.ts b/js/web/lib/wasm/jsep/tensor-view.ts index 5f1fdfa4534cd..027c6f5660c51 100644 --- a/js/web/lib/wasm/jsep/tensor-view.ts +++ b/js/web/lib/wasm/jsep/tensor-view.ts @@ -48,6 +48,11 @@ export interface TensorView { */ getInt32Array(): Int32Array; + /** + * get a Uint16Array data view of the tensor data. tensor data must be on CPU. + */ + getUint16Array(): Uint16Array; + /** * create a new tensor view with the same data but different dimensions. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 1fc2732f245a8..168d644fe064c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -3,11 +3,18 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; -import { MAX_CLIP, MIN_CLIP, ShapeUtil } from '../../util'; +import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo } from '../types'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType } from './common'; +import { + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglValueType, + UniformDataElementType, + UniformsArrayType, +} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -20,6 +27,7 @@ const createElementwiseProgramShader = ( outputDataType: number, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, + additionalUniformsType?: UniformsArrayType, ): string => { const vecSize = Math.ceil(datasize / 4); @@ -32,9 +40,13 @@ const createElementwiseProgramShader = ( const input = inputVariable('inputData', inputDataType, [vecSize], 4); const output = outputVariable('outputData', outputDataType, [vecSize], 4); + const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }]; + if (additionalUniformsType) { + uniforms.push(...additionalUniformsType); + } return ` - ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} ${additionalImplementation ?? ''} @@ -53,24 +65,38 @@ const createElementwiseProgramInfo = ( additionalImplementation?: string, cacheKey?: string, outputDataType: number = input.dataType, -): ProgramInfo => ({ - name, - shaderCache: { hint: cacheKey, inputDependencies: ['type'] }, - getShaderSource: (shaderHelper) => - createElementwiseProgramShader( - shaderHelper, - ShapeUtil.size(input.dims), - input.dataType, - outputDataType, - funcCall, - additionalImplementation, - ), - getRunData: (inputTensors) => ({ - outputs: [{ dims: input.dims, dataType: outputDataType }], - dispatchGroup: { x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */) }, - programUniforms: [{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }], - }), -}); + additionalUniforms?: ProgramUniform[], + additionalUniformsType?: UniformsArrayType, +): ProgramInfo => { + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }, + ]; + if (additionalUniforms) { + programUniforms.push(...additionalUniforms); + } + + return { + name, + shaderCache: { hint: cacheKey, inputDependencies: ['type'] }, + getShaderSource: (shaderHelper) => + createElementwiseProgramShader( + shaderHelper, + ShapeUtil.size(input.dims), + input.dataType, + outputDataType, + funcCall, + additionalImplementation, + additionalUniformsType, + ), + getRunData: (inputTensors) => ({ + outputs: [{ dims: input.dims, dataType: outputDataType }], + dispatchGroup: { + x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */), + }, + programUniforms, + }), + }; +}; export const abs = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs')); @@ -139,24 +165,46 @@ export interface ClipAttributes extends AttributeWithCacheKey { } const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = inputs.length >= 2 && inputs[1].data !== 0 ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = inputs.length >= 3 && inputs[2].data !== 0 ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + let min: number; + let max: number; + const hasMin = inputs.length >= 2 && inputs[1].data !== 0; + const hasMax = inputs.length >= 3 && inputs[2].data !== 0; + + switch (inputs[0].dataType) { + case DataType.float: + min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38; + max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38; + break; + case DataType.float16: + min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0) + max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0) + break; + default: + throw new Error('Unsupport data type'); + } + return createAttributeWithCacheKey({ min, max }); }; export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { - const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); + const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( context.inputs[0], 'Clip', - (a) => `clamp(${a}, clip_min_, clip_max_)`, - ` - const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); - const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); -`, + (a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`, + undefined, attributes.cacheKey, + undefined, + [ + { type: context.inputs[0].dataType, data: attributes.min }, + { type: context.inputs[0].dataType, data: attributes.max }, + ], + [ + { name: 'min', type: dataType as UniformDataElementType }, + { name: 'max', type: dataType as UniformDataElementType }, + ], ), { inputs: [0] }, ); @@ -302,9 +350,7 @@ export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttr context.inputs[0], 'HardSigmoid', (a) => - `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ - attributes.beta - })))`, + `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`, undefined, attributes.cacheKey, ), diff --git a/js/web/test/data/ops/clip.jsonc b/js/web/test/data/ops/clip.jsonc new file mode 100644 index 0000000000000..f2bcc2fd58469 --- /dev/null +++ b/js/web/test/data/ops/clip.jsonc @@ -0,0 +1,248 @@ +[ + { + "name": "clip float32 type with min and max attributes", + "operator": "Clip", + "opset": { "domain": "", "version": 10 }, + "attributes": [ + { "name": "min", "type": "float", "data": 1.0 }, + { "name": "max", "type": "float", "data": 5.0 } + ], + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type with min attribute but no max attribute", + "operator": "Clip", + "opset": { "domain": "", "version": 10 }, + "attributes": [{ "name": "min", "type": "float", "data": 1.0 }], + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type without min and max attributes", + "operator": "Clip", + "opset": { "domain": "", "version": 10 }, + "attributes": [], + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type with min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.0], + "dims": [], + "type": "float32" + }, + { + "data": [5.0], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type with min input but no max input", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + }, + { + "data": [1.0], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type without min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float16 type with min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float16" + }, + { + "data": [1.0], + "dims": [], + "type": "float16" + }, + { + "data": [5.0], + "dims": [], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0], + "dims": [2, 3], + "type": "float16" + } + ] + } + ] + }, + { + "name": "clip float16 type with min input but no max input", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + }, + { + "data": [1.0], + "dims": [], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + } + ] + } + ] + }, + { + "name": "clip float16 type without min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + } + ], + "outputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 347cf946e6ff3..3af3751ba0e51 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -415,6 +415,7 @@ Status EfficientAttention( p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; p.scale = scale; + p.use_smooth_softmax = false; if (nullptr == data.mask_index) { p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 1598a7e8bcf1e..5ffa63c54c8fb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -220,6 +220,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.bias_strideM = 0; p.bias_strideB = 0; } + + p.use_smooth_softmax = params.use_smooth_softmax; } auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index a9777800f6038..ec2c92c437283 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -25,6 +25,7 @@ struct MemoryEfficientAttentionParams { int32_t qk_head_size; int32_t v_head_size; bool causal; + bool use_smooth_softmax; float scale; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 48ecfd7304f4b..1f378a184ab9b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -153,7 +153,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int sm = (device_prop.major * 10) + device_prop.minor; bool use_memory_efficient_attention = - !use_smooth_softmax_ && !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 63e94f95b04ff..04aa1c14a0f69 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -678,8 +678,8 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, - scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); // if (parameters.left_padding && parameters.is_prompt) { @@ -843,6 +843,7 @@ Status EfficientAttention( : nullptr; p.stream = stream; p.has_custom_right_padding = true; + p.use_smooth_softmax = parameters.use_smooth_softmax; run_memory_efficient_attention(p); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 849a57512dc3d..ea410998b8eef 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; + p.use_smooth_softmax = false; p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index c00eefc8e49de..9bb93b6d06167 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -693,6 +693,7 @@ Status FusedAttentionCutlass( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; + p.use_smooth_softmax = false; p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc index 113cc3df5438d..594e75042f2ae 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc @@ -19,10 +19,9 @@ common::Status ComputeConvPads(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out, - bool use_nchw) { - const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1]; - const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2]; + std::vector& pads_out) { + const int64_t input_size_y = input_shape[2]; + const int64_t input_size_x = input_shape[3]; const int64_t stride_y = onnx_strides[0]; const int64_t stride_x = onnx_strides[1]; const int64_t dilation_y = onnx_dilations[0]; @@ -54,16 +53,15 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out, - bool use_nchw) { + std::vector& pads_out) { if (AutoPadType::SAME_UPPER == auto_pad_type) { ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_UPPER, pads_out, use_nchw)); + AutoPadType::SAME_UPPER, pads_out)); } else { ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_LOWER, pads_out, use_nchw)); + AutoPadType::SAME_LOWER, pads_out)); } return Status::OK(); } @@ -111,10 +109,9 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector const std::vector& onnx_output_padding, AutoPadType auto_pad_type, std::vector& pads_out, - std::vector& output_shape_out, - bool use_nchw) { - const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1]; - const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2]; + std::vector& output_shape_out) { + const int64_t input_size_y = input_shape[2]; + const int64_t input_size_x = input_shape[3]; const int64_t stride_y = onnx_strides[0]; const int64_t stride_x = onnx_strides[1]; const int64_t dilation_y = onnx_dilations[0]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h index 5a156c96c4852..f9f9746d6ed83 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h @@ -21,8 +21,7 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out, - bool use_nchw) ORT_MUST_USE_RESULT; + std::vector& pads_out) ORT_MUST_USE_RESULT; // Compute pads and output shape for ConvTranspose. common::Status ComputeConvTransposePadsAndOutputShape(const std::vector input_shape, @@ -34,8 +33,7 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector const std::vector& onnx_output_padding, AutoPadType auto_pad_type, std::vector& pads_out, - std::vector& output_shape_out, - bool use_nchw) ORT_MUST_USE_RESULT; + std::vector& output_shape_out) ORT_MUST_USE_RESULT; } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 76a8a178678df..980c5dcd184c0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -18,9 +18,6 @@ namespace webnn { class ConvOpBuilder : public BaseOpBuilder { // Add operator related. - public: - void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; @@ -33,13 +30,6 @@ class ConvOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override; }; -void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - // skip the weight for conv as we need to transpose for preferred layout NHWC. - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W - } -} - // Helper functions common::Status SetConvBaseOptions(ModelBuilder& model_builder, const Node& node, emscripten::val& options, @@ -48,7 +38,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const std::vector& strides, const std::vector& dilations, std::vector& pads, - const bool is_nhwc, const bool is_conv1d, const logging::Logger& logger) { NodeAttrHelper helper(node); @@ -61,7 +50,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, // Calculate explicit padding for autoPad. if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); + pads, strides, dilations, auto_pad_type, pads_out)); pads = pads_out; } } else if (node.OpType() == "ConvTranspose") { @@ -82,7 +71,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, // Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER. ORT_RETURN_IF_ERROR(ComputeConvTransposePadsAndOutputShape(input_shape, weight_shape[2], weight_shape[3], pads, strides, dilations, output_padding, - auto_pad_type, pads_out, output_shape, !is_nhwc)); + auto_pad_type, pads_out, output_shape)); if (output_shape[0] != -1 && output_shape[1] != -1) { options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); @@ -111,89 +100,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, return Status::OK(); } -// Both depthwise Conv and ConvTranspose share the same logic to add the layout. -Status AddInitializerInNewLayout(ModelBuilder& model_builder, - const std::string& name, - bool is_conv, - bool is_conv1d) { - const auto& tensor = *model_builder.GetInitializerTensors().at(name); - auto data_type = tensor.data_type(); - - const auto& shape = tensor.dims(); - std::vector dims = GetVecUint32FromVecInt64(std::vector(std::begin(shape), std::end(shape))); - - if (is_conv1d) { - // Support conv1d by prepending a 1 size dimension. - dims.push_back(1); - } - - const uint8_t* src = nullptr; - Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); - src = unpacked_tensor.DataAsByteSpan().data(); - const auto out_t = dims[0], in_t = dims[1], - h_t = dims[2], w_t = dims[3]; - std::vector dest_shape; - if (is_conv == 1) - dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 - else - dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight - - SafeInt num_elements = SafeInt(Product(dest_shape)); - - size_t element_size{0}; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - element_size = sizeof(uint8_t); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - element_size = sizeof(int8_t); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - element_size = sizeof(uint16_t); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - element_size = sizeof(float); - break; - default: - break; - } - std::unique_ptr buffer_holder(new uint8_t[element_size * num_elements]); - uint8_t* buffer = buffer_holder.get(); - - for (uint32_t out = 0; out < out_t; out++) { - for (uint32_t in = 0; in < in_t; in++) { - for (uint32_t h = 0; h < h_t; h++) { - for (uint32_t w = 0; w < w_t; w++) { - auto onnx_idx = out * in_t * h_t * w_t + - in * h_t * w_t + - h * w_t + - w; - - uint32_t nnapi_idx; - if (is_conv == 1) { // L_0231 - nnapi_idx = out * h_t * w_t * in_t + - h * w_t * in_t + - w * in_t + - in; - } else { // L_1230 for depthwise conv weight - nnapi_idx = in * h_t * w_t * out_t + - h * w_t * out_t + - w * out_t + - out; - } - - for (size_t i = 0; i < element_size; i++) { - buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i]; - } - } - } - } - } - ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size, - dest_shape, data_type)); - return Status::OK(); -} - // Add operator related. Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -203,7 +109,6 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& op_type = node.OpType(); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = emscripten::val::object(); - const auto& initializers(model_builder.GetInitializerTensors()); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); @@ -216,19 +121,11 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N auto dilations = helper.Get("dilations", std::vector{1, 1}); auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; - const bool is_constant_weight = Contains(initializers, weight_name); // Support conv1d by prepending a 1 or 2 size dimensions. if (is_conv1d) { // Reshape input. - if (is_nhwc) { - // For NHWC preferred layout, the input has been transposed. - // For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2. - input_shape.insert(input_shape.begin() + 2, 1); - } else { - input_shape.push_back(1); - } + input_shape.push_back(1); std::vector new_shape = GetVecUint32FromVecInt64(input_shape); input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); @@ -244,63 +141,19 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); ORT_RETURN_IF_ERROR(SetConvBaseOptions( - model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); - bool depthwise = false; - if (op_type == "Conv" || op_type == "ConvInteger") { - int groups = options["groups"].as(); - if (is_nhwc) { - depthwise = (groups == input_shape[3] && groups != 1); - options.set("inputLayout", emscripten::val("nhwc")); - if (is_constant_weight) { - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); - } - if (!depthwise) { - options.set("filterLayout", emscripten::val("ohwi")); - } else { - options.set("filterLayout", emscripten::val("ihwo")); - } - } - } else { // ConvTranspose - if (is_nhwc) { - options.set("inputLayout", emscripten::val("nhwc")); - options.set("filterLayout", emscripten::val("ohwi")); - if (is_constant_weight) { - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); - } - } - } - + model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger)); emscripten::val filter = model_builder.GetOperand(weight_name); if (is_conv1d) { // Reshape weight to 4D for conv1d. - if (!is_nhwc || !is_constant_weight) { - // The weight_shape has been appended 1's, reshape weight operand. - std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - emscripten::val reshape_options = emscripten::val::object(); - reshape_options.set("label", node.Name() + "_reshape_filter"); - filter = model_builder.GetBuilder().call("reshape", - filter, - emscripten::val::array(new_shape), - reshape_options); - } - } - - emscripten::val transpose_options = emscripten::val::object(); - if (is_nhwc && !is_constant_weight) { - // For NHWC preferred layout, if the weight is input: - // - Transpose it from iohw -> ohwi for convTranspose. - // - Transpose it from oihw -> ihwo for depthwise conv. - // - Transpose it from oihw -> ohwi for conv. - std::vector perm(4); - if (op_type == "ConvTranspose" || depthwise) { - perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight - } else { - perm = {0, 2, 3, 1}; // L_0231 - } - transpose_options.set("permutation", emscripten::val::array(perm)); - transpose_options.set("label", node.Name() + "_transpose_filter"); - filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); + // The weight_shape has been appended 1's, reshape weight operand. + std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_filter"); + filter = model_builder.GetBuilder().call("reshape", + filter, + emscripten::val::array(new_shape), + reshape_options); } if (op_type == "Conv") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d068baf35e72..347cd11898d25 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -79,9 +79,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name()); - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - options.set("axis", rank - 1); - } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { @@ -104,9 +101,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder std::back_inserter(new_shape), [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3; ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank; - auto insertion_point = new_shape.begin() + insertion_offset; + auto insertion_point = new_shape.begin() + 3; if (input_shape.size() < webnn_shape_rank) { // Pad the shape with extra 1's to satisfy WebNN v1's rank requirements. new_shape.insert(insertion_point, -excess_rank, 1); @@ -125,9 +121,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder reshape_input_options); } - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - options.set("layout", emscripten::val("nhwc")); - } output = model_builder.GetBuilder().call("instanceNormalization", input, options); // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 0af62dacedbd5..09eb8e79ce1d3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -70,11 +70,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("strides", emscripten::val::array(strides)); const auto dilations = helper.Get("dilations", std::vector{1, 1}); options.set("dilations", emscripten::val::array(dilations)); - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - options.set("layout", emscripten::val("nhwc")); - } else { - options.set("layout", emscripten::val("nchw")); - } + options.set("layout", emscripten::val("nchw")); // Add Padding. // Usually using autopadding is more efficient than using explicit padding. @@ -93,8 +89,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, helper.Get("strides", std::vector{1, 1}), helper.Get("dilations", std::vector{1, 1}), auto_pad_type, - pads_out, - model_builder.GetPreferredLayout() == DataLayout::NCHW)); + pads_out)); pads = GetVecUint32FromVecInt64(pads_out); } // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 2218c858951d3..0e211de5a3986 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -120,18 +120,10 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector scales; std::vector sizes; - std::vector scales_hw; - std::vector sizes_hw; - std::vector axes; std::string scales_name = GetTensorName(input_defs, 2); - const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; if (!scales_name.empty()) { // Use scales. ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - if (is_nhwc) { - scales_hw = {scales[1], scales[2]}; - } else { - scales_hw = {scales[2], scales[3]}; - } + std::vector scales_hw = {scales[2], scales[3]}; options.set("scales", emscripten::val::array(scales_hw)); } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. std::vector output_sizes; @@ -140,19 +132,11 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::transform(output_sizes.cbegin(), output_sizes.cend(), std::back_inserter(sizes), [](int64_t dim) -> int32_t { return SafeInt(dim); }); - if (is_nhwc) { - sizes_hw = {sizes[1], sizes[2]}; - } else { - sizes_hw = {sizes[2], sizes[3]}; - } + std::vector sizes_hw = {sizes[2], sizes[3]}; options.set("sizes", emscripten::val::array(sizes_hw)); } - if (is_nhwc) { - axes = {1, 2}; - } else { - axes = {2, 3}; - } + std::vector axes = {2, 3}; options.set("axes", emscripten::val::array(axes)); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); @@ -221,7 +205,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } - const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; // We want to check if the scales or sizes are not trying to resize on N/C channels here. if (has_scales) { // We are using scales. std::vector scales; @@ -229,7 +212,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; float scale_n = scales[0]; - float scale_c = is_nhwc ? scales[3] : scales[1]; + float scale_c = scales[1]; if (scale_n != 1.0f || scale_c != 1.0f) { LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" << "Resize of N/C channels are not supported" @@ -239,8 +222,8 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. // TODO support ResizeBilinear. - float scale_h = is_nhwc ? scales[1] : scales[2]; - float scale_w = is_nhwc ? scales[2] : scales[3]; + float scale_h = scales[2]; + float scale_w = scales[3]; // Onnx spec requires scale to be a positive float, so we are not checking that here. if (roundf(scale_h) != scale_h) { @@ -261,12 +244,11 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; auto output_size_n = output_sizes[0]; - const int c_idx = is_nhwc ? 3 : 1; - if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { + if (output_size_n != input_shape[0] || output_sizes[1] != input_shape[1]) { LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " << "Resize of N/C channels are not supported" << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; + << ". input_size_c, " << input_shape[1] << ", output_size_c, " << output_sizes[1]; return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 44bec1fb6fd48..02fb8e732b3c7 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -20,12 +20,10 @@ namespace onnxruntime { namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type) + const emscripten::val& context, const WebnnDeviceType wnn_device_type) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), - preferred_layout_(preferred_layout), wnn_device_type_(wnn_device_type) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. @@ -254,64 +252,6 @@ Status ModelBuilder::AddOperations() { return Status::OK(); } -Status ModelBuilder::AddOperandFromPersistMemoryBuffer( - const std::string& name, const void* buffer, const size_t size, - const std::vector shape, const int32_t data_type) { - auto persist_buffer = std::make_unique(size); - uint8_t* dest = persist_buffer.get(); - memcpy(dest, buffer, size); - emscripten::val view = emscripten::val::undefined(); - emscripten::val desc = emscripten::val::object(); - ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(float), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t), - reinterpret_cast(dest))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint64_t), - reinterpret_cast(dest))}; - break; - default: - break; - } - - desc.set("dimensions", emscripten::val::array(shape)); - emscripten::val operand = emscripten::val::object(); - // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached - // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); - - AddOperand(name, operand); - mem_persist_buffers_.push_back(std::move(persist_buffer)); - return Status::OK(); -} - Status ModelBuilder::RegisterModelOutputs() { for (const auto* node_arg : graph_viewer_.GetOutputs()) { ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, false /* is_input */)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2d686070cdcc1..a954daa855e4a 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -22,8 +22,7 @@ class IOpBuilder; class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type); + const emscripten::val& context, const WebnnDeviceType wnn_device_type); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -37,15 +36,6 @@ class ModelBuilder { const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant(const std::string& data_type); - // Use the buffers to persist WebNN allocated data like transposed weight. - // It ensures the validity during inference session. - std::vector> mem_persist_buffers_; - // Add a constant operand (allocate persist buffer and move the ownership to mem_persist_buffers_). - Status AddOperandFromPersistMemoryBuffer( - const std::string& name, const void* buffer, - const size_t size, const std::vector shape, const int32_t data_type); - - DataLayout GetPreferredLayout() const { return preferred_layout_; } WebnnDeviceType GetWebnnDeviceType() const { return wnn_device_type_; } @@ -64,7 +54,6 @@ class ModelBuilder { emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); - DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; InlinedHashMap wnn_operands_; std::vector input_names_; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b918daf838c99..a6fe00241e55f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -19,12 +19,9 @@ namespace onnxruntime { WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { - // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { - preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { - preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -212,8 +209,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index d8c1e90c86cdb..1fbc99098e30f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -26,7 +26,8 @@ class WebNNExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/) const override; - DataLayout GetPreferredLayout() const override { return preferred_layout_; } + // WebNN EP uses default NCHW layout for all backends. + DataLayout GetPreferredLayout() const override { return DataLayout::NCHW; } // We implement the Compile that takes FusedNodeAndGraph instances. FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } @@ -44,7 +45,6 @@ class WebNNExecutionProvider : public IExecutionProvider { private: emscripten::val wnn_context_ = emscripten::val::undefined(); - DataLayout preferred_layout_; webnn::WebnnDeviceType wnn_device_type_; InlinedHashMap> models_; ModelMetadefIdGenerator metadef_id_generator_; diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index d48964203ce76..b20af5137d206 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -230,7 +230,9 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1 # TODO: This formula should be explained including why the scale is not estimated for the bias as well. bias_scale = input_scale * weight_scale * beta - quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32) + quantized_data = (np.asarray(bias_data) / bias_scale).round() + quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max) + quantized_data = quantized_data.astype(np.int32) # update bias initializer bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims) diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 17b9276a882eb..13bf51f74389a 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -2219,7 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - use_smooth_softmax=False, + use_smooth_softmax=True, ) @parameterized.expand(gqa_no_past_flash_attention_test_cases()) @@ -2263,7 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - use_smooth_softmax=False, + use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 74fc64fa53a4a..8b4fe66465bb1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -492,6 +492,9 @@ stages: - Linux_C_API_Packaging_CPU - Linux_C_API_Packaging_GPU - MacOS_C_API_Package_Publish + - Windows_Packaging_CPU_x86_${{ parameters.BuildVariant }} + - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} + - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} condition: succeeded() jobs: - job: Nodejs_Packaging diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index 4987e3019d24d..a5351a182b7a2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -35,5 +35,13 @@ steps: $(Build.Repository.LocalPath)/cmake/external/onnxruntime-extensions, $(Build.Repository.LocalPath)/js/react_native/e2e/node_modules, $(Build.Repository.LocalPath)/js/node_modules, + $(Build.Repository.LocalPath)/onnxruntime-inference-examples, + $(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten/tests, + $(Build.SourcesDirectory)/cmake/external/onnx/third_party/benchmark, + $(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11, + $(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11/tests, + $(Build.SourcesDirectory)/cmake/external/onnxruntime-extensions, + $(Build.SourcesDirectory)/js/react_native/e2e/node_modules, + $(Build.SourcesDirectory)/js/node_modules, $(Build.SourcesDirectory)/onnxruntime-inference-examples, $(Build.BinariesDirectory)' \ No newline at end of file