From 50e6235af111e5113860dfd7a0ece55dc00316a0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:15:59 -0800 Subject: [PATCH 1/9] [js/web] allow ShaderHelper to use internal (non-I/O) variables (#18525) ### Description This PR includes a change that inspired from #18452 to resolve a requirement: a shader may depend on an instance of `IndicesHelper` to generate WGSL code snippet, but the IndicesHelper instance is not necessarily an input/output of the program. So the existing `declareVariables()` function does not work with this scenario. In order to support this requirement, I added this "use" function to `interface ShaderHelper`, which takes a helper-like object as parameter. The hidden implementation `ShaderHelperImpl` class will iterate the helpers and call `impl()` for each. @axinging @qjia7 --- .../ops/3rd-party/matmul_packed_webgpu.ts | 26 ++--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 108 ++++++++++++------ 2 files changed, 83 insertions(+), 51 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 3e520571779e4..a8f296ea0c865 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -341,13 +341,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, const matMulReadWriteFnSource = (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], batchShapes: Array, isChannelsLast = false): string => { - const batchAShape = batchShapes[0]; - const batchBShape = batchShapes[1]; - const batchShape = batchShapes[2]; - const batchVariable = variables[0]; - const aVariable = variables[1]; - const bVariable = variables[2]; - const outputVariable = variables[3]; + const [batchAShape, batchBShape, batchShape] = batchShapes; + const [batchVariable, aVariable, bVariable, outputVariable] = variables; const broadCastADims = getBroadcastDims(batchAShape, batchShape); const broadCastBDims = getBroadcastDims(batchBShape, batchShape); const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); @@ -434,9 +429,7 @@ export const createMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const enableBatchUniforms = enableShapesUniforms(outerDims.length); const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true); - const variables = [batchDims]; - const batchShapes = [outerDimsA, outerDimsB, outerDims]; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); const dimAOuter = aShape[aShape.length - 2]; @@ -469,10 +462,7 @@ export const createMatmulProgramInfo = const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - variables.push(A); - variables.push(B); - variables.push(output); - const inputVariables = [batchDims, A, B]; + const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; if (enableBatchUniforms) { @@ -490,8 +480,9 @@ export const createMatmulProgramInfo = const hasBias = inputs.length > 2; const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = - matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); if (hasBias) { const biasComponents = isChannelsLast ? components : 1; inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); @@ -506,6 +497,7 @@ export const createMatmulProgramInfo = shaderHelper.registerUniform('dimAOuter', 'i32') .registerUniform('dimBOuter', 'i32') .registerUniform('dimInner', 'i32') + .registerInternalVariables(batchDims) .declareVariables(...inputVariables, output)} ${activationFunction} ${declareFunctions} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index f7ae18998b218..b7a391ee667bb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -58,10 +58,11 @@ interface IndicesHelperTypes { * create an instance of an indices helper: * - `inputVariable()`: create an indices helper instance for an input. * - `outputVariable()`: create an indices helper instance for an output. + * - `internalVariable()`: create an indices helper instance for an internal variable. * * An indices helper instance contains helper functions for the following operations: * - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an - * input or an output) and `shape`(the passed in shape). + * input, an output or an internal variable) and `shape`(the passed in shape). * - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at * runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView). * - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate @@ -192,9 +193,9 @@ export interface IndicesHelper { readonly name: string; /** - * whether the helper is for an input or an output. + * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input'|'output'; + readonly usage: 'input'|'output'|'internal'; /** * the rank of the input or output. @@ -210,11 +211,6 @@ export interface IndicesHelper { * a string representing the variable name for the strides of the input or output. */ readonly strides: string; - - /** - * representing variable with uniforms, but without binding. - */ - readonly uniformOnly: boolean; } const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { @@ -335,13 +331,13 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of the input or output. * @param tensorType - the tensor type of the input or output. * @param shapeOrRank - the tensor shape or the rank of the input or output. - * @param isInput - whether the helper is for an input or an output. + * @param usage - the usage of the indices helper. * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4, - uniformOnly = false): IndicesHelper => { + (name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'], + components: 1|2|3|4): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; const rankIdentity = [...new Array(rank).keys()]; @@ -363,7 +359,7 @@ const createIndicesHelper = getByIndices: false, }; - const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : ''; + const uniformPrefix = useUniform ? 'uniforms.' : ''; const shape = `${uniformPrefix}${name}_shape`; const strides = `${uniformPrefix}${name}_strides`; let o2iSnippet = ''; @@ -617,12 +613,11 @@ const createIndicesHelper = getByOffset, getByIndices, // isVec4, - usage: isInput ? 'input' : 'output', + usage, name, strides, shape, - rank, - uniformOnly + rank }; }; @@ -636,8 +631,8 @@ const createIndicesHelper = * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false): - IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'input', components); /** * Create a IndicesHelper for an output. @@ -650,7 +645,20 @@ export const inputVariable = */ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, false, components); + createIndicesHelper(name, type, shapeOrRank, 'output', components); + +/** + * Create a IndicesHelper for an internal variable. + * + * @param name - the name of the variable. + * @param type - the tensor type of the variable. + * @param shapeOrRank - the tensor shape or the rank of the variable. + * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1. + * @returns an IndicesHelper for the variable. + */ +export const internalVariable = + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'internal', components); export type UniformsArrayType = Array<{name: string; type: string}>; @@ -703,9 +711,27 @@ export interface ShaderHelper { /** * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. + * + * @param name - the name of the uniform. + * @param type - the type of the uniform. */ registerUniform(name: string, type: string): ShaderHelper; - registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms. + * + * @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and + * `type`. + */ + registerUniforms(uniforms: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple internal variables. Can be called multiple times to register multiple + * internal variables. + * + * @param variables - an array of IndicesHelper for the variables. + */ + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -740,8 +766,7 @@ class ShaderHelperImpl implements ShaderHelper { `; } - private declareVariable(variable: IndicesHelper, bindingIndex = -1): string { - this.indicesHelpers.push(variable); + private appendVariableUniforms(variable: IndicesHelper): void { if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); @@ -750,24 +775,37 @@ class ShaderHelperImpl implements ShaderHelper { this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); } } - if (variable.uniformOnly) { - return ''; + } + + private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + if (variable.usage === 'internal') { + throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.'); } + this.variables.push(variable); + this.appendVariableUniforms(variable); + const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } declareVariables(...variables: IndicesHelper[]): string { - return variables - .map(v => { - if (v.uniformOnly === true) { - return this.declareVariable(v); - } else { - return this.declareVariable(v, this.variableIndex++); - } - }) - .join('\n'); + return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); + } + + private registerInternalVariable(variable: IndicesHelper): void { + if (variable.usage !== 'internal') { + throw new Error( + 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.'); + } + + this.internalVariables.push(variable); + this.appendVariableUniforms(variable); + } + + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper { + variables.forEach(v => this.registerInternalVariable(v)); + return this; } registerUniform(name: string, type: string): ShaderHelper { @@ -780,7 +818,8 @@ class ShaderHelperImpl implements ShaderHelper { return this; } - private indicesHelpers: IndicesHelper[] = []; + private internalVariables: IndicesHelper[] = []; + private variables: IndicesHelper[] = []; private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { @@ -802,7 +841,8 @@ class ShaderHelperImpl implements ShaderHelper { * Get additional implementation that needs to be added to the shader source. */ get additionalImplementations(): string { - return this.uniformDeclaration() + this.indicesHelpers.map(i => i.impl()).join('\n'); + return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + + this.internalVariables.map(i => i.impl()).join('\n'); } } From f13380f3d8d25df797be60b4899b43504a5576b5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 28 Nov 2023 15:46:42 -0800 Subject: [PATCH 2/9] Support LoRA and Control Net in Stable Diffusion demo (#18593) ### Description (1) Export onnx model with LoRA weights for both SD 1.5 and SDXL (2) Export onnx model with Control Net for both SD 1.5 and SDXL. For SD 1.5, it is allowed to use multiple control nets. For SDXL, at most one control net is supported right now. (3) Add demo of LCM LoRA (3) Add demo of control net. --- .../models/stable_diffusion/README.md | 19 +- .../models/stable_diffusion/demo_txt2img.py | 34 +- .../stable_diffusion/demo_txt2img_xl.py | 42 +- .../models/stable_diffusion/demo_utils.py | 345 ++++++++++++++- .../stable_diffusion/diffusion_models.py | 392 +++++++++++++++--- .../models/stable_diffusion/engine_builder.py | 80 +++- .../engine_builder_ort_cuda.py | 44 +- .../engine_builder_ort_trt.py | 25 +- .../engine_builder_tensorrt.py | 45 +- .../models/stable_diffusion/ort_optimizer.py | 46 +- .../pipeline_stable_diffusion.py | 134 +++--- .../stable_diffusion/pipeline_txt2img.py | 27 +- .../stable_diffusion/pipeline_txt2img_xl.py | 22 + .../models/stable_diffusion/requirements.txt | 1 + 14 files changed, 1044 insertions(+), 212 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 54af8844d0c6c..3d00c9cd6bf59 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -83,8 +83,21 @@ For example: If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration. -#### Generate an image with SDXL LCM guided by a text prompt -```python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic"``` +### Generate an image guided by a text prompt using LCM LoRA +``` +python3 demo_txt2img_xl.py "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 +``` +#### Generate an image with SDXL LCM model guided by a text prompt +``` +python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic" +``` + +#### Generate an image with a text prompt using a control net +``` +python3 demo_txt2img.py "Stormtrooper's lecture in beautiful lecture hall" --controlnet-type depth --controlnet-scale 1.0 + +python3 demo_txt2img_xl.py "young Mona Lisa" --controlnet-type canny --controlnet-scale 0.5 --scheduler UniPC --disable-refiner +``` ## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum @@ -482,7 +495,7 @@ Most ROCm kernel optimizations are from [composable kernel](https://github.com/R Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' collaboration. ### Future Works -* Update demo to support inpainting, LoRA Weights and Control Net. +* Update demo to support inpainting. * Support flash attention in Windows. * Integration with UI. * Optimization for H100 GPU. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index b3056cc47c647..c18747d5c6518 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -22,7 +22,16 @@ import coloredlogs from cuda import cudart -from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt +from demo_utils import ( + add_controlnet_arguments, + arg_parser, + get_metadata, + init_pipeline, + max_batch, + parse_arguments, + process_controlnet_arguments, + repeat_prompt, +) from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_txt2img import Txt2ImgPipeline @@ -30,7 +39,12 @@ if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") - args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo") + parser = arg_parser("Options for Stable Diffusion Demo") + add_controlnet_arguments(parser) + args = parse_arguments(is_xl=False, parser=parser) + + controlnet_images, controlnet_scale = process_controlnet_arguments(args) + prompt, negative_prompt = repeat_prompt(args) image_height = args.height @@ -43,9 +57,7 @@ init_trt_plugins() - max_batch_size = 16 - if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512): - max_batch_size = 4 + max_batch_size = max_batch(args) batch_size = len(prompt) if batch_size > max_batch_size: @@ -58,7 +70,15 @@ # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768. min_image_size = 512 if args.engine != "ORT_CUDA" else 256 max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 - pipeline_info = PipelineInfo(args.version, min_image_size=min_image_size, max_image_size=max_image_size) + pipeline_info = PipelineInfo( + args.version, + min_image_size=min_image_size, + max_image_size=max_image_size, + do_classifier_free_guidance=(args.guidance > 1.0), + controlnet=args.controlnet_type, + lora_weights=args.lora_weights, + lora_scale=args.lora_scale, + ) # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. @@ -99,6 +119,8 @@ def run_inference(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scale, return_type="image", ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 7ff1794a68f8c..646e3518fa053 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -22,7 +22,16 @@ import coloredlogs from cuda import cudart -from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt +from demo_utils import ( + add_controlnet_arguments, + arg_parser, + get_metadata, + init_pipeline, + max_batch, + parse_arguments, + process_controlnet_arguments, + repeat_prompt, +) from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_img2img_xl import Img2ImgXLPipeline @@ -37,11 +46,7 @@ def load_pipelines(args, batch_size): init_trt_plugins() - max_batch_size = 16 - if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and ( - args.build_dynamic_shape or args.height > 512 or args.width > 512 - ): - max_batch_size = 4 + max_batch_size = max_batch(args) if batch_size > max_batch_size: raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") @@ -59,6 +64,10 @@ def load_pipelines(args, batch_size): min_image_size=min_image_size, max_image_size=max_image_size, use_lcm=args.lcm, + do_classifier_free_guidance=(args.guidance > 1.0), + controlnet=args.controlnet_type, + lora_weights=args.lora_weights, + lora_scale=args.lora_scale, ) # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to @@ -113,7 +122,9 @@ def load_pipelines(args, batch_size): return base, refiner -def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False): +def run_pipelines( + args, base, refiner, prompt, negative_prompt, controlnet_image=None, controlnet_scale=None, is_warm_up=False +): image_height = args.height image_width = args.width batch_size = len(prompt) @@ -131,6 +142,8 @@ def run_base_and_refiner(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, + controlnet_images=controlnet_image, + controlnet_scales=controlnet_scale, return_type="latent" if refiner else "image", ) if refiner is None: @@ -180,9 +193,9 @@ def run_base_and_refiner(warmup=False): cudart.cudaProfilerStop() if refiner: - print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("e2e", perf_data["latency"])) - print("|------------|--------------|") + print("|----------------|--------------|") + print("| {:^14} | {:>9.2f} ms |".format("e2e", perf_data["latency"])) + print("|----------------|--------------|") metadata = get_metadata(args, True) metadata.update({"base." + key: val for key, val in base.metadata().items()}) @@ -197,11 +210,11 @@ def run_base_and_refiner(warmup=False): def run_demo(args): """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image.""" - + controlnet_image, controlnet_scale = process_controlnet_arguments(args) prompt, negative_prompt = repeat_prompt(args) batch_size = len(prompt) base, refiner = load_pipelines(args, batch_size) - run_pipelines(args, base, refiner, prompt, negative_prompt) + run_pipelines(args, base, refiner, prompt, negative_prompt, controlnet_image, controlnet_scale) base.teardown() if refiner: refiner.teardown() @@ -294,7 +307,10 @@ def run_dynamic_shape_demo(args): if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") - args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + parser = arg_parser("Options for Stable Diffusion XL Demo") + add_controlnet_arguments(parser) + args = parse_arguments(is_xl=True, parser=parser) + no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0] if no_prompt: run_dynamic_shape_demo(args) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 70b4f34fdd988..f0c83fc507ae4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -19,22 +19,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- - import argparse -from typing import Any, Dict - +import os +import sys +from importlib.metadata import PackageNotFoundError, version +from io import BytesIO +from typing import Any, Dict, List + +import controlnet_aux +import cv2 +import numpy as np +import requests import torch +from diffusers.utils import load_image from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_paths +from PIL import Image class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): pass -def parse_arguments(is_xl: bool, description: str): - parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) +def arg_parser(description: str): + return argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + +def parse_arguments(is_xl: bool, parser): engines = ["ORT_CUDA", "ORT_TRT", "TRT"] parser.add_argument( @@ -69,7 +80,7 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "UniPC", "LCM"] if is_xl else ["DDIM", "EulerA", "UniPC"], + choices=["DDIM", "UniPC", "LCM"] if is_xl else ["DDIM", "EulerA", "UniPC", "LCM"], help="Scheduler for diffusion process" + " of base" if is_xl else "", ) @@ -106,6 +117,11 @@ def parse_arguments(is_xl: bool, description: str): help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", ) + parser.add_argument( + "--lora-scale", type=float, default=1, help="Scale of LoRA weights, default 1 (must between 0 and 1)" + ) + parser.add_argument("--lora-weights", type=str, default="", help="LoRA weights to apply in the base model") + if is_xl: parser.add_argument( "--lcm", @@ -142,6 +158,10 @@ def parse_arguments(is_xl: bool, description: str): help="A value between 0 and 1. The higher the value less the final image similar to the seed image.", ) + parser.add_argument( + "--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline." + ) + # ONNX export parser.add_argument( "--onnx-opset", @@ -182,10 +202,6 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") - parser.add_argument( - "--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline." - ) - group = parser.add_argument_group("Options for ORT_CUDA engine only") group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") @@ -228,25 +244,39 @@ def parse_arguments(is_xl: bool, description: str): args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 if is_xl: - if args.lcm: - if args.guidance > 1.0: - print("[I] Use --guidance=1.0 for base since LCM is used.") - args.guidance = 1.0 - if args.scheduler != "LCM": - print("[I] Use --scheduler=LCM for base since LCM is used.") - args.scheduler = "LCM" - if args.denoising_steps > 16: - print("[I] Use --denoising_steps=8 (no more than 16) for base since LCM is used.") - args.denoising_steps = 8 + if args.lcm and args.scheduler != "LCM": + print("[I] Use --scheduler=LCM for base since LCM is used.") + args.scheduler = "LCM" + assert args.strength > 0.0 and args.strength < 1.0 + assert not (args.lcm and args.lora_weights), "it is not supported to use both lcm unet and Lora together" + + if args.scheduler == "LCM": + if args.guidance > 1.0: + print("[I] Use --guidance=1.0 for base since LCM is used.") + args.guidance = 1.0 + if args.denoising_steps > 16: + print("[I] Use --denoising_steps=8 (no more than 16) for base since LCM is used.") + args.denoising_steps = 8 + print(args) return args +def max_batch(args): + do_classifier_free_guidance = args.guidance > 1.0 + batch_multiplier = 2 if do_classifier_free_guidance else 1 + max_batch_size = 32 // batch_multiplier + if args.engine != "ORT_CUDA" and (args.build_dynamic_shape or args.height > 512 or args.width > 512): + max_batch_size = 8 // batch_multiplier + return max_batch_size + + def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: metadata = { + "command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]), "args.prompt": args.prompt, "args.negative_prompt": args.negative_prompt, "args.batch_size": args.batch_size, @@ -257,6 +287,14 @@ def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: "engine": args.engine, } + if args.lora_weights: + metadata["lora_weights"] = args.lora_weights + metadata["lora_scale"] = args.lora_scale + + if args.controlnet_type: + metadata["controlnet_type"] = args.controlnet_type + metadata["controlnet_scale"] = args.controlnet_scale + if is_xl and not args.disable_refiner: metadata["base.scheduler"] = args.scheduler metadata["base.denoising_steps"] = args.denoising_steps @@ -270,6 +308,27 @@ def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: metadata["denoising_steps"] = args.denoising_steps metadata["guidance"] = args.guidance + # Version of installed python packages + packages = "" + for name in [ + "onnxruntime-gpu", + "torch", + "tensorrt", + "transformers", + "diffusers", + "onnx", + "onnx-graphsurgeon", + "polygraphy", + "controlnet_aux", + ]: + try: + packages += (" " if packages else "") + f"{name}=={version(name)}" + except PackageNotFoundError: + continue + metadata["packages"] = packages + metadata["device"] = torch.cuda.get_device_name() + metadata["torch.version.cuda"] = torch.version.cuda + return metadata @@ -318,6 +377,7 @@ def init_pipeline( engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, + tmp_dir=os.path.join(args.work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"), force_engine_rebuild=args.force_engine_build, device_id=torch.cuda.current_device(), ) @@ -361,3 +421,248 @@ def init_pipeline( ) return pipeline + + +def get_depth_image(image): + """ + Create depth map for SDXL depth control net. + """ + from transformers import DPTFeatureExtractor, DPTForDepthEstimation + + depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") + feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") + + image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + with torch.no_grad(), torch.autocast("cuda"): + depth_map = depth_estimator(image).predicted_depth + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(1024, 1024), + mode="bicubic", + align_corners=False, + ) + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = (depth_map - depth_min) / (depth_max - depth_min) + image = torch.cat([depth_map] * 3, dim=1) + + image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + return image + + +def get_canny_image(image) -> Image.Image: + """ + Create canny image for SDXL control net. + """ + image = np.array(image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + image = Image.fromarray(image) + return image + + +def process_controlnet_images_xl(args) -> List[Image.Image]: + """ + Process control image for SDXL control net. + """ + image = None + if args.controlnet_image: + image = Image.open(args.controlnet_image[0]) + else: + # If no image is provided, download an image for demo purpose. + if args.controlnet_type[0] == "canny": + image = load_image( + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ) + elif args.controlnet_type[0] == "depth": + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png" + ) + + controlnet_images = [] + if args.controlnet_type[0] == "canny": + controlnet_images.append(get_canny_image(image)) + elif args.controlnet_type[0] == "depth": + controlnet_images.append(get_depth_image(image)) + else: + raise ValueError(f"The controlnet is not supported for SDXL: {args.controlnet_type}") + + return controlnet_images + + +def add_controlnet_arguments(parser, is_xl: bool = False): + """ + Add control net related arguments. + """ + group = parser.add_argument_group("Options for ControlNet (only supports SD 1.5 or XL).") + + group.add_argument( + "--controlnet-image", + nargs="*", + type=str, + default=[], + help="Path to the input regular RGB image/images for controlnet", + ) + group.add_argument( + "--controlnet-type", + nargs="*", + type=str, + default=[], + choices=list(PipelineInfo.supported_controlnet("xl-1.0" if is_xl else "1.5").keys()), + help="A list of controlnet type", + ) + group.add_argument( + "--controlnet-scale", + nargs="*", + type=float, + default=[], + help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet. Default is 0.35 for SDXL, or 1.0 for SD 1.5", + ) + + +def download_image(url) -> Image.Image: + response = requests.get(url) + return Image.open(BytesIO(response.content)).convert("RGB") + + +def controlnet_demo_images(controlnet_list: List[str], height, width) -> List[Image.Image]: + """ + Return demo images of control net v1.1 for Stable Diffusion 1.5. + """ + control_images = [] + shape = (height, width) + for controlnet in controlnet_list: + if controlnet == "canny": + canny_image = download_image( + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ) + canny_image = controlnet_aux.CannyDetector()(canny_image) + control_images.append(canny_image.resize(shape)) + elif controlnet == "normalbae": + normal_image = download_image( + "https://huggingface.co/lllyasviel/sd-controlnet-normal/resolve/main/images/toy.png" + ) + normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(normal_image) + control_images.append(normal_image.resize(shape)) + elif controlnet == "depth": + depth_image = download_image( + "https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png" + ) + depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(depth_image) + control_images.append(depth_image.resize(shape)) + elif controlnet == "mlsd": + mlsd_image = download_image( + "https://huggingface.co/lllyasviel/sd-controlnet-mlsd/resolve/main/images/room.png" + ) + mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(mlsd_image) + control_images.append(mlsd_image.resize(shape)) + elif controlnet == "openpose": + openpose_image = download_image( + "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png" + ) + openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(openpose_image) + control_images.append(openpose_image.resize(shape)) + elif controlnet == "scribble": + scribble_image = download_image( + "https://huggingface.co/lllyasviel/sd-controlnet-scribble/resolve/main/images/bag.png" + ) + scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")( + scribble_image, scribble=True + ) + control_images.append(scribble_image.resize(shape)) + elif controlnet == "seg": + seg_image = download_image( + "https://huggingface.co/lllyasviel/sd-controlnet-seg/resolve/main/images/house.png" + ) + seg_image = controlnet_aux.SamDetector.from_pretrained( + "ybelkada/segment-anything", subfolder="checkpoints" + )(seg_image) + control_images.append(seg_image.resize(shape)) + else: + raise ValueError(f"There is no demo image of this controlnet: {controlnet}") + return control_images + + +def process_controlnet_image(controlnet_type: str, image: Image.Image, height, width): + """ + Process control images of control net v1.1 for Stable Diffusion 1.5. + """ + control_image = None + shape = (height, width) + image = image.convert("RGB") + if controlnet_type == "canny": + canny_image = controlnet_aux.CannyDetector()(image) + control_image = canny_image.resize(shape) + elif controlnet_type == "normalbae": + normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(image) + control_image = normal_image.resize(shape) + elif controlnet_type == "depth": + depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(image) + control_image = depth_image.resize(shape) + elif controlnet_type == "mlsd": + mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(image) + control_image = mlsd_image.resize(shape) + elif controlnet_type == "openpose": + openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(image) + control_image = openpose_image.resize(shape) + elif controlnet_type == "scribble": + scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")(image, scribble=True) + control_image = scribble_image.resize(shape) + elif controlnet_type == "seg": + seg_image = controlnet_aux.SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")( + image + ) + control_image = seg_image.resize(shape) + else: + raise ValueError(f"There is no demo image of this controlnet_type: {controlnet_type}") + return control_image + + +def process_controlnet_arguments(args): + """ + Process control net arguments, and returns a list of control images and a tensor of control net scales. + """ + assert isinstance(args.controlnet_type, list) + assert isinstance(args.controlnet_scale, list) + assert isinstance(args.controlnet_image, list) + if args.version not in ["1.5", "xl-1.0"]: + raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5 or XL.") + + is_xl = args.version == "xl-1.0" + if is_xl and len(args.controlnet_type) > 1: + raise ValueError("This demo only support one ControlNet for Stable Diffusion XL.") + + if len(args.controlnet_image) != 0 and len(args.controlnet_image) != len(args.controlnet_scale): + raise ValueError( + f"Numbers of ControlNets {len(args.controlnet_image)} should be equal to number of ControlNet scales {len(args.controlnet_scale)}." + ) + + if len(args.controlnet_type) == 0: + return None, None + + if len(args.controlnet_scale) == 0: + args.controlnet_scale = [0.5 if is_xl else 1.0] * len(args.controlnet_type) + elif len(args.controlnet_type) != len(args.controlnet_scale): + raise ValueError( + f"Numbers of ControlNets {len(args.controlnet_type)} should be equal to number of ControlNet scales {len(args.controlnet_scale)}." + ) + + # Convert controlnet scales to tensor + controlnet_scale = torch.FloatTensor(args.controlnet_scale) + + if is_xl: + images = process_controlnet_images_xl(args) + else: + images = [] + if len(args.controlnet_image) > 0: + for i, image in enumerate(args.controlnet_image): + images.append( + process_controlnet_image(args.controlnet_type[i], Image.open(image), args.height, args.width) + ) + else: + images = controlnet_demo_images(args.controlnet_type, args.height, args.width) + + return images, controlnet_scale diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 8206bee753859..c09aff2f514c6 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -29,7 +29,7 @@ import onnx import onnx_graphsurgeon as gs import torch -from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from onnx import GraphProto, ModelProto, shape_inference from ort_optimizer import OrtStableDiffusionOptimizer from polygraphy.backend.onnx.loader import fold_constants @@ -92,6 +92,10 @@ def __init__( max_image_size=1024, use_fp16_vae=True, use_lcm=False, + do_classifier_free_guidance=True, + controlnet=None, + lora_weights=None, + lora_scale=1.0, ): self.version = version self._is_inpaint = is_inpaint @@ -101,6 +105,11 @@ def __init__( self._max_image_size = max_image_size self._use_fp16_vae = use_fp16_vae self._use_lcm = use_lcm + self.do_classifier_free_guidance = do_classifier_free_guidance and not use_lcm + self.controlnet = controlnet # A list of control net type + self.lora_weights = lora_weights + self.lora_scale = lora_scale + if is_refiner: assert not use_lcm assert self.is_xl() @@ -224,6 +233,41 @@ def default_image_size(self): return 768 return 512 + @staticmethod + def supported_controlnet(version="1.5"): + if version == "xl-1.0": + return { + "canny": "diffusers/controlnet-canny-sdxl-1.0", + "depth": "diffusers/controlnet-depth-sdxl-1.0", + } + elif version == "1.5": + return { + "canny": "lllyasviel/control_v11p_sd15_canny", + "depth": "lllyasviel/control_v11f1p_sd15_depth", + "openpose": "lllyasviel/control_v11p_sd15_openpose", + # "tile": "lllyasviel/control_v11f1e_sd15_tile", + # "lineart": "lllyasviel/control_v11p_sd15_lineart", + # "inpaint": "lllyasviel/control_v11p_sd15_inpaint", + # "softedge": "lllyasviel/control_v11p_sd15_softedge", + "mlsd": "lllyasviel/control_v11p_sd15_mlsd", + "scribble": "lllyasviel/control_v11p_sd15_scribble", + # "ip2p": "lllyasviel/control_v11e_sd15_ip2p", + "normalbae": "lllyasviel/control_v11p_sd15_normalbae", + "seg": "lllyasviel/control_v11p_sd15_seg", + # "shuffle": "lllyasviel/control_v11e_sd15_shuffle", + # "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime", + } + return None + + def controlnet_name(self): + """Return a list of controlnet name""" + if not self.controlnet: + return None + controlnet_map = PipelineInfo.supported_controlnet(self.version) + if controlnet_map is None: + return None + return [controlnet_map[controlnet] for controlnet in self.controlnet] + class BaseModel: def __init__( @@ -254,6 +298,9 @@ def __init__( self.embedding_dim = embedding_dim self.text_maxlen = text_maxlen + def get_batch_multiplier(self): + return 2 if self.pipeline_info.do_classifier_free_guidance else 1 + def get_ort_optimizer(self): model_name_to_model_type = { "CLIP": "clip", @@ -316,7 +363,10 @@ def get_profile_id(self, batch_size, image_height, image_width, static_batch, st _, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + if (self.name in ["UNet", "UNetXL"]) and (self.get_batch_multiplier() == 1): + profile_id = f"_b1_{batch_size}" if static_batch else f"_b1_{min_batch}_{max_batch}" + else: + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" if self.name != "CLIP": if static_image_shape: @@ -348,6 +398,7 @@ def optimize_ort( fp32_op_list=None, optimize_by_ort=True, optimize_by_fusion=True, + tmp_dir=None, ): optimizer = self.get_ort_optimizer() optimizer.optimize( @@ -358,6 +409,7 @@ def optimize_ort( fp32_op_list=fp32_op_list, optimize_by_ort=optimize_by_ort, optimize_by_fusion=optimize_by_fusion, + tmp_dir=tmp_dir, ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): @@ -525,6 +577,7 @@ def optimize_ort( fp32_op_list=None, optimize_by_ort=True, optimize_by_fusion=True, + tmp_dir=None, ): optimizer = self.get_ort_optimizer() @@ -538,6 +591,7 @@ def optimize_ort( keep_outputs=["text_embeddings"], optimize_by_ort=optimize_by_ort, optimize_by_fusion=optimize_by_fusion, + tmp_dir=tmp_dir, ) elif optimize_by_fusion: with tempfile.TemporaryDirectory() as tmp_dir: @@ -556,6 +610,7 @@ def optimize_ort( keep_outputs=["text_embeddings", "hidden_states"], optimize_by_ort=optimize_by_ort, optimize_by_fusion=optimize_by_fusion, + tmp_dir=tmp_dir, ) else: # input is optimized model, there is no need to add hidden states. optimizer.optimize( @@ -567,6 +622,7 @@ def optimize_ort( keep_outputs=["text_embeddings", "hidden_states"], optimize_by_ort=optimize_by_ort, optimize_by_fusion=optimize_by_fusion, + tmp_dir=tmp_dir, ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): @@ -622,6 +678,100 @@ def get_shape_dict(self, batch_size, image_height, image_width): return output +class UNet2DConditionControlNetModel(torch.nn.Module): + def __init__(self, unet, controlnets: ControlNetModel): + super().__init__() + self.unet = unet + self.controlnets = controlnets + + def forward(self, sample, timestep, encoder_hidden_states, controlnet_images, controlnet_scales): + for i, (controlnet_image, conditioning_scale, controlnet) in enumerate( + zip(controlnet_images, controlnet_scales, self.controlnets) + ): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + down_samples = [down_sample * conditioning_scale for down_sample in down_samples] + mid_sample *= conditioning_scale + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + noise_pred = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + return noise_pred[0] + + +# Modified from convert_stable_diffusion_controlnet_to_onnx.py in diffusers +class UNet2DConditionXLControlNetModel(torch.nn.Module): + def __init__(self, unet, controlnets: ControlNetModel): + super().__init__() + self.unet = unet + self.controlnets = controlnets + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids, + controlnet_images, + controlnet_scales, + ): + added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} + for i, (controlnet_image, conditioning_scale, controlnet) in enumerate( + zip(controlnet_images, controlnet_scales, self.controlnets) + ): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + conditioning_scale=conditioning_scale, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + noise_pred = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + return noise_pred[0] + + class UNet(BaseModel): def __init__( self, @@ -642,72 +792,129 @@ def __init__( embedding_dim=pipeline_info.unet_embedding_dim(), text_maxlen=text_maxlen, ) + self.unet_dim = unet_dim + self.controlnet = pipeline_info.controlnet_name() def load_model(self, framework_model_dir, hf_token, subfolder="unet"): options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} - return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + if self.controlnet: + cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 else {} + controlnets = torch.nn.ModuleList( + [ControlNetModel.from_pretrained(name, **cnet_model_opts).to(self.device) for name in self.controlnet] + ) + model = UNet2DConditionControlNetModel(model, controlnets) + + return model def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] + if not self.controlnet: + return ["sample", "timestep", "encoder_hidden_states"] + else: + return ["sample", "timestep", "encoder_hidden_states", "controlnet_images", "controlnet_scales"] def get_output_names(self): return ["latent"] def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, + b = "2B" if self.get_batch_multiplier() == 2 else "B" + output = { + "sample": {0: b, 2: "H", 3: "W"}, + "encoder_hidden_states": {0: b}, + "latent": {0: b, 2: "H", 3: "W"}, } + if self.controlnet: + output.update( + { + "controlnet_images": {1: b, 3: "8H", 4: "8W"}, + } + ) + return output def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) ( min_batch, max_batch, - _, - _, - _, - _, + min_image_height, + max_image_height, + min_image_width, + max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { + m = self.get_batch_multiplier() + output = { "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + (m * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (m * batch_size, self.unet_dim, latent_height, latent_width), + (m * max_batch, self.unet_dim, max_latent_height, max_latent_width), ], "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), + (m * min_batch, self.text_maxlen, self.embedding_dim), + (m * batch_size, self.text_maxlen, self.embedding_dim), + (m * max_batch, self.text_maxlen, self.embedding_dim), ], } + if self.controlnet: + output.update( + { + "controlnet_images": [ + (len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width), + (len(self.controlnet), m * batch_size, 3, image_height, image_width), + (len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width), + ] + } + ) + return output + def get_shape_dict(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + m = self.get_batch_multiplier() + output = { + "sample": (m * batch_size, self.unet_dim, latent_height, latent_width), "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), + "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (m * batch_size, 4, latent_height, latent_width), } + if self.controlnet: + output.update( + { + "controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width), + "controlnet_scales": [len(self.controlnet)], + } + ) + return output + def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) dtype = torch.float16 if self.fp16 else torch.float32 - return ( + m = self.get_batch_multiplier() + output = ( torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device ), torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ) + if self.controlnet: + output = ( + *output, + torch.randn( + len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device + ), + torch.randn(len(self.controlnet), dtype=dtype, device=self.device), + ) + return output + def fp32_input_output_names(self) -> List[str]: return ["sample", "timestep"] @@ -737,8 +944,7 @@ def __init__( self.time_dim = time_dim self.custom_unet = pipeline_info.custom_unet() - self.do_classifier_free_guidance = not (self.custom_unet and "lcm" in self.custom_unet) - self.batch_multiplier = 2 if self.do_classifier_free_guidance else 1 + self.controlnet = pipeline_info.controlnet_name() def load_model(self, framework_model_dir, hf_token, subfolder="unet"): options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} @@ -750,49 +956,62 @@ def load_model(self, framework_model_dir, hf_token, subfolder="unet"): unet.save_pretrained(model_dir) else: unet = UNet2DConditionModel.from_pretrained(model_dir, **options) - return unet.to(self.device) + model = unet.to(self.device) + else: + model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + if self.controlnet: + cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 else {} + controlnets = torch.nn.ModuleList( + [ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnet] + ) + model = UNet2DConditionXLControlNetModel(model, controlnets) - return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + return model def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + input_names = ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + if self.controlnet: + return [*input_names, "controlnet_images", "controlnet_scales"] + return input_names def get_output_names(self): return ["latent"] def get_dynamic_axes(self): - if self.do_classifier_free_guidance: - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - "text_embeds": {0: "2B"}, - "time_ids": {0: "2B"}, - } - return { - "sample": {0: "B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "B"}, - "latent": {0: "B", 2: "H", 3: "W"}, - "text_embeds": {0: "B"}, - "time_ids": {0: "B"}, + b = "2B" if self.get_batch_multiplier() == 2 else "B" + output = { + "sample": {0: b, 2: "H", 3: "W"}, + "encoder_hidden_states": {0: b}, + "text_embeds": {0: b}, + "time_ids": {0: b}, + "latent": {0: b, 2: "H", 3: "W"}, } + if self.controlnet: + output.update( + { + "controlnet_images": {1: b, 3: "8H", 4: "8W"}, + } + ) + return output + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) ( min_batch, max_batch, - _, - _, - _, - _, + min_image_height, + max_image_height, + min_image_width, + max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - m = self.batch_multiplier - return { + m = self.get_batch_multiplier() + output = { "sample": [ (m * min_batch, self.unet_dim, min_latent_height, min_latent_width), (m * batch_size, self.unet_dim, latent_height, latent_width), @@ -811,35 +1030,72 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, ], } + if self.controlnet: + output.update( + { + "controlnet_images": [ + (len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width), + (len(self.controlnet), m * batch_size, 3, image_height, image_width), + (len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width), + ], + } + ) + return output + def get_shape_dict(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - m = self.batch_multiplier - return { + m = self.get_batch_multiplier() + output = { "sample": (m * batch_size, self.unet_dim, latent_height, latent_width), "timestep": (1,), "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (m * batch_size, 4, latent_height, latent_width), "text_embeds": (m * batch_size, 1280), "time_ids": (m * batch_size, self.time_dim), + "latent": (m * batch_size, 4, latent_height, latent_width), } + if self.controlnet: + output.update( + { + "controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width), + "controlnet_scales": [len(self.controlnet)], + } + ) + return output + def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) dtype = torch.float16 if self.fp16 else torch.float32 - m = self.batch_multiplier - return ( - torch.randn( - m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), - { - "added_cond_kwargs": { - "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), - "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), - } - }, - ) + m = self.get_batch_multiplier() + if not self.controlnet: + return ( + torch.randn( + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + { + "added_cond_kwargs": { + "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), + } + }, + ) + else: + # sample, timestep, encoder_hidden_states, text_embeds, time_ids, controlnet_images, controlnet_scales, + return ( + torch.randn( + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), + torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), + torch.randn( + len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device + ), + torch.randn(len(self.controlnet), dtype=dtype, device=self.device), + ) def fp32_input_output_names(self) -> List[str]: return ["sample", "timestep"] diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index fac72be346b3d..8e167b74d6918 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import hashlib import os from enum import Enum @@ -68,18 +69,46 @@ def __init__( self.torch_models = {} self.use_vae_slicing = False + self.torch_sdpa = getattr(torch.nn.functional, "scaled_dot_product_attention", None) + def enable_vae_slicing(self): self.use_vae_slicing = True + def disable_torch_spda(self): + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + delattr(torch.nn.functional, "scaled_dot_product_attention") + + def enable_torch_spda(self): + if (not hasattr(torch.nn.functional, "scaled_dot_product_attention")) and self.torch_sdpa: + torch.nn.functional.scaled_dot_product_attention = self.torch_sdpa + def teardown(self): for engine in self.engines.values(): del engine self.engines = {} def get_cached_model_name(self, model_name): + hash_source = [] + if model_name in ["clip", "clip2", "unet", "unetxl"] and self.pipeline_info.lora_weights: + if self.pipeline_info.lora_weights in [ + "latent-consistency/lcm-lora-sdxl", + "latent-consistency/lcm-lora-sdv1-5", + ]: + if model_name in ["unet", "unetxl"]: + model_name = model_name + "_lcm-lora" + else: + model_name = model_name + "_lora" + hash_source.append(self.pipeline_info.lora_weights) + # TODO(tianleiwu): save custom model to a directory named by its original model. if model_name == "unetxl" and self.pipeline_info.custom_unet(): - model_name = "lcm_" + model_name + model_name = model_name + "_lcm" + + if model_name in ["unet", "unetxl"] and self.pipeline_info.controlnet: + model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet) + + if hash_source: + model_name += "_" + hashlib.md5("\t".join(hash_source).encode("utf-8")).digest().hex()[:8] # TODO: When we support original VAE, we shall save custom VAE to another directory. @@ -87,22 +116,54 @@ def get_cached_model_name(self, model_name): model_name += "_inpaint" return model_name - def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""): + def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True): engine_name = self.engine_type.name.lower() directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + suffix - onnx_model_dir = os.path.join(onnx_dir, directory_name) - os.makedirs(onnx_model_dir, exist_ok=True) + onnx_model_dir = os.path.join(root_dir, directory_name) + if create: + os.makedirs(onnx_model_dir, exist_ok=True) + return onnx_model_dir + + def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""): + onnx_model_dir = self.get_model_dir(model_name, onnx_dir, opt=opt, suffix=suffix) return os.path.join(onnx_model_dir, "model.onnx") def get_engine_path(self, engine_dir, model_name, profile_id): return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id) - def load_models(self, framework_model_dir: str): - # Disable torch SDPA since torch 2.0.* cannot export it to ONNX - # TODO(tianleiwu): Test and remove it if this is not needed in Torch 2.1. - if hasattr(torch.nn.functional, "scaled_dot_product_attention"): - delattr(torch.nn.functional, "scaled_dot_product_attention") + def load_pipeline_with_lora(self): + """Load text encoders and UNet with diffusers pipeline""" + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + self.pipeline_info.name(), + variant="fp16", + torch_dtype=torch.float16, + ) + pipeline.load_lora_weights(self.pipeline_info.lora_weights) + pipeline.fuse_lora(lora_scale=self.pipeline_info.lora_scale) + + del pipeline.vae + pipeline.vae = None + return pipeline + + def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir): + if model_name in ["clip", "clip2", "unet", "unetxl"] and pipeline: + if model_name == "clip": + model = pipeline.text_encoder + pipeline.text_encoder = None + elif model_name == "clip2": + model = pipeline.text_encoder_2 + pipeline.text_encoder_2 = None + else: + model = pipeline.unet + pipeline.unet = None + else: + model = model_obj.load_model(framework_model_dir, self.hf_token) + + return model.to(self.torch_device) + def load_models(self, framework_model_dir: str): # For TRT or ORT_TRT, we will export fp16 torch model for UNet. # For ORT_CUDA, we export fp32 model first, then optimize to fp16. export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] @@ -198,6 +259,7 @@ def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: En onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") + timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") framework_model_dir = os.path.join(root_dir, engine_type.name, "torch_model") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index a03ca7ce2912c..2ac9a45577676 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -158,6 +158,7 @@ def build_engines( engine_dir: str, framework_model_dir: str, onnx_dir: str, + tmp_dir: Optional[str] = None, onnx_opset_version: int = 17, force_engine_rebuild: bool = False, device_id: int = 0, @@ -187,22 +188,39 @@ def build_engines( if model_name not in self.model_config: self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph) + # Load lora only when we need export text encoder or UNet to ONNX. + load_lora = False + if self.pipeline_info.lora_weights: + for model_name in self.models: + if model_name not in ["clip", "clip2", "unet", "unetxl"]: + continue + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + + suffix = ".fp16" if self.model_config[model_name].fp16 else ".fp32" + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + load_lora = True + break + # Export models to ONNX + self.disable_torch_spda() + pipe = self.load_pipeline_with_lora() if load_lora else None + for model_name, model_obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: continue onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) - onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") - onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") - onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path + suffix = ".fp16" if self.model_config[model_name].fp16 else ".fp32" + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): print("----") logger.info("Exporting model: %s", onnx_path) - model = model_obj.load_model(framework_model_dir, self.hf_token) - if model_name == "vae": - model.to(torch.float32) + + model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir) + model = model.to(torch.float32) with torch.inference_mode(): # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. @@ -230,18 +248,19 @@ def build_engines( # If final target is fp16 model, we save fp32 optimized model so that it is easy to tune # fp16 conversion. That could save a lot of time in developing. use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16 + onnx_fp32_path = onnx_path if use_fp32_intermediate: + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") if not os.path.exists(onnx_fp32_path): print("------") logger.info("Generating optimized model: %s", onnx_fp32_path) - - # There is risk that some ORT fused ops fp32 only. So far, we have not encountered such issue. model_obj.optimize_ort( onnx_path, onnx_fp32_path, to_fp16=False, fp32_op_list=self.model_config[model_name].force_fp32_ops, optimize_by_ort=self.model_config[model_name].optimize_by_ort, + tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp32", create=False), ) else: logger.info("Found cached optimized model: %s", onnx_fp32_path) @@ -255,24 +274,25 @@ def build_engines( optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort model_obj.optimize_ort( - onnx_fp32_path if use_fp32_intermediate else onnx_path, + onnx_fp32_path, onnx_opt_path, to_fp16=self.model_config[model_name].fp16, fp32_op_list=self.model_config[model_name].force_fp32_ops, optimize_by_ort=optimize_by_ort, optimize_by_fusion=not use_fp32_intermediate, + tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp16", create=False), ) else: logger.info("Found cached optimized model: %s", onnx_opt_path) + self.enable_torch_spda() built_engines = {} for model_name in self.models: if model_name == "vae" and self.vae_torch_fallback: continue - onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") - onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") - onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path + suffix = ".fp16" if self.model_config[model_name].fp16 else ".fp32" + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) use_cuda_graph = self.model_config[model_name].use_cuda_graph diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index d966833aba394..8c637007b840d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -189,7 +189,28 @@ def build_engines( if not os.path.isdir(onnx_dir): os.makedirs(onnx_dir) + # Load lora only when we need export text encoder or UNet to ONNX. + load_lora = False + if self.pipeline_info.lora_weights: + for model_name, model_obj in self.models.items(): + if model_name not in ["clip", "clip2", "unet", "unetxl"]: + continue + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if not self.has_engine_file(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + load_lora = True + break + # Export models to ONNX + self.disable_torch_spda() + pipe = self.load_pipeline_with_lora() if load_lora else None + for model_name, model_obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: continue @@ -204,7 +225,8 @@ def build_engines( if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): logger.info(f"Exporting model: {onnx_path}") - model = model_obj.load_model(framework_model_dir, self.hf_token) + model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir) + with torch.inference_mode(), torch.autocast("cuda"): inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) torch.onnx.export( @@ -230,6 +252,7 @@ def build_engines( model_obj.optimize_trt(onnx_path, onnx_opt_path) else: logger.info("Found cached optimized model: %s", onnx_opt_path) + self.enable_torch_spda() built_engines = {} for model_name, model_obj in self.models.items(): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py index 61a9c0d2c8fa9..bac1a8bb8140d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -407,11 +407,32 @@ def load_engines( self.load_models(framework_model_dir) + # Load lora only when we need export text encoder or UNet to ONNX. + load_lora = False + if self.pipeline_info.lora_weights: + for model_name, model_obj in self.models.items(): + if model_name not in ["clip", "clip2", "unet", "unetxl"]: + continue + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if force_export or force_build or not os.path.exists(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if force_export or not os.path.exists(onnx_opt_path): + if force_export or not os.path.exists(onnx_path): + load_lora = True + break + # Export models to ONNX - for model_name, obj in self.models.items(): + self.disable_torch_spda() + pipe = self.load_pipeline_with_lora() if load_lora else None + + for model_name, model_obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: continue - profile_id = obj.get_profile_id( + profile_id = model_obj.get_profile_id( opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape ) engine_path = self.get_engine_path(engine_dir, model_name, profile_id) @@ -421,9 +442,10 @@ def load_engines( if force_export or not os.path.exists(onnx_opt_path): if force_export or not os.path.exists(onnx_path): print(f"Exporting model: {onnx_path}") - model = obj.load_model(framework_model_dir, self.hf_token) + model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir) + with torch.inference_mode(), torch.autocast("cuda"): - inputs = obj.get_sample_input(1, opt_image_height, opt_image_width) + inputs = model_obj.get_sample_input(1, opt_image_height, opt_image_width) torch.onnx.export( model, inputs, @@ -431,9 +453,9 @@ def load_engines( export_params=True, opset_version=onnx_opset, do_constant_folding=True, - input_names=obj.get_input_names(), - output_names=obj.get_output_names(), - dynamic_axes=obj.get_dynamic_axes(), + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), ) del model torch.cuda.empty_cache() @@ -444,15 +466,16 @@ def load_engines( # Optimize onnx if force_optimize or not os.path.exists(onnx_opt_path): print(f"Generating optimizing model: {onnx_opt_path}") - obj.optimize_trt(onnx_path, onnx_opt_path) + model_obj.optimize_trt(onnx_path, onnx_opt_path) else: print(f"Found cached optimized model: {onnx_opt_path} ") + self.enable_torch_spda() # Build TensorRT engines - for model_name, obj in self.models.items(): + for model_name, model_obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: continue - profile_id = obj.get_profile_id( + profile_id = model_obj.get_profile_id( opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape ) engine_path = self.get_engine_path(engine_dir, model_name, profile_id) @@ -463,7 +486,7 @@ def load_engines( engine.build( onnx_opt_path, fp16=True, - input_profile=obj.get_input_profile( + input_profile=model_obj.get_input_profile( opt_batch_size, opt_image_height, opt_image_width, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 28e79abb9f018..ff91bf416bf51 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -8,6 +8,8 @@ """ import logging +import os +import shutil import tempfile from pathlib import Path @@ -33,23 +35,32 @@ def __init__(self, model_type: str): "clip": ClipOnnxModel, } - def optimize_by_ort(self, onnx_model, use_external_data_format=False): + def _optimize_by_ort(self, onnx_model, use_external_data_format, tmp_dir): + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") + tmp_model_path = Path(tmp_dir) / "model.onnx" + onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) + ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" + optimize_by_onnxruntime( + str(tmp_model_path), + use_gpu=True, + optimized_model_path=str(ort_optimized_model_path), + save_as_external_data=use_external_data_format, + external_data_filename="optimized.onnx_data", + ) + model = onnx.load(str(ort_optimized_model_path), load_external_data=True) + return self.model_type_class_mapping[self.model_type](model) + + def optimize_by_ort(self, onnx_model, use_external_data_format=False, tmp_dir=None): # Use this step to see the final graph that executed by Onnx Runtime. - with tempfile.TemporaryDirectory() as tmp_dir: - # Save to a temporary file so that we can load it with Onnx Runtime. - logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") - tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) - ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" - optimize_by_onnxruntime( - str(tmp_model_path), - use_gpu=True, - optimized_model_path=str(ort_optimized_model_path), - save_as_external_data=use_external_data_format, - external_data_filename="optimized.onnx_data", - ) - model = onnx.load(str(ort_optimized_model_path), load_external_data=True) - return self.model_type_class_mapping[self.model_type](model) + if tmp_dir is None: + with tempfile.TemporaryDirectory() as temp_dir: + return self._optimize_by_ort(onnx_model, use_external_data_format, temp_dir) + else: + os.makedirs(tmp_dir, exist_ok=True) + model = self._optimize_by_ort(onnx_model, use_external_data_format, tmp_dir) + shutil.rmtree(tmp_dir) + return model def optimize( self, @@ -62,6 +73,7 @@ def optimize( optimize_by_ort=True, optimize_by_fusion=True, final_target_float16=True, + tmp_dir=None, ): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") @@ -104,7 +116,7 @@ def optimize( from onnxruntime import __version__ as ort_version if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format): - m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) + m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format, tmp_dir=tmp_dir) if float16: logger.info("Convert to float16 ...") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index a0b3c3a1c85b1..5d51554a5cee4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -25,6 +25,7 @@ import random from typing import Any, Dict, List +import numpy as np import nvtx import torch from cuda import cudart @@ -103,8 +104,6 @@ def __init__( self.verbose = verbose self.nvtx_profile = nvtx_profile - self.stages = pipeline_info.stages() - self.use_cuda_graph = use_cuda_graph self.tokenizer = None @@ -138,11 +137,20 @@ def __init__( self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" ) + self.control_image_processor = None + if self.pipeline_info.is_xl() and self.pipeline_info.controlnet: + from diffusers.image_processor import VaeImageProcessor + + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=8, do_convert_rgb=True, do_normalize=False + ) + # Create CUDA events self.events = {} for stage in ["clip", "denoise", "vae", "vae_encoder"]: for marker in ["start", "stop"]: self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] + self.markers = {} def is_backend_tensorrt(self): return self.engine_type == EngineType.TRT @@ -219,19 +227,63 @@ def initialize_timesteps(self, timesteps, strength): timesteps = self.scheduler.timesteps[t_start:].to(self.device) return timesteps, t_start - def preprocess_images(self, batch_size, images=()): + def start_profile(self, name, color="blue"): if self.nvtx_profile: - nvtx_image_preprocess = nvtx.start_range(message="image_preprocess", color="pink") + self.markers[name] = nvtx.start_range(message=name, color=color) + event_name = name + "-start" + if event_name in self.events: + cudart.cudaEventRecord(self.events[event_name], 0) + + def stop_profile(self, name): + event_name = name + "-stop" + if event_name in self.events: + cudart.cudaEventRecord(self.events[event_name], 0) + if self.nvtx_profile: + nvtx.end_range(self.markers[name]) + + def preprocess_images(self, batch_size, images=()): + self.start_profile("preprocess", color="pink") init_images = [] for i in images: image = i.to(self.device).float() if image.shape[0] != batch_size: image = image.repeat(batch_size, 1, 1, 1) init_images.append(image) - if self.nvtx_profile: - nvtx.end_range(nvtx_image_preprocess) + self.stop_profile("preprocess") return tuple(init_images) + def preprocess_controlnet_images( + self, batch_size, images=None, do_classifier_free_guidance=True, height=1024, width=1024 + ): + """ + Process a list of PIL.Image.Image as control images, and return a torch tensor. + """ + if images is None: + return None + self.start_profile("preprocess", color="pink") + + if not self.pipeline_info.is_xl(): + images = [ + (np.array(i.convert("RGB")).astype(np.float32) / 255.0)[..., None] + .transpose(3, 2, 0, 1) + .repeat(batch_size, axis=0) + for i in images + ] + if do_classifier_free_guidance: + images = [torch.cat([torch.from_numpy(i).to(self.device).float()] * 2) for i in images] + else: + images = [torch.from_numpy(i).to(self.device).float() for i in images] + images = torch.cat([image[None, ...] for image in images], dim=0) + images = images.to(dtype=torch.float16) + else: + images = self.control_image_processor.preprocess(images, height=height, width=width).to(dtype=torch.float32) + images = images.repeat_interleave(batch_size, dim=0) + images = images.to(device=self.device, dtype=torch.float16) + if do_classifier_free_guidance: + images = torch.cat([images] * 2) + self.stop_profile("preprocess") + return images + def encode_prompt( self, prompt, @@ -246,9 +298,7 @@ def encode_prompt( if tokenizer is None: tokenizer = self.tokenizer - if self.nvtx_profile: - nvtx_clip = nvtx.start_range(message="clip", color="green") - cudart.cudaEventRecord(self.events["clip-start"], 0) + self.start_profile("clip", color="green") # Tokenize prompt text_input_ids = ( @@ -308,9 +358,7 @@ def encode_prompt( else: text_embeddings = hidden_states.to(dtype=torch.float16) - cudart.cudaEventRecord(self.events["clip-stop"], 0) - if self.nvtx_profile: - nvtx.end_range(nvtx_clip) + self.stop_profile("clip") if pooled_outputs: return text_embeddings, pooled_output @@ -330,14 +378,12 @@ def denoise_latent( ): do_classifier_free_guidance = guidance > 1.0 - cudart.cudaEventRecord(self.events["denoise-start"], 0) + self.start_profile("denoise", color="blue") + if not isinstance(timesteps, torch.Tensor): timesteps = self.scheduler.timesteps for step_index, timestep in enumerate(timesteps): - if self.nvtx_profile: - nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") - # Expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -347,8 +393,6 @@ def denoise_latent( if isinstance(mask, torch.Tensor): latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - if self.nvtx_profile: - nvtx.end_range(nvtx_latent_scale) # Predict the noise residual if self.nvtx_profile: @@ -361,6 +405,7 @@ def denoise_latent( "timestep": timestep_float, "encoder_hidden_states": text_embeddings, } + if add_kwargs: params.update(add_kwargs) @@ -369,9 +414,6 @@ def denoise_latent( if self.nvtx_profile: nvtx.end_range(nvtx_unet) - if self.nvtx_profile: - nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") - # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -384,36 +426,23 @@ def denoise_latent( else: latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) - if self.nvtx_profile: - nvtx.end_range(nvtx_latent_step) - - cudart.cudaEventRecord(self.events["denoise-stop"], 0) - # The actual number of steps. It might be different from denoising_steps. self.actual_steps = len(timesteps) + self.stop_profile("denoise") return latents def encode_image(self, init_image): - if self.nvtx_profile: - nvtx_vae = nvtx.start_range(message="vae_encoder", color="red") - cudart.cudaEventRecord(self.events["vae_encoder-start"], 0) + self.start_profile("vae_encoder", color="red") init_latents = self.run_engine("vae_encoder", {"images": init_image})["latent"] - cudart.cudaEventRecord(self.events["vae_encoder-stop"], 0) - if self.nvtx_profile: - nvtx.end_range(nvtx_vae) - init_latents = self.vae_scaling_factor * init_latents + self.stop_profile("vae_encoder") return init_latents def decode_latent(self, latents): - if self.nvtx_profile: - nvtx_vae = nvtx.start_range(message="vae", color="red") - cudart.cudaEventRecord(self.events["vae-start"], 0) + self.start_profile("vae", color="red") images = self.backend.vae_decode(latents) - cudart.cudaEventRecord(self.events["vae-stop"], 0) - if self.nvtx_profile: - nvtx.end_range(nvtx_vae) + self.stop_profile("vae") return images def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: @@ -428,18 +457,23 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: ) latency = (toc - tic) * 1000.0 - print("|------------|--------------|") - print("| {:^10} | {:^12} |".format("Module", "Latency")) - print("|------------|--------------|") + print("|----------------|--------------|") + print("| {:^14} | {:^12} |".format("Module", "Latency")) + print("|----------------|--------------|") if vae_enc: - print("| {:^10} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder)) - print("| {:^10} | {:>9.2f} ms |".format("CLIP", latency_clip)) - print("| {:^10} | {:>9.2f} ms |".format("UNet x " + str(self.actual_steps), latency_unet)) - print("| {:^10} | {:>9.2f} ms |".format("VAE-Dec", latency_vae)) - - print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("Pipeline", latency)) - print("|------------|--------------|") + print("| {:^14} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder)) + print("| {:^14} | {:>9.2f} ms |".format("CLIP", latency_clip)) + print( + "| {:^14} | {:>9.2f} ms |".format( + "UNet" + ("+CNet" if self.pipeline_info.controlnet else "") + " x " + str(self.actual_steps), + latency_unet, + ) + ) + print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae)) + + print("|----------------|--------------|") + print("| {:^14} | {:>9.2f} ms |".format("Pipeline", latency)) + print("|----------------|--------------|") print(f"Throughput: {throughput:.2f} image/s") perf_data = { diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index 87ce85af247a5..2d2fdb542c845 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -51,6 +51,8 @@ def _infer( denoising_steps=50, guidance=7.5, seed=None, + controlnet_images=None, + controlnet_scales=None, warmup=False, return_type="latent", ): @@ -73,10 +75,25 @@ def _infer( e2e_tic = time.perf_counter() # CLIP text encoder - text_embeddings = self.encode_prompt(prompt, negative_prompt) + do_classifier_free_guidance = guidance > 1.0 + text_embeddings = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + add_kwargs = None + if self.pipeline_info.controlnet: + controlnet_images = self.preprocess_controlnet_images( + latents.shape[0], controlnet_images, do_classifier_free_guidance=do_classifier_free_guidance + ) + add_kwargs = { + "controlnet_images": controlnet_images, + "controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device), + } # UNet denoiser - latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) + latents = self.denoise_latent(latents, text_embeddings, guidance=guidance, add_kwargs=add_kwargs) # VAE decode latent images = self.decode_latent(latents / self.vae_scaling_factor) @@ -99,6 +116,8 @@ def run( denoising_steps=30, guidance=7.5, seed=None, + controlnet_images=None, + controlnet_scales=None, warmup=False, return_type="image", ): @@ -138,6 +157,8 @@ def run( denoising_steps=denoising_steps, guidance=guidance, seed=seed, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scales, warmup=warmup, return_type=return_type, ) @@ -150,6 +171,8 @@ def run( denoising_steps=denoising_steps, guidance=guidance, seed=seed, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scales, warmup=warmup, return_type=return_type, ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py index 8ed7e20e94c07..d3387ab6db1bd 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -58,6 +58,8 @@ def _infer( denoising_steps=30, guidance=5.0, seed=None, + controlnet_images=None, + controlnet_scales=None, warmup=False, return_type="image", ): @@ -117,6 +119,20 @@ def _infer( add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids.to(self.device)} + if self.pipeline_info.controlnet: + controlnet_images = self.preprocess_controlnet_images( + latents.shape[0], + controlnet_images, + do_classifier_free_guidance=do_classifier_free_guidance, + height=image_height, + width=image_width, + ) + add_kwargs.update( + { + "controlnet_images": controlnet_images, + "controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device), + } + ) # UNet denoiser latents = self.denoise_latent( @@ -152,6 +168,8 @@ def run( denoising_steps=30, guidance=5.0, seed=None, + controlnet_images=None, + controlnet_scales=None, warmup=False, return_type="image", ): @@ -192,6 +210,8 @@ def run( denoising_steps=denoising_steps, guidance=guidance, seed=seed, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scales, warmup=warmup, return_type=return_type, ) @@ -204,6 +224,8 @@ def run( denoising_steps=denoising_steps, guidance=guidance, seed=seed, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scales, warmup=warmup, return_type=return_type, ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 63fa8acfbcc95..a04f05f4b23d8 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -9,6 +9,7 @@ packaging protobuf==3.20.3 psutil sympy +controlnet_aux # The following are for SDXL optimum==1.13.1 safetensors From e833d22f143f86529f4863b5da6cac4eb4a78bbb Mon Sep 17 00:00:00 2001 From: ivberg Date: Tue, 28 Nov 2023 16:58:51 -0800 Subject: [PATCH 3/9] Change QNN EP Profiling logs to output to CSV (#18201) ### Description Change QNN EP Profiling logs to output to CSV. Output is in a similar format to QNN SDK Tools (instead of to ORT logs) https://onnxruntime.ai/docs/execution-providers/QNN-ExecutionProvider.html#configuration-options (profiling_level) ### Motivation and Context It is hard to read and interpret QNN profiling logs in the ORT logs. --------- Co-authored-by: Hector Li --- .../qnn/builder/qnn_backend_manager.cc | 232 ++++++++++++++++-- .../qnn/builder/qnn_backend_manager.h | 12 +- 2 files changed, 227 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 03d6b46c528c3..ab0ea042ea5e2 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -4,6 +4,8 @@ #include "qnn_backend_manager.h" #include "qnn_model.h" #include +#include +#include #include "QnnOpDef.h" #include "HTP/QnnHtpPerfInfrastructure.h" #include "CPU/QnnCpuCommon.h" @@ -829,16 +831,49 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { if (num_events > 0) { LOGS(*logger_, VERBOSE) << "profile_events: " << profile_events << " num_events: " << num_events; - } - for (size_t event_idx = 0; event_idx < num_events; event_idx++) { - ORT_RETURN_IF_ERROR(ExtractProfilingEvent(*(profile_events + event_idx))); - ORT_RETURN_IF_ERROR(ExtractProfilingSubEvents(*(profile_events + event_idx))); + bool backendSupportsExtendedEventData = false; + Qnn_ErrorHandle_t resultPropertyHasCapability = + qnn_interface_.propertyHasCapability(QNN_PROPERTY_PROFILE_SUPPORTS_EXTENDED_EVENT); + uint16_t errorCodePropertyHasCapability = static_cast(resultPropertyHasCapability & 0xFFFF); + if (errorCodePropertyHasCapability == QNN_PROFILE_NO_ERROR) { + LOGS(*logger_, VERBOSE) << "The QNN backend supports extended event data."; + backendSupportsExtendedEventData = true; + } else { + LOGS(*logger_, VERBOSE) << "The QNN backend does not support extended event data."; + } + + // Write to CSV in append mode + const char* profilingCsvFilename = "qnn-profiling-data.csv"; + std::ifstream infile(profilingCsvFilename); + bool exists = infile.good(); + infile.close(); + + std::ofstream outfile(profilingCsvFilename, std::ios_base::app); + ORT_RETURN_IF(!outfile.is_open(), "Failed to open qnn-profiling-data.csv"); + // If file didn't exist before, write the header + if (!exists) { + outfile << "Msg Timestamp,Message,Time,Unit of Measurement,Timing Source,Event Level,Event Identifier\n"; + } + + for (size_t event_idx = 0; event_idx < num_events; event_idx++) { + ORT_RETURN_IF_ERROR( + ExtractProfilingEvent(*(profile_events + event_idx), "ROOT", outfile, backendSupportsExtendedEventData)); + ORT_RETURN_IF_ERROR( + ExtractProfilingSubEvents(*(profile_events + event_idx), outfile, backendSupportsExtendedEventData)); + } + + outfile.close(); + LOGS(*logger_, INFO) << "Wrote QNN profiling events (" << num_events << ") to qnn-profiling-data.csv"; } + return Status::OK(); } -Status QnnBackendManager::ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id) { +Status QnnBackendManager::ExtractProfilingSubEvents( + QnnProfile_EventId_t profile_event_id, + std::ofstream& outfile, + bool useExtendedEventData) { const QnnProfile_EventId_t* profile_sub_events{nullptr}; uint32_t num_sub_events{0}; auto result = qnn_interface_.profileGetSubEvents(profile_event_id, &profile_sub_events, &num_sub_events); @@ -846,28 +881,195 @@ Status QnnBackendManager::ExtractProfilingSubEvents(QnnProfile_EventId_t profile if (num_sub_events > 0) { LOGS(*logger_, VERBOSE) << "profile_sub_events: " << profile_sub_events << " num_sub_events: " << num_sub_events; - } - for (size_t sub_event_idx = 0; sub_event_idx < num_sub_events; sub_event_idx++) { - ORT_RETURN_IF_ERROR(ExtractProfilingEvent(*(profile_sub_events + sub_event_idx))); - ORT_RETURN_IF_ERROR(ExtractProfilingSubEvents(*(profile_sub_events + sub_event_idx))); + for (size_t sub_event_idx = 0; sub_event_idx < num_sub_events; sub_event_idx++) { + ORT_RETURN_IF_ERROR( + ExtractProfilingEvent(*(profile_sub_events + sub_event_idx), "SUB-EVENT", outfile, useExtendedEventData)); + ORT_RETURN_IF_ERROR( + ExtractProfilingSubEvents(*(profile_sub_events + sub_event_idx), outfile, useExtendedEventData)); + } + + LOGS(*logger_, INFO) << "Wrote QNN profiling sub events (" << num_sub_events << ") to qnn-profiling-data.csv"; } + return Status::OK(); } -Status QnnBackendManager::ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id) { +Status QnnBackendManager::ExtractProfilingEvent( + QnnProfile_EventId_t profile_event_id, + const std::string& eventLevel, + std::ofstream& outfile, + bool useExtendedEventData) { + if (useExtendedEventData) { + return ExtractProfilingEventExtended(profile_event_id, eventLevel, outfile); + } else { + return ExtractProfilingEventBasic(profile_event_id, eventLevel, outfile); + } +} + +Status QnnBackendManager::ExtractProfilingEventBasic( + QnnProfile_EventId_t profile_event_id, + const std::string& eventLevel, + std::ofstream& outfile) { QnnProfile_EventData_t event_data; auto result = qnn_interface_.profileGetEventData(profile_event_id, &event_data); - ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile event data."); + QnnProfile_Error_t errorCode = static_cast(result & 0xFFFF); + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile event data: " + std::string(QnnProfileErrorToString(errorCode))); + + std::string message = GetEventTypeString(event_data.type); + std::string unit = GetUnitString(event_data.unit); + + outfile << "UNKNOWN" + << "," + << message << "," + << event_data.value << "," + << unit << "," + << "BACKEND" + << "," + << eventLevel << "," + << (event_data.identifier ? event_data.identifier : "NULL") << "\n"; + + return Status::OK(); +} - LOGS(*logger_, VERBOSE) << "Profiling Event Info - Event Type: " << event_data.type - << ", Event Value: " << event_data.value - << ", Event Identifier: " << event_data.identifier - << ", Event Unit: " << event_data.unit; +Status QnnBackendManager::ExtractProfilingEventExtended( + QnnProfile_EventId_t profile_event_id, + const std::string& eventLevel, + std::ofstream& outfile) { + QnnProfile_ExtendedEventData_t event_data_extended; + auto resultGetExtendedEventData = qnn_interface_.profileGetExtendedEventData(profile_event_id, &event_data_extended); + QnnProfile_Error_t errorCode = static_cast(resultGetExtendedEventData & 0xFFFF); + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != errorCode, "Failed to get profile event data: " + std::string(QnnProfileErrorToString(errorCode))); + + std::string message = GetEventTypeString(event_data_extended.v1.type); + std::string unit = GetUnitString(event_data_extended.v1.unit); + + if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) { + outfile << event_data_extended.v1.timestamp << "," + << message << "," + << ExtractQnnScalarValue(event_data_extended.v1.value) << "," + << unit << "," + << "BACKEND" + << "," + << eventLevel << "," + << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n"; + } return Status::OK(); } +const std::string& QnnBackendManager::GetUnitString(QnnProfile_EventUnit_t unitType) { + const auto& unitStringMap = GetUnitStringMap(); + auto it = unitStringMap.find(unitType); + if (it != unitStringMap.end()) { + return it->second; + } + static const std::string unknown = "UNKNOWN"; + return unknown; +} + +const std::unordered_map& QnnBackendManager::GetUnitStringMap() { + static const std::unordered_map unitStringMap = { + {QNN_PROFILE_EVENTUNIT_MICROSEC, "US"}, + {QNN_PROFILE_EVENTUNIT_BYTES, "BYTES"}, + {QNN_PROFILE_EVENTUNIT_CYCLES, "CYCLES"}, + {QNN_PROFILE_EVENTUNIT_COUNT, "COUNT"}, + {QNN_PROFILE_EVENTUNIT_OBJECT, "OBJECT"}, + {QNN_PROFILE_EVENTUNIT_BACKEND, "BACKEND"}}; + return unitStringMap; +} + +const std::string QnnBackendManager::GetEventTypeString(QnnProfile_EventType_t eventType) { + // Interpret the event type + switch (eventType) { + case QNN_PROFILE_EVENTTYPE_INIT: + return "INIT"; + case QNN_PROFILE_EVENTTYPE_FINALIZE: + return "FINALIZE"; + case QNN_PROFILE_EVENTTYPE_EXECUTE: + return "EXECUTE"; + case QNN_PROFILE_EVENTTYPE_NODE: + return "NODE"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_QUEUE_WAIT: + return "EXECUTE QUEUE WAIT"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_PREPROCESS: + return "EXECUTE PREPROCESS"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_DEVICE: + return "EXECUTE DEVICE"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_POSTPROCESS: + return "EXECUTE POSTPROCESS"; + case QNN_PROFILE_EVENTTYPE_DEINIT: + return "DE-INIT"; + case QNN_PROFILE_EVENTTYPE_BACKEND: + return "BACKEND"; + default: + if (eventType > QNN_PROFILE_EVENTTYPE_BACKEND) { + return "BACKEND"; + } + return "UNKNOWN"; + } +} + +const char* QnnBackendManager::QnnProfileErrorToString(QnnProfile_Error_t error) { + switch (error) { + case QNN_PROFILE_NO_ERROR: + return "QNN_PROFILE_NO_ERROR"; + case QNN_PROFILE_ERROR_UNSUPPORTED: + return "QNN_PROFILE_ERROR_UNSUPPORTED"; + case QNN_PROFILE_ERROR_INVALID_ARGUMENT: + return "QNN_PROFILE_ERROR_INVALID_ARGUMENT"; + case QNN_PROFILE_ERROR_MEM_ALLOC: + return "QNN_PROFILE_ERROR_MEM_ALLOC"; + case QNN_PROFILE_ERROR_INVALID_HANDLE: + return "QNN_PROFILE_ERROR_INVALID_HANDLE"; + case QNN_PROFILE_ERROR_HANDLE_IN_USE: + return "QNN_PROFILE_ERROR_HANDLE_IN_USE"; + case QNN_PROFILE_ERROR_INCOMPATIBLE_EVENT: + return "QNN_PROFILE_ERROR_INCOMPATIBLE_EVENT"; + default: + return "UNKNOWN_ERROR"; + } +} + +const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) { + switch (scalar.dataType) { + case QNN_DATATYPE_INT_8: + return std::to_string(static_cast(scalar.int8Value)); + case QNN_DATATYPE_INT_16: + return std::to_string(scalar.int16Value); + case QNN_DATATYPE_INT_32: + return std::to_string(scalar.int32Value); + case QNN_DATATYPE_INT_64: + return std::to_string(scalar.int64Value); + case QNN_DATATYPE_UINT_8: + return std::to_string(static_cast(scalar.uint8Value)); + case QNN_DATATYPE_UINT_16: + return std::to_string(scalar.uint16Value); + case QNN_DATATYPE_UINT_32: + return std::to_string(scalar.uint32Value); + case QNN_DATATYPE_UINT_64: + return std::to_string(scalar.uint64Value); + case QNN_DATATYPE_FLOAT_16: + return std::to_string(scalar.floatValue); + case QNN_DATATYPE_FLOAT_32: + return std::to_string(scalar.floatValue); + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_32: + return std::to_string(scalar.int32Value); // Assume using int types for signed fixed points. + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_32: + return std::to_string(scalar.uint32Value); // Assume using unsigned int types for unsigned fixed points. + case QNN_DATATYPE_BOOL_8: + return scalar.bool8Value ? "true" : "false"; + case QNN_DATATYPE_STRING: + return scalar.stringValue ? scalar.stringValue : "NULL"; + default: + return "UNKNOWN"; + } +} + QnnBackendManager::~QnnBackendManager() { ReleaseResources(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 4edccea661642..bc05820da2f73 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -117,8 +117,8 @@ class QnnBackendManager { void Split(std::vector& split_string, const std::string& tokenized_string, const char separator); Status ExtractBackendProfilingInfo(); - Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id); - Status ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id); + Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, bool backendSupportsExtendedEventData); + Status ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, std::ofstream& outfile, bool backendSupportsExtendedEventData); void SetQnnBackendType(uint32_t backend_id); QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } @@ -175,6 +175,14 @@ class QnnBackendManager { return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id)); } + Status ExtractProfilingEventBasic(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, std::ofstream& outfile); + Status ExtractProfilingEventExtended(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, std::ofstream& outfile); + static const std::string& GetUnitString(QnnProfile_EventUnit_t unitType); + static const std::unordered_map& GetUnitStringMap(); + static const std::string GetEventTypeString(QnnProfile_EventType_t eventType); + static const std::string ExtractQnnScalarValue(const Qnn_Scalar_t& scalar); + const char* QnnProfileErrorToString(QnnProfile_Error_t error); + private: const std::string backend_path_; const logging::Logger* logger_ = nullptr; From 14a343441dcd530bec24e18e34c3c068993eb06c Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:14:20 -0800 Subject: [PATCH 4/9] Fix Objective-C static analysis build (#18606) - Patch abseil to fix a compile error about not finding `cxxabi.h`. - Fix some static analysis warnings. --- .../absl_gh_issue_1435_workaround.patch | 17 +++++++ include/onnxruntime/core/graph/graph.h | 2 +- .../core/providers/coreml/model/model.mm | 45 ++++++++++++------- .../mac-objc-static-analysis-ci-pipeline.yml | 5 +++ 4 files changed, 51 insertions(+), 18 deletions(-) create mode 100644 cmake/patches/abseil/absl_gh_issue_1435_workaround.patch diff --git a/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch b/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch new file mode 100644 index 0000000000000..0a864cdc019b4 --- /dev/null +++ b/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch @@ -0,0 +1,17 @@ +--- absl/container/internal/layout.h 2023-11-28 09:35:48 ++++ absl/container/internal/layout.updated.h 2023-11-28 10:13:14 +@@ -181,9 +181,11 @@ + #include + #endif + +-#if defined(__GXX_RTTI) +-#define ABSL_INTERNAL_HAS_CXA_DEMANGLE +-#endif ++// Comment out ABSL_INTERNAL_HAS_CXA_DEMANGLE definition to work around this issue: ++// https://github.com/abseil/abseil-cpp/issues/1435 ++// #if defined(__GXX_RTTI) ++// #define ABSL_INTERNAL_HAS_CXA_DEMANGLE ++// #endif + + #ifdef ABSL_INTERNAL_HAS_CXA_DEMANGLE + #include diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index fe0734c51f807..22827d43b200f 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -668,7 +668,7 @@ class Node { The Graph representation containing the graph inputs and outputs, the Node instances, and the edges connecting the nodes. */ -class Graph { +class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve existing data member order for readability public: /** Gets the Graph name. */ const std::string& Name() const noexcept; diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 4a6743e9e5c52..32821fd02647a 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -32,6 +32,13 @@ using namespace onnxruntime::coreml; namespace { +// Converts a UTF8 const char* to an NSString. Throws on failure. +NSString* _Nonnull Utf8StringToNSString(const char* utf8_str) { + NSString* result = [NSString stringWithUTF8String:utf8_str]; + ORT_ENFORCE(result != nil, "NSString conversion failed."); + return result; +} + /** * Computes the static output shape used to allocate the output tensor. * `inferred_shape` is the inferred shape known at model compile time. It may contain dynamic dimensions (-1). @@ -152,19 +159,20 @@ Status CreateInputFeatureProvider(const std::unordered_map&)inputs get_output_tensor_mutable_raw_data_fn API_AVAILABLE_OS_VERSIONS; -@property MLModel* model API_AVAILABLE_OS_VERSIONS; +@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS; @end @@ -297,12 +305,15 @@ - (void)dealloc { - (Status)loadModel { NSError* error = nil; NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; - NSAssert(modelUrl != nil, @"modelUrl must not be nil"); + if (modelUrl == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); + } + NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model ", - [[error localizedDescription] cStringUsingEncoding:NSUTF8StringEncoding]); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", + [[error localizedDescription] UTF8String]); } compiled_model_path_ = [compileUrl path]; @@ -313,9 +324,9 @@ - (Status)loadModel { : MLComputeUnitsAll; _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; - if (error != NULL) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error Creating MLModel ", - [[error localizedDescription] cStringUsingEncoding:NSUTF8StringEncoding]); + if (_model == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", + (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); } return Status::OK(); @@ -327,7 +338,7 @@ - (Status)predict:(const std::unordered_map&)inputs Status status = Status::OK(); ORT_TRY { if (_model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model is not loaded"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); } id input_features; @@ -342,12 +353,12 @@ - (Status)predict:(const std::unordered_map&)inputs if (error != nil) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", - [[error localizedDescription] cStringUsingEncoding:NSUTF8StringEncoding]); + [[error localizedDescription] UTF8String]); } for (const auto& [output_name, output_tensor_info] : outputs) { MLFeatureValue* output_value = - [output_features featureValueForName:[NSString stringWithUTF8String:output_name.c_str()]]; + [output_features featureValueForName:Utf8StringToNSString(output_name.c_str())]; if (output_value == nil) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); @@ -452,7 +463,7 @@ Status Predict(const std::unordered_map& inputs, return status; } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+ "); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+"); } Status Execution::Predict(const std::unordered_map& inputs, @@ -468,7 +479,7 @@ Status Predict(const std::unordered_map& inputs, } } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+ "); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+"); } Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) diff --git a/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml index 6893fb95cfec5..482279fa07225 100644 --- a/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml @@ -29,6 +29,11 @@ jobs: --build --parallel --target onnx_proto displayName: Generate compile_commands.json and ONNX protobuf files + - script: | + patch < "$(Build.SourcesDirectory)/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch" + workingDirectory: "$(Build.BinariesDirectory)/Debug/_deps/abseil_cpp-src" + displayName: Apply absl_gh_issue_1435_workaround.patch + - script: | set -e From 38b640c797613e2396f2975ccd4d8ff0e95a5baa Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 30 Nov 2023 00:00:23 +0800 Subject: [PATCH 5/9] [WebNN EP] Re-implement Unsqueeze, Squeeze, Flatten with WebNN's reshape (#18585) WebNN will not provide `unsqueeze`, `squeeze`, `flatten2d` ops, as it can be easily implemented by reshape. --- .../core/providers/webnn/builders/helper.h | 6 +-- .../webnn/builders/impl/flatten_op_builder.cc | 20 ++++++--- .../impl/squeeze_unsqueeze_op_builder.cc | 43 ++++++++++++++----- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 28b54b9c9cf8d..617108c57d8a2 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -153,7 +153,7 @@ static const InlinedHashMap op_map = { {"Erf", {"erf", false}}, {"Exp", {"exp", false}}, {"Expand", {"expand", false}}, - {"Flatten", {"flattenTo2d", false}}, + {"Flatten", {"reshape", true}}, {"Floor", {"floor", true}}, {"Gather", {"gather", false}}, {"Gemm", {"gemm", true}}, @@ -206,12 +206,12 @@ static const InlinedHashMap op_map = { {"Softmax", {"softmax", true}}, {"Split", {"split", true}}, {"Sqrt", {"sqrt", false}}, - {"Squeeze", {"squeeze", false}}, + {"Squeeze", {"reshape", true}}, {"Sub", {"sub", true}}, {"Tan", {"tan", false}}, {"Tanh", {"tanh", true}}, {"Transpose", {"transpose", true}}, - {"Unsqueeze", {"unsqueeze", false}}, + {"Unsqueeze", {"reshape", true}}, {"Where", {"elementwiseIf", false}}, }; diff --git a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc index 6c59ca451f333..f0df27b523dfc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc @@ -36,14 +36,20 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, int64_t rank = input_shape.size(); NodeAttrHelper helper(node); int64_t axis = helper.Get("axis", 1); - ORT_ENFORCE(axis >= -rank && axis <= rank, "axis ", axis, - " is not in valid range [-", rank, ",", rank, "]"); - if (axis < 0) { - axis += rank; - } + axis = HandleNegativeAxis(axis, rank); + + // Use WebNN's reshape to implement Flatten. + int64_t num_pre_axis_elements = std::accumulate( + input_shape.begin(), input_shape.begin() + static_cast(axis), 1, std::multiplies()); + int64_t num_post_axis_elements = std::accumulate( + input_shape.begin() + static_cast(axis), input_shape.end(), 1, std::multiplies()); + + std::vector new_shape = {SafeInt(num_pre_axis_elements), + SafeInt(num_post_axis_elements)}; + emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val output = model_builder.GetBuilder().call("flattenTo2d", inputs, - static_cast(axis)); + emscripten::val output = model_builder.GetBuilder().call( + "reshape", inputs, emscripten::val::array(new_shape)); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 1c0258944dbe9..2a1672c001b0e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -56,6 +56,7 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil emscripten::val options = emscripten::val::object(); std::vector axes_data; + auto rank = input_rank; if (node.SinceVersion() >= 13 && input_defs.size() > 1) { // Input axes is provided, use axes initializer data. @@ -63,35 +64,57 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil const auto& axes_tensor = *initializers.at(input_defs[1]->Name()); Initializer axes_initializer(axes_tensor); const auto axes_data_span = axes_initializer.DataAsSpan(); - const auto output_rank = input_rank + axes_data_span.size(); + if (op_type == "Unsqueeze") { + // Unsqueeze should check the expanded rank. + rank = input_rank + axes_data_span.size(); + } std::transform( axes_data_span.begin(), axes_data_span.end(), std::back_inserter(axes_data), - [output_rank](int64_t axis) -> int32_t { return SafeInt(HandleNegativeAxis(axis, output_rank)); }); + [rank](int64_t axis) -> int32_t { return SafeInt(HandleNegativeAxis(axis, rank)); }); } else { NodeAttrHelper helper(node); if (helper.HasAttr("axes")) { auto axes = helper.Get("axes", std::vector{}); - const auto output_rank = input_rank + axes.size(); + if (op_type == "Unsqueeze") { + // Unsqueeze should check the expanded rank. + rank = input_rank + axes.size(); + } std::transform( axes.begin(), axes.end(), std::back_inserter(axes_data), - [output_rank](int64_t axis) -> int32_t { return SafeInt(HandleNegativeAxis(axis, output_rank)); }); + [rank](int64_t axis) -> int32_t { return SafeInt(HandleNegativeAxis(axis, rank)); }); } } - if (axes_data.size() > 0) { - options.set("axes", emscripten::val::array(axes_data)); - } - emscripten::val output = emscripten::val::undefined(); + // Use WebNN's reshape to implement Squeeze/Unsqueeze. + std::vector new_shape; + std::transform( + input_shape.begin(), input_shape.end(), std::back_inserter(new_shape), + [](int64_t data) -> uint32_t { return SafeInt(data); }); + // Sort axes_data in ascending order. + std::sort(axes_data.begin(), axes_data.end()); if (op_type == "Squeeze") { - output = model_builder.GetBuilder().call("squeeze", input, options); + if (!axes_data.empty()) { + for (auto axis = axes_data.rbegin(); axis != axes_data.rend(); ++axis) { + size_t index = *axis; + new_shape.erase(new_shape.begin() + index); + } + } else { + // Remove all the single dimensions. + new_shape.erase( + std::remove_if(new_shape.begin(), new_shape.end(), [](uint32_t axis) { return axis == 1; }), new_shape.end()); + } } else if (op_type == "Unsqueeze") { - output = model_builder.GetBuilder().call("unsqueeze", input, options); + // Expand new_shape according to axes_data. + for (const int32_t& axis : axes_data) { + new_shape.insert(new_shape.begin() + axis, 1); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); } + output = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } From 68209307daadfe21a74a36d44c4c170b91141772 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 30 Nov 2023 02:32:42 +0800 Subject: [PATCH 6/9] Replace all Azure-Pipelines-EO-Windows2022-aiinfrat to Onnxruntime-Win-CPU-2022 (#18614) ### Description Replace all Azure-Pipelines-EO-Windows2022-aiinfrat to Onnxruntime-Win-CPU-2022 ### Motivation and Context Reduce the maintenance cost --- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 4 ++-- .../github/azure-pipelines/npm-packaging-pipeline.yml | 4 ++-- tools/ci_build/github/azure-pipelines/post-merge-jobs.yml | 2 +- .../github/azure-pipelines/py-package-test-pipeline.yml | 2 +- .../azure-pipelines/stages/nuget-combine-cuda-stage.yml | 6 ++---- .../templates/ondevice-training-cpu-packaging-pipeline.yml | 2 +- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 67fa78da003a3..db1dcc3af792e 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -673,7 +673,7 @@ stages: clean: all # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing - pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + pool: 'Onnxruntime-Win-CPU-2022' variables: breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] @@ -858,7 +858,7 @@ stages: clean: all # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing - pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + pool: 'Onnxruntime-Win-CPU-2022' variables: breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index b98837078b2d5..fd26128b8b29a 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -48,7 +48,7 @@ stages: RunWebGpuTestsForDebugBuild: false RunWebGpuTestsForReleaseBuild: true WebGpuPoolName: 'onnxruntime-Win2022-webgpu-A10' - WebCpuPoolName: 'Azure-Pipelines-EO-Windows2022-aiinfra' + WebCpuPoolName: 'Onnxruntime-Win-CPU-2022' - template: templates/react-native-ci.yml parameters: @@ -65,7 +65,7 @@ stages: - Build_web_Debug jobs: - job: Download_Node_Package_And_Publish_Validation_Script - pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + pool: 'Onnxruntime-Win-CPU-2022' variables: runCodesignValidationInjection: false timeoutInMinutes: 10 diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index c86920422b6f0..706c87fc079ca 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -8,7 +8,7 @@ stages: BuildStaticLib: true ExtraBuildArgs: '' UseWebPoolName: true - WebCpuPoolName: 'Azure-Pipelines-EO-Windows2022-aiinfra' + WebCpuPoolName: 'Onnxruntime-Win-CPU-2022' # This stage is to test if the combined build works on # o Windows ARM64 diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c8aac6e8b130d..55d3150f21aa3 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -84,7 +84,7 @@ stages: skipComponentGovernanceDetection: true workspace: clean: all - pool: Azure-Pipelines-EO-Windows2022-aiinfra + pool: Onnxruntime-Win-CPU-2022 steps: - task: PowerShell@2 displayName: 'Add Build Tag' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index b69e75856c39f..d009e15559180 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -27,9 +27,7 @@ stages: - job: workspace: clean: all - # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. - # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing - pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + pool: 'Onnxruntime-Win-CPU-2022' variables: breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] @@ -225,4 +223,4 @@ stages: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' - condition: always() \ No newline at end of file + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index 24e46066a1f10..29cea63df1662 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -141,7 +141,7 @@ stages: clean: all # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing - pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + pool: 'Onnxruntime-Win-CPU-2022' variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} breakCodesignValidationInjection: ${{ parameters.DoEsrp }} From d2dfbf41795e72911643e2ffcadac069b72580bd Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 29 Nov 2023 10:44:59 -0800 Subject: [PATCH 7/9] Add float16 type support to SplitToSequence and make code type independent (#18594) ### Description Add support for `float16` type to address the below issue. Re-work the code to make it type independent. This reduces binary size by ~11 K. ![image](https://github.com/microsoft/onnxruntime/assets/11303988/1a77c7bc-34a8-478c-a16a-abd94062c6c6) ### Motivation and Context This PR addresses https://github.com/microsoft/onnxruntime/issues/18481 --- docs/OperatorKernels.md | 2 +- .../providers/cpu/sequence/sequence_ops.cc | 111 +++++++++--------- .../providers/cpu/sequence/sequence_ops.h | 3 +- .../cpu/sequence/sequence_ops_test.cc | 81 +++++++++---- 4 files changed, 114 insertions(+), 83 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 16df788c284ee..edf249a816923 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -373,7 +373,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)| +|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc index 4759938cd8250..8064bc0a58cb1 100644 --- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc +++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc @@ -334,27 +334,14 @@ Status SequenceConstruct::Compute(OpKernelContext* context) const { // SplitToSequence -namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0, - float, double, int32_t, int64_t, std::string); -} // namespace op_kernel_type_control - -namespace { -using EnabledSplitToSequenceDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0); -} // namespace - ONNX_CPU_OPERATOR_KERNEL( SplitToSequence, 11, KernelDefBuilder() .TypeConstraint("T", - BuildKernelDefConstraintsFromTypeList()) + BuildKernelDefConstraints()) .TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes()) - .TypeConstraint("I", std::vector{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("I", BuildKernelDefConstraints()), SplitToSequence); SplitToSequence::SplitToSequence(const OpKernelInfo& info) : OpKernel(info) { @@ -366,29 +353,14 @@ Status SplitToSequence::Compute(OpKernelContext* context) const { const Tensor& input = *context->Input(0); const Tensor* p_split_input = context->Input(1); - Status status; - - if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataTypeString()) - status = ComputeImpl(*context, input, p_split_input); - else - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SplitToSequence operator does not support ", input.DataType(), " yet"); - - return status; + return ComputeImpl(*context, input, p_split_input); } Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar, int64_t& num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, bool& is_uneven_split, int& num_remaining_splits, - std::vector& split_sizes) const { + InlinedVector& split_sizes) const { auto input_dims = input_shape.GetDims(); const auto num_dimensions = gsl::narrow_cast(input_shape.NumDimensions()); axis = HandleNegativeAxis(axis_, num_dimensions); // handle negative and enforce axis is valid @@ -416,7 +388,7 @@ Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_ // populate split_sizes with the same size for each output num_outputs = split_dim_size; // https://github.com/onnx/onnx/issues/2396 - split_sizes = std::vector(static_cast(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_); + split_sizes = InlinedVector(static_cast(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_); } else { auto split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL); if (split_size_sum != split_dim_size) { @@ -453,7 +425,7 @@ static int64_t GetScalarSplitInput(const Tensor& tensor) { return retval; } -static void GetSplitSizesInput(const Tensor& tensor, std::vector& split_sizes) { +static void GetSplitSizesInput(const Tensor& tensor, InlinedVector& split_sizes) { auto num_elems = tensor.Shape().Size(); split_sizes.reserve(onnxruntime::narrow(num_elems)); if (tensor.IsDataType()) { @@ -467,13 +439,8 @@ static void GetSplitSizesInput(const Tensor& tensor, std::vector& split } } -template Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& input, const Tensor* p_split_input) const { - if (!utils::HasType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build."); - } - auto& input_shape = input.Shape(); int64_t num_outputs = 0; int64_t axis = axis_; @@ -484,7 +451,9 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu bool is_split_input_scalar = false; bool is_uneven_split = false; int num_remaining_splits = 0; - std::vector split_sizes; + InlinedVector split_sizes; + const bool is_string_type = input.IsDataTypeString(); + const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size(); // figure out split_scalar or split_sizes if (p_split_input) { @@ -520,8 +489,8 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu // copy dimensions so we can update the selected axis in place auto output_dimensions = input_shape.AsShapeVector(); - int64_t input_offset = 0; - const T* input_data = input.Data(); + SafeInt input_offset = 0; + const void* input_data = input.DataRaw(); for (int i = 0; i < num_outputs; ++i) { // update size of dimension for axis we're splitting on while considering uneven split int split_size; @@ -535,20 +504,50 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context.GetTempSpaceAllocator(&alloc)); Tensor output_tensor(input.DataType(), onnxruntime::TensorShape(output_dimensions), alloc); - T* output_data = output_tensor.MutableData(); - - ::onnxruntime::math::CopyMatrix( - before_dims, // M - split_size * after_dims_excluding_split, // N - static_cast(input_data + input_offset), // A - after_dims_including_split_axis, // lda - static_cast(output_data), // B - split_size * after_dims_excluding_split, // ldb - [](const T* src, T* dst, size_t count) { - copy_data(src, dst, count); - }); - - input_offset += static_cast(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration + void* output_data = output_tensor.MutableDataRaw(); + + const auto M = before_dims; + const auto* A = static_cast(input_data) + static_cast(input_offset * element_size); + const auto lda = after_dims_including_split_axis; + auto* B = output_data; + + const auto N = split_size * after_dims_excluding_split; + const auto ldb = N; + + if (is_string_type) { + const auto* src = reinterpret_cast(A); + auto* dst = reinterpret_cast(B); + if (lda == N) { + copy_data(src, dst, static_cast(M * N)); + } else { + size_t lda_offset = 0; + size_t ldb_offset = 0; + for (size_t idx = 0; idx < static_cast(M); ++idx, + lda_offset += lda, ldb_offset += ldb) { + copy_data(src + lda_offset, dst + ldb_offset, static_cast(N)); + } + } + } else { + if (lda == N) { + // if the data is contiguous, we can just copy the data + const size_t bytes_to_copy = static_cast(N) * static_cast(M) * element_size; + memcpy(B, A, bytes_to_copy); + } else { + // otherwise we need to copy each row + const size_t row_bytes = SafeInt(N) * element_size; + const auto lda_bytes_inc = SafeInt(lda) * element_size; + const auto ldb_bytes_inc = SafeInt(ldb) * element_size; + SafeInt lda_bytes_offset = 0; + SafeInt ldb_bytes_offset = 0; + for (size_t idx = 0; idx < static_cast(M); ++idx, + lda_bytes_offset += lda_bytes_inc, ldb_bytes_offset += ldb_bytes_inc) { + memcpy(reinterpret_cast(B) + static_cast(ldb_bytes_offset), + reinterpret_cast(A) + static_cast(lda_bytes_offset), row_bytes); + } + } + } + + input_offset += SafeInt(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration // if keep_dims = 0, reshape the tensor by dropping the dimension corresponding to 'axis' if (use_keep_dims && keepdims_ == 0) { diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.h b/onnxruntime/core/providers/cpu/sequence/sequence_ops.h index 9466d3f0fd108..ccca226fb07ee 100644 --- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.h +++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.h @@ -60,13 +60,12 @@ class SplitToSequence final : public OpKernel { Status Compute(OpKernelContext* context) const override; private: - template Status ComputeImpl(OpKernelContext& context, const Tensor& input, const Tensor* p_split_input) const; Status PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar, int64_t& num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, bool& is_uneven_split, int& num_remaining_splits, - std::vector& split_sizes) const; + InlinedVector& split_sizes) const; int64_t axis_{}; int64_t keepdims_{1}; const int64_t DEFAULT_LENGTH_EACH_OUTPUT_ = 1; diff --git a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc index d29aac81150c5..60e75811e4333 100644 --- a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc +++ b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc @@ -330,15 +330,26 @@ TEST(SequenceOpsTest, SequenceConstructPositive) { // SplitToSequence template -static std::vector GetConsequtiveVector(T start, int num) { +static std::vector GetConsecutiveVector(T start, size_t num) { std::vector inputv(num); std::iota(inputv.begin(), inputv.end(), start); return inputv; } +template <> +std::vector GetConsecutiveVector(MLFloat16 start, size_t num) { + std::vector inputv; + inputv.reserve(num); + float start_f = start.ToFloat(); + for (size_t i = 0; i < num; ++i) { + inputv.push_back(MLFloat16{start_f + static_cast(i)}); + } + return inputv; +} + TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloat) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {4, 2}, GetConsequtiveVector(1.f, 8)); + test.AddInput("input", {4, 2}, GetConsecutiveVector(1.f, 8)); test.AddInput("split", {1, 2}, {2, 2}); SeqTensors output; output.AddTensor({2, 2}, {1.f, 2.f, 3.f, 4.f}); @@ -347,9 +358,31 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloat) { test.Run(); } +TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitMLFloat16) { + OpTester test("SplitToSequence", 11); + test.AddInput("input", {4, 2}, GetConsecutiveVector(MLFloat16::One, 8)); + test.AddInput("split", {1, 2}, {2, 2}); + SeqTensors output; + + std::vector tensor_1; + const auto data_1 = {1.f, 2.f, 3.f, 4.f}; + for (auto f : data_1) + tensor_1.push_back(MLFloat16{f}); + + std::vector tensor_2; + const auto data_2 = {5.f, 6.f, 7.f, 8.f}; + for (auto f : data_2) + tensor_2.push_back(MLFloat16{f}); + + output.AddTensor({2, 2}, tensor_1); + output.AddTensor({2, 2}, tensor_2); + test.AddSeqOutput("S2", output); + test.Run(); +} + TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitLong) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {4, 2}, GetConsequtiveVector(1, 8)); + test.AddInput("input", {4, 2}, GetConsecutiveVector(1, 8)); test.AddInput("split", {1, 2}, {2, 2}); SeqTensors output; output.AddTensor({2, 2}, {1, 2, 3, 4}); @@ -360,7 +393,7 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitLong) { TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloatScalarSplit) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {4, 2}, GetConsequtiveVector(1.f, 8)); + test.AddInput("input", {4, 2}, GetConsecutiveVector(1.f, 8)); test.AddInput("split", {}, {2}); SeqTensors output; output.AddTensor({2, 2}, {1.f, 2.f, 3.f, 4.f}); @@ -371,7 +404,7 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloatScalarSplit) { TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitly) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {4, 2}, GetConsequtiveVector(1.f, 8)); + test.AddInput("input", {4, 2}, GetConsecutiveVector(1.f, 8)); int64_t axis = 0; test.AddAttribute("axis", axis); SeqTensors output; @@ -385,7 +418,7 @@ TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitly) { TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {2, 2, 6}, GetConsequtiveVector(1.f, 2 * 2 * 6)); + test.AddInput("input", {2, 2, 6}, GetConsecutiveVector(1.f, 2 * 2 * 6)); int64_t axis = 2; test.AddAttribute("axis", axis); test.AddInput("split", {}, {2}); @@ -411,11 +444,11 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) { TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {5, 2}, GetConsequtiveVector(1.f, 10)); + test.AddInput("input", {5, 2}, GetConsecutiveVector(1.f, 10)); test.AddInput("split", {}, {2}); SeqTensors output; - output.AddTensor({2, 2}, GetConsequtiveVector(1.f, 4)); - output.AddTensor({2, 2}, GetConsequtiveVector(5.f, 4)); + output.AddTensor({2, 2}, GetConsecutiveVector(1.f, 4)); + output.AddTensor({2, 2}, GetConsecutiveVector(5.f, 4)); output.AddTensor({1, 2}, {9.f, 10.f}); test.AddSeqOutput("S2", output); test.Run(); @@ -423,22 +456,22 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) { TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat2) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {17, 2}, GetConsequtiveVector(1.f, 34)); + test.AddInput("input", {17, 2}, GetConsecutiveVector(1.f, 34)); test.AddInput("split", {}, {3}); SeqTensors output; - output.AddTensor({3, 2}, GetConsequtiveVector(1.f, 6)); - output.AddTensor({3, 2}, GetConsequtiveVector(7.f, 6)); - output.AddTensor({3, 2}, GetConsequtiveVector(13.f, 6)); - output.AddTensor({3, 2}, GetConsequtiveVector(19.f, 6)); - output.AddTensor({3, 2}, GetConsequtiveVector(25.f, 6)); - output.AddTensor({2, 2}, GetConsequtiveVector(31.f, 4)); + output.AddTensor({3, 2}, GetConsecutiveVector(1.f, 6)); + output.AddTensor({3, 2}, GetConsecutiveVector(7.f, 6)); + output.AddTensor({3, 2}, GetConsecutiveVector(13.f, 6)); + output.AddTensor({3, 2}, GetConsecutiveVector(19.f, 6)); + output.AddTensor({3, 2}, GetConsecutiveVector(25.f, 6)); + output.AddTensor({2, 2}, GetConsecutiveVector(31.f, 4)); test.AddSeqOutput("S2", output); test.Run(); } TEST(SequenceOpsTest, SplitToSequence_PositiveAxisUnevenSplit) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {2, 5}, GetConsequtiveVector(1.f, 10)); + test.AddInput("input", {2, 5}, GetConsecutiveVector(1.f, 10)); test.AddInput("split", {}, {2}); int64_t axis = 1; test.AddAttribute("axis", axis); @@ -452,33 +485,33 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisUnevenSplit) { TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitlyDontKeepDims3Dim) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {2, 3, 4}, GetConsequtiveVector(1.f, 2 * 3 * 4)); + test.AddInput("input", {2, 3, 4}, GetConsecutiveVector(1.f, 2 * 3 * 4)); test.AddAttribute("keepdims", 0); int64_t axis = 0; test.AddAttribute("axis", axis); SeqTensors output; - output.AddTensor({3, 4}, GetConsequtiveVector(1.f, 12)); - output.AddTensor({3, 4}, GetConsequtiveVector(13.f, 12)); + output.AddTensor({3, 4}, GetConsecutiveVector(1.f, 12)); + output.AddTensor({3, 4}, GetConsecutiveVector(13.f, 12)); test.AddSeqOutput("S2", output); test.Run(); } TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitlyDontKeepDims2Dim) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {2, 3}, GetConsequtiveVector(1.f, 2 * 3)); + test.AddInput("input", {2, 3}, GetConsecutiveVector(1.f, 2 * 3)); test.AddAttribute("keepdims", 0); int64_t axis = 0; test.AddAttribute("axis", axis); SeqTensors output; - output.AddTensor({3}, GetConsequtiveVector(1.f, 3)); - output.AddTensor({3}, GetConsequtiveVector(4.f, 3)); + output.AddTensor({3}, GetConsecutiveVector(1.f, 3)); + output.AddTensor({3}, GetConsecutiveVector(4.f, 3)); test.AddSeqOutput("S2", output); test.Run(); } TEST(SequenceOpsTest, SplitToSequence_PositiveAxisDontKeepDims) { OpTester test("SplitToSequence", 11); - test.AddInput("input", {2, 3, 4}, GetConsequtiveVector(1.f, 2 * 3 * 4)); + test.AddInput("input", {2, 3, 4}, GetConsecutiveVector(1.f, 2 * 3 * 4)); test.AddAttribute("keepdims", 0); int64_t axis = 2; test.AddAttribute("axis", axis); From 483c490ec4db2d2b5001e42f5c842abfc9e379af Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 29 Nov 2023 14:38:44 -0800 Subject: [PATCH 8/9] Refine error checks in onnxruntime/core/providers/coreml/model/model.mm. (#18620) #18606 updated the original error checks to check that the returned object != nil to appease the static analyzer. However, per the API docs, checking `error != nil` is the way to determine whether an error occurred. This change adds back the `error != nil` check to be safe. --- onnxruntime/core/providers/coreml/model/model.mm | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 32821fd02647a..155201ad4c39c 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -159,7 +159,7 @@ Status CreateInputFeatureProvider(const std::unordered_map Date: Wed, 29 Nov 2023 15:30:33 -0800 Subject: [PATCH 9/9] [JS/Web] Add uniforms to Einsum (#18531) ### Description Add uinforms to Einsum ### Motivation and Context Improve performance. --- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 220 +++++++++------ js/web/test/data/ops/einsum.jsonc | 330 +++++++++++++++++++++- 2 files changed, 453 insertions(+), 97 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index a233d37a79e65..4db7c04ad67be 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -4,9 +4,10 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -101,7 +102,7 @@ class EinsumEquation { this.outputDims.push(info.dimValue); } }); - this.rhs = this.processTerm(rhs, true, this.outputDims); + this.rhs = this.processTerm(rhs, false, this.outputDims); } // End of EinsumEqation constructor // Add a symbol to the equation @@ -157,12 +158,12 @@ class EinsumEquation { } // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling for (let j = 0; j < ellipsisDims.length; j++) { - const symbol = String.fromCharCode('0'.charCodeAt(0) + i); + const symbol = String.fromCharCode('0'.charCodeAt(0) + j); einsumTerm.addSymbol(symbol, i + j); this.addSymbol(symbol, dims[nextDim++], index); } } else { - einsumTerm.addSymbol(symbol, i); + einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0)); this.addSymbol(symbol, dims[nextDim++], index); } }); @@ -177,101 +178,132 @@ class EinsumEquation { outputDims: number[]; // Output dimensions of the equation } // End of class EinsumEquation -const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { - const dataType = inputs[0].dataType; - const inputVars = new Array(inputs.length); - for (let i = 0; i < inputs.length; ++i) { - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); - } - const outputShape = einsumEquation.outputDims; - const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', dataType, outputShape); - const idxCopy: string[] = []; - const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); - const initProd = 'var prod = 1.0;'; - const initSum = 'var sum = 0.0;'; - const updateSum = 'sum += prod;'; - const reduceOpsSetIndices: string[] = []; - const reduceOpsLoopHeaders: string[] = []; - const reduceOpsLoopFooters: string[] = []; - const reduceOpCompute: string[] = []; - const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length; - einsumEquation.symbolToInfo.forEach((info, symbol) => { - if (rhsSymbols.includes(symbol)) { - const outputIndex = rhsSymbols.indexOf(symbol); - einsumEquation.lhs.forEach((term, i) => { - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); +const appendMax = (name: string): string => name + '_max'; + +const createEinsumProgramInfo = + (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, + einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { + const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); + const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + const outputSize = ShapeUtil.size(outputShape); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); + const uniformsSymbols = + [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; + const initProd = 'var prod = 1.0;'; + const initSum = 'var sum = 0.0;'; + const updateSum = 'sum += prod;'; + const reduceOpsSetIndices: string[] = []; + const reduceOpsLoopHeaders: string[] = []; + const reduceOpsLoopFooters: string[] = []; + const reduceOpCompute: string[] = []; + const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size; + einsumEquation.symbolToInfo.forEach((info, symbol) => { + if (einsumEquation.rhs.symbolToIndices.has(symbol)) { + const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0]; + if (outputIndex !== undefined) { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + idxCopy.push(`${ + inputVars[i].indicesSet( + `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); + }); + } + }); + } + } else { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + }); + reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); + } + }); + reduceOpsLoopHeaders.push( + `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`); + reduceOpsLoopFooters.push('}'); } - indices.forEach((index) => { - idxCopy.push(`${ - inputVars[i].indicesSet(`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); - }); - } - }); - } else { - einsumEquation.lhs.forEach((term, i) => { - const info = einsumEquation.symbolToInfo.get(symbol); - if (info === undefined) { - throw new Error('Invalid symbol error'); - } - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); + }); + const reduceOps = isReduceOpsWithoutLoop ? + [ + ...idxCopy, + `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` + ] : + [ + ...idxCopy, + initSum, + ...reduceOpsLoopHeaders, + ...reduceOpsSetIndices, + initProd, + ...reduceOpCompute, + updateSum, + ...reduceOpsLoopFooters, + ]; + return ` + ${ + shaderHelper + .registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'}))) + .registerUniform('outputSize', 'u32') + .declareVariables(...inputVars, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var outputIndices = ${output.offsetToIndices('global_idx')}; + ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${reduceOps.join('\n')}; + ${output.setByOffset('global_idx', 'sum')}; + }`; + }; + return { + name: 'Einsum', + shaderCache: { + hint: einsumEquation.equation, + inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') + }, + getRunData: () => { + // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The + // filter is added to make sure that dimValue is never 0. + const programUniformsInit: ProgramUniform[] = + uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) + .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: 'uint32', data: outputSize}); + const programUniforms: ProgramUniform[] = + inputShapes.filter((_, index) => enableInputShapesUniforms[index]) + .map((dims, _) => [...createTensorShapeVariables(dims)]) + .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); } - indices.forEach((index) => { - reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + return ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }); - reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); - } - }); - reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${ - einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`); - reduceOpsLoopFooters.push('}'); - } - }); - const reduceOps = isReduceOpsWithoutLoop ? - [ - ...idxCopy, - `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` - ] : - [ - ...idxCopy, - initSum, - ...reduceOpsLoopHeaders, - ...reduceOpsSetIndices, - initProd, - ...reduceOpCompute, - updateSum, - ...reduceOpsLoopFooters, - ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - var outputIndices = ${output.offsetToIndices('global_idx')}; - ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} - ${reduceOps.join('\n')}; - ${output.setByOffset('global_idx', 'sum')}; - }`; - return { - name: 'Einsum', - shaderCache: {hint: einsumEquation.equation}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} - }), - getShaderSource, - }; -}; + }, + getShaderSource, + }; + }; export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - context.compute(createEinsumProgramInfo(context.inputs, einsumEquation)); + const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); + const outputShape = einsumEquation.outputDims; + const inputShapes = context.inputs.map((input, _) => input.dims); + context.compute(createEinsumProgramInfo( + enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/test/data/ops/einsum.jsonc b/js/web/test/data/ops/einsum.jsonc index baf30cf982148..45bba6a121bd1 100644 --- a/js/web/test/data/ops/einsum.jsonc +++ b/js/web/test/data/ops/einsum.jsonc @@ -171,7 +171,7 @@ ], "cases": [ { - "name": "Diagonal elementwise multiplication", + "name": "Diagonal elements dot product", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -210,7 +210,7 @@ ], "cases": [ { - "name": "Dotproduct", + "name": "diagonal elements multiplication", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -233,6 +233,240 @@ } ] }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,ij -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 5, 0, 0, 0, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dot product/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "outer product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 2, 4, 6, 3, 6, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,ij -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 5, 0, 0, 0, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dot product/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "outer product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 2, 4, 6, 3, 6, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "einsum", "operator": "Einsum", @@ -249,7 +483,7 @@ ], "cases": [ { - "name": "Multiply", + "name": "Multiply (2,3) X (3,4) -> (2,4)", "inputs": [ { "data": [1, 2, 3, 4, 5, 6], @@ -269,6 +503,28 @@ "type": "float32" } ] + }, + { + "name": "Multiply (2,6) X (6,4) -> (2,4)", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [6, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [220, 235, 250, 265, 580, 631, 682, 733], + "dims": [2, 4], + "type": "float32" + } + ] } ] }, @@ -631,5 +887,73 @@ ] } ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ijk->ikj", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with 3 dims", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij->...ji", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with ellipsis with input/output dims > 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 1, 1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ]