diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 980a1dfbb404a..8d3ea403fb74b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -202,4 +202,4 @@ endif() if (onnxruntime_USE_AZURE) include(onnxruntime_providers_azure.cmake) -endif() \ No newline at end of file +endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index c19e703fb2128..003012f8da071 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -37,6 +37,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -122,6 +123,7 @@ # CUDA 11.3+ supports parallel compilation # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) + option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() if (UNIX) diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake index 4339c42cd4254..01b0bda9fea6b 100644 --- a/cmake/onnxruntime_providers_dml.cmake +++ b/cmake/onnxruntime_providers_dml.cmake @@ -56,13 +56,13 @@ if (GDK_PLATFORM STREQUAL Scarlett) target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs}) else() - target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib) + target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib) endif() target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199") + set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") endif() target_compile_definitions(onnxruntime_providers_dml @@ -88,4 +88,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/docs/ORTModule_PythonOp_Notes.md b/docs/ORTModule_PythonOp_Notes.md new file mode 100644 index 0000000000000..1bb549f8fab9d --- /dev/null +++ b/docs/ORTModule_PythonOp_Notes.md @@ -0,0 +1,156 @@ +# ORTModule Custom Autograd Function Support + +## What is autograd Functions? + +`PyTorch` allows users to define customized operators (for its forward and backward implementations) [PyTorch: Defining New autograd Functions](https://github.com/pytorch/tutorials/blob/d98606855d3c8c5bd78d55b95717be5a02960363/beginner_source/examples_autograd/polynomial_custom_function.py#L25). + +There are many such use cases as more optimized deep learning projects keep growing, here we just name a few: +- [NVIDIA/apex](https://github.com/NVIDIA/apex/blob/58acf96915eecd7e13adff61d2c389fba3efede2/apex/transformer/functional/fused_softmax.py#L21) +- [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM/blob/f7727433293427bef04858f67b2889fe9b177d88/megatron/core/tensor_parallel/mappings.py#L220C31-L220C31) +- [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention/blob/3a9fe7b0faaa9d648394026c9c20231c07bf999d/flash_attn/flash_attn_interface.py#L429), +- [openai/triton](https://github.com/openai/triton/blob/424e67e7275f0cb2cd231e7a4d17ff8570530b77/python/tutorials/06-fused-attention.py#L457) +- ... + +Those operators are used in training/evaluation scenarios a lot, where is ORTModule capability overlaps. +To best release ORTModule's acceleration power, we need tolerant and handle those customized operators +from the to-onnx conversion, to backward graph building, and also its execution in runtime as a full lifecycle. + +## How ORTModule support autograd.Function? + +The way we have here is through introduced `PythonOp`/`PythonOpGrad` MS domain operators in `ONNX Runtime`, +- Map autograd Function (`prim::PythonOp` in `PyTorch`) to `PythonOp` in `ONNX Runtime` during model export by [registering customized export function](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py#L69C16-L69C16) + ``` + class ScalarAndTupleFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, alpha, beta, gamma): + ctx.save_for_backward(input) + ctx.alpha = alpha + ctx.beta = beta + ctx.gamma = gamma + return alpha * beta[0] * beta[1] * gamma * input.clamp(min=0) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + alpha = ctx.alpha + beta = ctx.beta + gamma = ctx.gamma + grad_input = grad_output.clone() + grad_input[input < 0] = 0 + return alpha * beta[0] * beta[1] * gamma * grad_input, None, None, None + ``` + The example above shows a customized function taking 4 inputs (despite of ctx), the first input is a tensor [exporter treats it as input for `PythonOp`](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py#L174), + the others are scalars, export function will convert all such non-tensor inputs to constant and [stores + in `PythonOp`'s attributes](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py#L272). Things to be noted here: if the non-tensor + input is one of those types "bool scalar, int scalar, float scalar, bool tuple, int tuple, float tuple", they will be + stored in corresponding attributes; otherwise, they will be treated a `object` and the object address stored in `input_pointer_scalars` ([reference count will be increased](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py#L250C27-L250C27) also to make sure it exists during model run). +- [PythonOp kernel](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc#L38) is responsible to run the `forward` interface user defined through [forward runner](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py#L409). +Similarly, [PythonOpGrad kernel](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc#L49) is responsible to run the `backward` interface user defined through [backward runner](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py#L554). + +Currently, for training python wheel, `PythonOp` support is by default enabled, users don't need to be aware of it. As long as the +defined torch.autograd.Function is working in `PyTorch` run, it should be runnable with `ORTModule`. If you need to enable it or +disable it explicitly, refer to the [wiki](https://github.com/microsoft/onnxruntime/blob/main/docs/ORTModule_Training_Guidelines.md#ortmodule_enable_custom_autograd). + + + +## Known Issues and Workaround + +PyTorch Versions +- Minimum version 1.9 (introduced "Support registering custom export for `prim::PythonOp`` from torch.autograd.Function ([#55630](https://github.com/pytorch/pytorch/pull/55630)) ([#57600](https://github.com/pytorch/pytorch/pull/57600))") +- If the static forward function has only one output, any version of Pytorch 1.9 is fine. Otherwise, a PyTorch version containing [this commit](https://github.com/pytorch/pytorch/commit/a55cae3d37e0f7852e391886c3904307caa4d06d) is required. +- [Throw _Map_base::at Exception](https://github.com/pytorch/pytorch/issues/88286), export errors like this: + ``` + RuntimeError: There was an error while exporting the PyTorch model to ONNX: + + Traceback (most recent call last): + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 316, in get_exception_as_string + raise exception + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 425, in _get_exported_model + torch.onnx.export( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export + _export( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 1548, in _export + graph, params_dict, torch_out = _model_to_graph( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph + graph, params, torch_out, module = _create_jit_graph(model, args) + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph + graph, torch_out = _trace_and_get_graph_from_model(model, args) + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph + outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + ... + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/deepspeed-0.9.5+95680ca-py3.8.egg/deepspeed/runtime/zero/parameter_offload.py", line 632, in _ort_post_forward_module_hook + a = ORTPostForwardwardFunction.apply(module, _post_forward_module_hook, _ort_run_before_backward_function, len(input), len(output), *input_and_output) + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply + return super().apply(*args, **kwargs) # type: ignore[misc] + RuntimeError: _Map_base::at + ``` + Resolution: upgrade `PyTorch` to new versions containing [this commit](https://github.com/thiagocrepaldi/pytorch/commit/3d3da109e3afa617c513e78aa999f5a1f44ffbce), when export param `autograd_inlining` is [set to false](https://github.com/microsoft/onnxruntime/blob/0e2782438a65b97919f15af14d2a4ada361157b6/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py#L387C26-L387C26) to skip this error. +- "Tried to trace <__torch__.torch.classes.c10d.ProcessGroup object at 0x2969c520> but it is not part of the active trace" + This usually happens when torch.autograd.Function's forward function used `PyTorch` collective calls and pass the group explicitly. + ``` + RuntimeError: There was an error while exporting the PyTorch model to ONNX: + + Traceback (most recent call last): + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 324, in get_exception_as_string + raise exception + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 342, in _get_exported_model + torch.onnx.export( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 507, in export + _export( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1567, in _export + graph, params_dict, torch_out = _model_to_graph( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1124, in _model_to_graph + graph, params, torch_out, module = _create_jit_graph(model, args) + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1000, in _create_jit_graph + graph, torch_out = _trace_and_get_graph_from_model(model, args) + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 1269, in _get_trace_graph + outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 128, in forward + graph, out = torch._C._create_graph_by_tracing( + ... + File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/parameter_offload.py", line 640, in _ort_pre_forward_module_hook + rets = ORTPreForwardwardFunction.apply(self, module, _ort_run_after_backward_function, *inputs) + ... + File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/parameter_offload.py", line 823, in pre_sub_module_forward_function + param_coordinator.fetch_sub_module(sub_module, forward=True) + ... + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2841, in all_gather_into_tensor + work = group._allgather_base(output_tensor, input_tensor) + RuntimeError: Tried to trace <__torch__.torch.classes.c10d.ProcessGroup object at 0x56250ad114a0> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced. + ``` + Resolution: modify the autograd.Function, to skip the run the collection operator during onnx export, here is an example. + ```python + # Pre + def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()): + return torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug) + + # Workaround + from typing import Any, List + class DummyWork(torch.distributed.distributed_c10d.Work): + def is_completed(self) -> bool: + return True + def is_success(self) -> bool: + return True + def exception(self) -> Any: + return None + def wait(self, timeout: timedelta = timedelta) -> bool: + return True + def source_rank(self) -> int: + return 0 + def _source_rank(self) -> int: + return 0 + def result(self) -> List[torch.Tensor]: + return [] + def synchronize(self): + pass + + def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()): + if torch.onnx.is_in_onnx_export(): + return DummyWork() + + return torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug) + ``` diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 0782d2d9ed760..dd4ffb835d51c 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice; extern "C" { #endif +enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +}; + +enum OrtDmlDeviceFilter : uint32_t { + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +}; + +inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } +inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); } +inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); } +inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); } +inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); } +inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } +inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } + +struct OrtDmlDeviceOptions { + OrtDmlPerformancePreference Preference; + OrtDmlDeviceFilter Filter; +}; + /** * [[deprecated]] * This export is deprecated. @@ -99,6 +124,13 @@ struct OrtDmlApi { * This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP. */ ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource); + + /** + * SessionOptionsAppendExecutionProvider_DML2 + * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference + * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; #ifdef __cplusplus diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index cf6d25e61acf7..5d66caf77f08f 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -259,10 +259,6 @@ export class WebGpuBackend { run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[], createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView, createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] { - if (inputTensorViews.length !== program.inputTypes.length) { - throw new Error(`Input size must be equal to ${program.inputTypes.length}.`); - } - // create info for inputs const inputDatas: GpuData[] = []; for (let i = 0; i < inputTensorViews.length; ++i) { @@ -277,7 +273,7 @@ export class WebGpuBackend { const key = getProgramInfoUniqueKey(program, inputTensorViews); let artifact = this.programManager.getArtifact(key); - const {outputs, dispatchGroup, variables} = program.getRunData(inputTensorViews); + const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); // check output indices const validatedOutputIndices = outputIndices.length === 0 ? outputs.map((_, i) => i) : outputIndices; @@ -328,12 +324,12 @@ export class WebGpuBackend { // TODO: add cache for uniform (is it necessary?) // let uniformBufferBinding: GPUBindingResource|undefined; - if (variables) { + if (programUniforms) { let currentOffset = 0; let preLength = 0; const offsets: number[] = []; let maxAlignmentOfField = 1; - variables.forEach(v => { + programUniforms.forEach(v => { const data = typeof v.data === 'number' ? [v.data] : v.data; // https://www.w3.org/TR/WGSL/#alignof let baseAlignment: number; @@ -374,7 +370,7 @@ export class WebGpuBackend { currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField; const arrayBuffer = new ArrayBuffer(currentOffset); - variables.forEach((v, i) => { + programUniforms.forEach((v, i) => { const offset = offsets[i]; const data = typeof v.data === 'number' ? [v.data] : v.data; if (v.type === 'int32') { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 389c5c725b391..01ddca520deed 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -22,7 +22,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo} from '../../types'; +import {ProgramInfo} from '../../types'; import {tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; @@ -213,11 +213,9 @@ export const createConv2DMatMulProgramInfo = return { name: 'Conv2DMatMul', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, }), getShaderSource: () => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 32bff2f7586b9..840360223c75a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -22,7 +22,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo} from '../../types'; +import {ProgramInfo} from '../../types'; import {ConvTransposeAttributes} from '../conv-transpose'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; @@ -200,11 +200,9 @@ export const createConv2DTransposeMatMulProgramInfo = } return { name: 'Conv2DTransposeMatMul', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} }), getShaderSource: () => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 414abe64eba9e..2e6392aada454 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,7 +20,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo} from '../../types'; +import {ProgramInfo} from '../../types'; import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; @@ -260,15 +260,12 @@ export const createConvTranspose2DProgramInfo = const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { name: 'ConvTranspose2D', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType, - gpuDataType: GpuDataType.default + dataType: inputs[0].dataType }] }), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index c7a0b701ee86c..1032869412462 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -21,7 +21,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo} from '../../types'; +import {ProgramInfo} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; @@ -481,11 +481,9 @@ export const createMatmulProgramInfo = ${batchDims.impl()}`; return { name: 'MatMul', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], shaderCache: {hint: activationAttributes.activationCacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts index 18bde55db6244..e2b8412000ef9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -51,9 +51,8 @@ const createBiasAddProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => return { name: 'BiasAdd', - inputTypes: Array(inputs.length).fill(GpuDataType.default), getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index fc171367a7071..14eefc344f3c0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {erfImpl} from './unary-op'; @@ -58,9 +58,8 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI return { name: 'BiasSplitGelu', - inputTypes: [GpuDataType.default, GpuDataType.default], getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index e57b8869f395d..eab571e87f5f5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -4,7 +4,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -175,13 +175,12 @@ const createBinaryOpProgramInfo = return { name, - inputTypes: [GpuDataType.default, GpuDataType.default], shaderCache: {hint: cacheKey}, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, outputDataType, additionalImplementation), getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)} }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 4354543aea713..55ef9b3366abb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -3,7 +3,7 @@ import {DataType} from '../../../wasm-common'; import {ShapeUtil} from '../../util'; -import {ProgramVariable} from '../types'; +import {ProgramUniform} from '../types'; /** * constant value for a workgroup size. @@ -259,7 +259,7 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = }; export const createTensorShapeVariables = (dims: readonly number[]): - ProgramVariable[] => [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; + ProgramUniform[] => [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; /** * A helper function to get a IndicesHelper for a given input or output. diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 1b0505277b73c..4b5ca869f0dfb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -122,10 +122,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P }`; return { name: 'Concat', - inputTypes: Array(inputs.length).fill(GpuDataType.default), shaderCache: {hint: `${axis}`}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 21a15e82c2750..7abf022928ade 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {GpuDataType, ProgramInfo} from '../types'; +import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; @@ -85,14 +85,11 @@ export const createGroupedConvProgramInfo = }`; return { name: 'GroupedConv', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType, - gpuDataType: GpuDataType.default + dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index c54ead10ec08f..357eb5c0b84ad 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -260,10 +260,9 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: }`; return { name: 'Einsum', - inputTypes: Array(inputs.length).fill(GpuDataType.default), shaderCache: {hint: einsumEquation.equation}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 0e76501795ed2..5680af4787b6a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -70,11 +70,10 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }`; return { name: 'Expand', - inputTypes: [GpuDataType.default], shaderCache: {hint: `${outputShape}`}, getShaderSource, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }) }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index aef45dd70e31e..9924a50e2ae6f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -87,10 +87,9 @@ const createGatherElementsProgramInfo = return { name: 'GatherElements', - inputTypes: [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 561b9f9cca2b7..fdcd64abfe4e7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -72,11 +72,10 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }`; return { name: 'Gather', - inputTypes: [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: inputs[0].dataType}, ], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 9ec84333b426a..6e9dee41ce488 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -112,11 +112,9 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt }`; return { name: 'Gemm', - inputTypes: inputs.length === 3 ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index d55aa835464ba..0c39152f56dad 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -14,8 +14,7 @@ export interface InstanceNormAttributes extends AttributeWithCacheKey { } const metadata = { - name: 'InstanceNormalization', - inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], + name: 'InstanceNormalization' }; const createInstanceNormProgramInfo = @@ -104,7 +103,7 @@ const createInstanceNormProgramInfo = shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: inputs[0].dataType}, ], dispatchGroup: {x: normCount} }), @@ -169,7 +168,7 @@ const createInstanceNormNHWCProgramInfo = shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: inputs[0].dataType}, ], dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)} }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3fdf935b99b30..186a9999e53f2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -96,22 +96,20 @@ const createLayerNormProgramInfo = ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; ${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''}; }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (hasMeanDataOutput) { outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: meanInvStdDevDim, dataType: inputs[0].dataType}, ); } if (hasInvStdOutput) { outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: meanInvStdDevDim, dataType: inputs[0].dataType}, ); } return { name: 'LayerNormalization', - inputTypes: inputs.length === 2 ? [GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default, GpuDataType.default], shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`}, getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 021ae9a896ce6..180dab92a453a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -198,10 +198,9 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); return { name: 'Pad', - inputTypes: [GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} }), getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index c81afa47a5e23..05f02b07c4d89 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -254,10 +254,9 @@ const createAveragePoolProgramInfo = } return { name, - inputTypes: [GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: input.dataType}], dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} }), getShaderSource: shaderHelper => @@ -320,10 +319,9 @@ const createMaxPoolProgramInfo = const x = inputVariable('x', input.dataType, input.dims); return { name, - inputTypes: [GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: input.dataType}], dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} }), getShaderSource: shaderHelper => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts index b857e6380c7c9..9cf66111bf707 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/range.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -4,7 +4,7 @@ import {env} from 'onnxruntime-common'; import {DataType} from '../../../wasm-common'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {outputVariable, ShaderHelper} from './common'; @@ -34,13 +34,11 @@ const createRangeProgramInfo = (start: number, limit: number, delta: number, dat }`; return { name: 'Range', - inputTypes: [], shaderCache: {hint: [start, limit, delta].map(x => x.toString()).join('_')}, getShaderSource, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} - }) + getRunData: () => ( + {outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}}) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 0003ccfb4f32d..44d6332852d2a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramShaderCacheInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -97,11 +97,10 @@ export const createReduceProgramInfo = return { name, - inputTypes: [GpuDataType.default], shaderCache, getShaderSource, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 00474645b2d2c..fed1dbcf51e9b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -513,14 +513,13 @@ const createResizeProgramInfo = return { name: 'Resize', - inputTypes: [GpuDataType.default], shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ sizes.length > 0 ? sizes : ''}` }, getShaderSource, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputTensor.dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} }) }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index b4478a33d391b..75e6a84cd6fcd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -135,20 +135,19 @@ const createSkipLayerNormProgramInfo = output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType}); } if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType}); } if (outputCount > 3) { - outputs.push({dims: inputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: inputShape, dataType: inputs[0].dataType}); } return { name: 'SkipLayerNormalization', - inputTypes: new Array(inputs.length).fill(GpuDataType.default), shaderCache: {hint: attributes.cacheKey}, getShaderSource, getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 06d960525d05a..d607351f69b74 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -137,8 +137,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); - const outputTensorInfo: - TensorInfo = {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}; + const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; const output = outputVariable('output', inputs[0].dataType, outputShape); const input = inputVariable('input', inputs[0].dataType, inputShape); @@ -161,7 +160,6 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice }`; return { name: 'Slice', - inputTypes: [GpuDataType.default], shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`}, getShaderSource, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 8d53e3311fa74..d4dbad79e613e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -8,7 +8,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -119,11 +119,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut }`; return { name: 'Softmax', - inputTypes: [GpuDataType.default], - getRunData: () => ({ - outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: {x: rows} - }), + getRunData: () => ({outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}}), getShaderSource, }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 7c3d16ba896bc..fd60d81b87ae1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -81,7 +81,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); outputs[i] = outputVariable(`output${i}`, dataType, outputShapes[i]); - outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -102,7 +102,6 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split }`; return { name: 'Split', - inputTypes: [GpuDataType.default], shaderCache: {hint: attributes.cacheKey}, getShaderSource, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index a9f3b4e7812da..e294541a775ca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -4,7 +4,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -74,10 +74,9 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf return { name: 'Tile', - inputTypes: [GpuDataType.default], shaderCache: {hint: `${repeats}`}, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, }), getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index e436b4bbb380d..fe556a7fd8552 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -56,15 +56,14 @@ export const createTransposeProgramInfo = }`; return { name: 'Transpose', - inputTypes: [GpuDataType.default], shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, getRunData: (inputs) => { const outputShape = getOutputShape(inputs[0].dims, perm); const outputSize = ShapeUtil.size(outputShape); return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - variables: [ + programUniforms: [ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index dedf8ffdf6e74..bead3e72f63c7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -45,12 +45,11 @@ const createElementwiseProgramInfo = (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, cacheKey?: string, outputDataType: number = input.dataType): ProgramInfo => ({ name, - inputTypes: [GpuDataType.default], shaderCache: {hint: cacheKey}, getShaderSource: shaderHelper => createElementwiseProgramShader( shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation), getRunData: (inputTensors) => ({ - outputs: [{dims: input.dims, dataType: outputDataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: input.dims, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)} }) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index d481a1636f3a3..6f66dd86b4088 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -4,7 +4,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -92,11 +92,10 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => return { name: 'Where', - inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], getShaderSource: (shaderHelper) => createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)} }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 032bb723a6332..23fa33a9bba8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,11 +21,10 @@ export interface GpuData { export interface TensorInfo { dims: readonly number[]; dataType: number; - gpuDataType: GpuDataType; } -export interface ProgramVariable { +export interface ProgramUniform { type: 'int32'|'float32'|'uint32'; data: number|readonly number[]; } @@ -88,11 +87,6 @@ export interface ProgramInfo { */ name: string; - /** - * gpu data types for each input - */ - inputTypes: GpuDataType[]; - /** * an optional object describing the cache information of the program shader. * @@ -115,7 +109,7 @@ export interface ProgramInfo { getRunData: (inputs: readonly TensorView[]) => { outputs: readonly TensorInfo[]; dispatchGroup: {x: number; y?: number; z?: number}; - variables?: readonly ProgramVariable[]; + programUniforms?: readonly ProgramUniform[]; }; } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 314fd7f6babbc..5151f27582c1f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -408,7 +408,7 @@ async function main() { }); // ort.wasm-core[.min].js await addAllWebBuildTasks({ - outputBundleName: 'ort.wasm-core.min', + outputBundleName: 'ort.wasm-core', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', @@ -417,9 +417,9 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, }); - // ort.training.wasm.min.js + // ort.training.wasm[.min].js await addAllWebBuildTasks({ - outputBundleName: 'ort.training.wasm.min', + outputBundleName: 'ort.training.wasm', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_TRAINING': 'false', diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 2cb4578d99687..b9d12cf47b5c5 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } + +declare module 'onnxruntime-web/training' { + export * from 'onnxruntime-web'; +} diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc index 253a58bd82a20..9008edbf3db30 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc @@ -4,7 +4,6 @@ // Distributed computation. #include "sharding.h" #include "distributed_matmul.h" -#include "nccl_kernels.h" #include "mpi_include.h" // ORT system. @@ -63,20 +62,7 @@ static TensorShape InferMatmulOutputShape( }; template -DistributedMatMul::DistributedMatMul(const OpKernelInfo& info) : NcclKernel(info) { - std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); - std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); - std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); - std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); - - for (size_t i = 0; i < input_shard_specs.size(); ++i) { - auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); - input_shard_specs_.push_back(spec); - } - for (size_t i = 0; i < output_shard_specs.size(); ++i) { - auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); - output_shard_specs_.push_back(spec); - } +DistributedMatMul::DistributedMatMul(const OpKernelInfo& info) : DistributedKernel(info) { } template diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h index d8df24c03498f..da07f9a8b2c7b 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - -#include "sharding_spec.h" -#include "core/providers/cuda/cuda_kernel.h" +#include "sharding.h" #include #include @@ -20,15 +18,11 @@ namespace cuda { #if defined(ORT_USE_NCCL) template -class DistributedMatMul final : public NcclKernel { +class DistributedMatMul final : public DistributedKernel { public: explicit DistributedMatMul(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; - - private: - std::vector input_shard_specs_; - std::vector output_shard_specs_; }; #endif diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc new file mode 100644 index 0000000000000..5768dba791292 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_slice.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedSlice::DistributedSlice(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedSlice::ComputeInternal(OpKernelContext* context) const { + const auto tensor_shard_data = context->Input(0); + const auto tensor_shard_starts = context->Input(1); + const auto tensor_shard_ends = context->Input(2); + + const TensorPartitionSpec& spec_data = input_shard_specs_[0]; + const TensorPartitionSpec& spec_starts = input_shard_specs_[1]; + const TensorPartitionSpec& spec_ends = input_shard_specs_[2]; + const TensorPartitionSpec& spec_Y = output_shard_specs_[0]; + + const auto tensor_shard_axes = context->Input(3); + const TensorPartitionSpec& spec_axes = input_shard_specs_[3]; + + if (spec_starts.HasShard() || + spec_ends.HasShard() || + spec_axes.HasShard() || + (input_shard_specs_.size() > 4 && input_shard_specs_[4].HasShard())) + ORT_THROW("DistributedSlice: shard on starts / ends / axes / steps are not supported yet."); + + std::vector input_starts; + std::vector input_ends; + auto starts_data = tensor_shard_starts->DataAsSpan(); + input_starts.resize(starts_data.size()); + std::copy(starts_data.begin(), starts_data.end(), input_starts.begin()); + auto ends_data = tensor_shard_ends->DataAsSpan(); + input_ends.resize(ends_data.size()); + std::copy(ends_data.begin(), ends_data.end(), input_ends.begin()); + + std::vector input_axes; + if (tensor_shard_axes) { + auto axes_data = tensor_shard_axes->DataAsSpan(); + input_axes.resize(axes_data.size()); + std::copy(axes_data.begin(), axes_data.end(), input_axes.begin()); + } + + std::vector input_steps; + const auto tensor_shard_steps = context->Input(4); + if (tensor_shard_steps) { + const TensorPartitionSpec& spec_steps = input_shard_specs_[4]; + if (spec_steps.HasShard()) + ORT_THROW("Not supported yet."); + + auto steps_data = tensor_shard_steps->DataAsSpan(); + input_steps.resize(steps_data.size()); + std::copy(steps_data.begin(), steps_data.end(), input_steps.begin()); + } + + if (spec_data.GetPartitionAxis() != -1 && + std::find(input_axes.begin(), input_axes.end(), spec_data.GetPartitionAxis()) != input_axes.end()) { + // shard on slice axes, reshard first + auto tmp_spec_data = TensorPartitionSpec::CreateAllReplica(spec_data); + auto tensor_data = ReshardTensor(this, context, spec_data, tmp_spec_data, nccl_->Rank(), tensor_shard_data); + + const auto& input_shape = tensor_data->Shape(); + const auto input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + TensorShape output_shape(compute_metadata.output_dims_); + + if (spec_Y.HasNoShard()) { + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_data.get(), + input_starts, + input_ends, + input_axes, + input_steps, + context->Output(0, output_shape))); + } else { + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst_tensor = Tensor::Create(tensor_data->DataType(), output_shape, alloc); + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_data.get(), + input_starts, + input_ends, + input_axes, + input_steps, + dst_tensor.get())); + auto tmp_spec_output = TensorPartitionSpec::CreateAllReplica(spec_Y); + ReshardTensor(this, context, tmp_spec_output, spec_Y, nccl_->Rank(), dst_tensor.get(), 0); + } + } else { + const auto& input_shape = tensor_shard_data->Shape(); + const auto input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + TensorShape output_shape(compute_metadata.output_dims_); + + if (spec_Y.GetPartitionAxis() == spec_data.GetPartitionAxis()) { + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_shard_data, + input_starts, + input_ends, + input_axes, + input_steps, + context->Output(0, output_shape))); + } else { + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst_tensor = Tensor::Create(tensor_shard_data->DataType(), output_shape, alloc); + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_shard_data, + input_starts, + input_ends, + input_axes, + input_steps, + dst_tensor.get())); + ReshardTensor(this, context, spec_data, spec_Y, nccl_->Rank(), dst_tensor.get(), 0); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSlice, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .InputMemoryType(OrtMemTypeCPUInput, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + DistributedSlice); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSlice, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .InputMemoryType(OrtMemTypeCPUInput, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + DistributedSlice); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h new file mode 100644 index 0000000000000..48c77eee241de --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedSlice final : public DistributedKernel { + public: + explicit DistributedSlice(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index d9f2f3c1bcbca..7d106fd75e2d0 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -212,6 +212,46 @@ std::unique_ptr ReshardTensor( return dst; } +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + int output_idx) { + // Implement ReshardTensor but returning a unique_ptr to Tensor instead. + const auto origin_shape = ComputeOriginShape(src->Shape(), src_spec); + const auto dst_shape = ComputeShardShape(origin_shape, dst_spec); + ORT_ENFORCE(CanShard(origin_shape, dst_spec), "Cannot shard tensor. Shape:", origin_shape, ", sharding spec: ", dst_spec.ToString()); + + auto* dst = ctx->Output(output_idx, dst_shape); + ReshardTensor( + nccl_kernel, + ctx, + src_spec, + dst_spec, + device_id, + src, + dst); +} + +DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) { + std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); + std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); + std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); + std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); + + for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); + input_shard_specs_.push_back(spec); + } + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); + output_shard_specs_.push_back(spec); + } +} + #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.h b/onnxruntime/contrib_ops/cuda/collective/sharding.h index 497826160aaab..81a0f72f0c32f 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.h @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "sharding_spec.h" #include "nccl_kernels.h" -#pragma once - namespace onnxruntime { namespace contrib { namespace cuda { @@ -49,6 +48,16 @@ void ReshardTensor( const Tensor* src, Tensor* dst); +// Output from ctx +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + int output_idx); + std::unique_ptr ReshardTensor( const NcclKernel* nccl_kernel, OpKernelContext* ctx, @@ -57,6 +66,17 @@ std::unique_ptr ReshardTensor( const int64_t device_id, const Tensor* src); +class TensorPartitionSpec; + +class DistributedKernel : public NcclKernel { + public: + explicit DistributedKernel(const OpKernelInfo& info); + + protected: + std::vector input_shard_specs_; + std::vector output_shard_specs_; +}; + #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 13982ee7711cf..0f5ef6927a545 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "core/common/common.h" #include "core/framework/tensor_shape.h" @@ -8,8 +9,6 @@ #include #include -#pragma once - namespace onnxruntime { namespace contrib { namespace cuda { diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 71ee5ae1ddbe6..3e440a091870a 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -153,6 +153,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllT class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); #endif template <> @@ -310,6 +313,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 84eed7fae6ac1..7cdd71014c02e 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -105,6 +105,74 @@ void RegisterCollectiveOps() { "tensor(float)", }, "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("device_mesh_elements", + "", + AttributeProto::INTS) + .Attr("device_mesh_shape", + "", + AttributeProto::INTS) + .Attr("input_shard_specs", + "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + AttributeProto::STRINGS) + .Input( + 0, + "data", + "Tensor of data to extract slices from.", + "T", + OpSchema::Single, + true, + 1, + OpSchema::Differentiable) + .Input( + 1, + "starts", + "1-D tensor of starting indices of corresponding axis in `axes`", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 2, + "ends", + "1-D tensor of ending indices (exclusive) of corresponding axis in `axes`", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 3, + "axes", + "1-D tensor of axes that `starts` and `ends` apply to. Negative value means counting dimensions " + "from the back. Accepted range is [-r, r-1] where r = rank(data). Behavior is undefined if an " + "axis is repeated.", + "Tind", + OpSchema::Optional, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 4, + "steps", + "1-D tensor of slice step of corresponding axis in `axes`. " + "Negative value means slicing backward. 'steps' cannot be 0. " + "Defaults to 1s.", + "Tind", + OpSchema::Optional, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") + .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); } } // namespace contrib diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 96bc1d8010bed..3c0f82408179b 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -458,12 +458,16 @@ Return Value: #if defined(_WIN32) HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we - // disable it there. - uint64_t isar0_el1; - asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; #else + // Use the cpuinfo value which is read from sysctl and has some additional special cases. + // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips + // as well as failing on other ARM chips as it is an EL1 level register that requires extra + // privileges to read. + // + // uint64_t isar0_el1; + // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); + // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #endif diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index fde61e73c2124..cd8bc8fe909dc 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include #ifndef _GAMING_XBOX #include @@ -92,12 +95,298 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } +static bool IsHardwareAdapter(IDXCoreAdapter* adapter) { + bool is_hardware = false; + THROW_IF_FAILED(adapter->GetProperty( + DXCoreAdapterProperty::IsHardware, + &is_hardware)); + return is_hardware; +} + +static bool IsGPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); +} + +static bool IsNPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); +} + +enum class DeviceType { GPU, NPU, BadDevice }; + +static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) { + auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu; + if (IsGPU(adapter) && allow_gpus) { + return DeviceType::GPU; + } + + auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + if (IsNPU(adapter) && allow_npus) { + return DeviceType::NPU; + } + + return DeviceType::BadDevice; +} + +// Struct for holding each adapter +struct AdapterInfo { + ComPtr Adapter; + DeviceType Type; // GPU or NPU +}; + +static ComPtr EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) { + ComPtr adapter_list; + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + if (!use_dxcore_workload_enumeration) { + // Get a list of all the adapters that support compute + GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + ORT_THROW_IF_FAILED( + adapter_factory->CreateAdapterList(_countof(attributes), + attributes, + adapter_list.GetAddressOf())); + } + + return adapter_list; +} + +static void SortDXCoreAdaptersByPreference( + IDXCoreAdapterList* adapter_list, + OrtDmlPerformancePreference preference) { + if (adapter_list->GetAdapterCount() <= 1) { + return; + } + + // DML prefers the HighPerformance adapter by default + std::array adapter_list_preferences = { + DXCoreAdapterPreference::HighPerformance + }; + + // If callers specify minimum power change the DXCore sort policy + // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists + if (preference == OrtDmlPerformancePreference::MinimumPower) { + adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower; + } + + ORT_THROW_IF_FAILED(adapter_list->Sort( + static_cast(adapter_list_preferences.size()), + adapter_list_preferences.data())); +} + +static std::vector FilterDXCoreAdapters( + IDXCoreAdapterList* adapter_list, + OrtDmlDeviceFilter filter) { + auto adapter_infos = std::vector(); + const uint32_t count = adapter_list->GetAdapterCount(); + for (uint32_t i = 0; i < count; ++i) { + ComPtr candidate_adapter; + ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf())); + + // Add the adapters that are valid based on the device filter (GPU, NPU, or Both) + auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter); + if (adapter_type != DeviceType::BadDevice) { + adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type}); + } + } + + return adapter_infos; +} + +static void SortHeterogenousDXCoreAdapterList( + std::vector& adapter_infos, + OrtDmlDeviceFilter filter, + OrtDmlPerformancePreference preference) { + if (adapter_infos.size() <= 1) { + return; + } + + // When considering both GPUs and NPUs sort them by performance preference + // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) + auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + auto only_npus = filter == OrtDmlDeviceFilter::Npu; + if (!keep_npus || only_npus) { + return; + } + + struct SortingPolicy { + // default is false because GPUs are considered higher priority in + // a mixed adapter environment + bool npus_first_ = false; + + SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { } + + bool operator()(const AdapterInfo& a, const AdapterInfo& b) { + return npus_first_ ? a.Type < b.Type : a.Type > b.Type; + } + }; + + auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower); + auto policy = SortingPolicy(npus_first); + std::sort(adapter_infos.begin(), adapter_infos.end(), policy); +} + std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { return Create(device_id, /*skip_software_device_check*/ false); } -Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device(int device_id, bool skip_software_device_check) -{ +std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( + OrtDmlDeviceOptions* device_options) { + auto default_device_options = OrtDmlDeviceOptions { Default, Gpu }; + if (device_options == nullptr) { + device_options = &default_device_options; + } + + OrtDmlPerformancePreference preference = device_options->Preference; + OrtDmlDeviceFilter filter = device_options->Filter; + + // Create DXCore Adapter Factory + ComPtr adapter_factory; + ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); + + // Get all DML compatible DXCore adapters + ComPtr adapter_list; + adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get()); + + if (adapter_list->GetAdapterCount() == 0) { + ORT_THROW("No GPUs or NPUs detected."); + } + + // Sort the adapter list to honor DXCore hardware ordering + SortDXCoreAdaptersByPreference(adapter_list.Get(), preference); + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + + std::vector adapter_infos; + if (!use_dxcore_workload_enumeration) { + // Filter all DXCore adapters to hardware type specified by the device filter + adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter); + if (adapter_infos.size() == 0) { + ORT_THROW("No devices detected that match the filter criteria."); + } + } + + // DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them. + SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference); + + // Extract just the adapters + auto adapters = std::vector>(adapter_infos.size()); + std::transform( + adapter_infos.begin(), adapter_infos.end(), + adapters.begin(), + [](auto& a){ return a.Adapter; }); + + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); +} + +static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) { + static const std::string PerformancePreference = "performance_preference"; + static const std::string Default = "default"; + static const std::string HighPerformance = "high_performance"; + static const std::string MinimumPower = "minimum_power"; + + auto preference_it = provider_options.find(PerformancePreference); + if (preference_it != provider_options.end()) { + if (preference_it->second == Default) { + return OrtDmlPerformancePreference::Default; + } + + if (preference_it->second == HighPerformance) { + return OrtDmlPerformancePreference::HighPerformance; + } + + if (preference_it->second == MinimumPower) { + return OrtDmlPerformancePreference::MinimumPower; + } + + ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseFilter(const ProviderOptions& provider_options) { + static const std::string Filter = "filter"; + static const std::string Any = "any"; + static const std::string Gpu = "gpu"; + static const std::string Npu = "npu"; + + auto preference_it = provider_options.find(Filter); + if (preference_it != provider_options.end()) { + if (preference_it->second == Any) { + return OrtDmlDeviceFilter::Any; + } + + if (preference_it->second == Gpu) { + return OrtDmlDeviceFilter::Gpu; + } + + if (preference_it->second == Npu) { + return OrtDmlDeviceFilter::Npu; + } + + ORT_THROW("Invalid Filter provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseDeviceId(const ProviderOptions& provider_options) { + static const std::string DeviceId = "device_id"; + + auto preference_it = provider_options.find(DeviceId); + if (preference_it != provider_options.end()) { + if (!preference_it->second.empty()) { + return std::stoi(preference_it->second); + } + } + + return {}; +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( + const ProviderOptions& provider_options) { + auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) + { + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + } + + auto preference = ParsePerformancePreference(provider_options); + auto filter = ParseFilter(provider_options); + + // If no preference/filters are specified then create with default preference/filters. + if (!preference.has_value() && !filter.has_value()) { + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr); + } + + if (!preference.has_value()) { + preference = OrtDmlPerformancePreference::Default; + } + + if (!filter.has_value()) { + filter = OrtDmlDeviceFilter::Gpu; + } + + OrtDmlDeviceOptions device_options; + device_options.Preference = preference.value(); + device_options.Filter = filter.value(); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); +} + +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device( + int device_id, + bool skip_software_device_check) { #ifdef _GAMING_XBOX ComPtr d3d12_device; D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {}; @@ -128,8 +417,7 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Devic return d3d12_device; } -Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) -{ +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) { DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE; // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled @@ -153,9 +441,7 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { - ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); - +std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; @@ -163,10 +449,27 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int ComPtr cmd_queue; ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); - auto dml_device = CreateDMLDevice(d3d12_device.Get()); + auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device); return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } +std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { + ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( + std::vector>&& dxcore_devices) { + // Choose the first device from the list since it's the highest priority + auto dxcore_device = dxcore_devices[0]; + + // Create D3D12 Device from DXCore Adapter + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + } // namespace onnxruntime // [[deprecated]] @@ -211,6 +514,17 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { API_IMPL_END } +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { +API_IMPL_BEGIN +#ifdef USE_DML + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + // return the create function for a dxcore device + options->provider_factories.push_back(factory); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { API_IMPL_BEGIN #ifdef USE_DML @@ -233,7 +547,8 @@ static constexpr OrtDmlApi ort_dml_api_10_to_x = { &OrtSessionOptionsAppendExecutionProviderEx_DML, &CreateGPUAllocationFromD3DResource, &FreeGPUAllocation, - &GetD3D12ResourceFromAllocation + &GetD3D12ResourceFromAllocation, + &OrtSessionOptionsAppendExecutionProvider_DML2, }; const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION { diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 574f4410fe3e3..4e13330a4cd71 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -7,14 +7,26 @@ #include #include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" #include "core/providers/dml/dml_provider_factory.h" +#include +#include + namespace onnxruntime { struct DMLProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(int device_id, bool skip_software_device_check); + + static std::shared_ptr CreateFromProviderOptions( + const ProviderOptions& provider_options_map); + static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromAdapterList( + std::vector>&& dxcore_devices); + static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index fc8c2efc7a80f..17ce9b078b790 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -47,12 +47,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Where", *this); CreateSimpleOpBuilder("Sigmoid", *this); CreateSimpleOpBuilder("Sin", *this); - CreateSimpleOpBuilder("Softmax", *this); CreateSimpleOpBuilder("Sqrt", *this); CreateSimpleOpBuilder("Sub", *this); CreateSimpleOpBuilder("Tanh", *this); - CreateSimpleOpBuilder("LogSoftmax", *this); CreateSimpleOpBuilder("MatMul", *this); CreateSimpleOpBuilder("Concat", *this); @@ -67,6 +65,11 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("GridSample", *this); } + { + CreateSoftmaxOpBuilder("Softmax", *this); + CreateSoftmaxOpBuilder("LogSoftmax", *this); + } + { CreateCastOpBuilder("Cast", *this); } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 5d59f4343d773..c2c9345e109a9 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -50,6 +50,8 @@ const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type); void CreateSimpleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 7c9603692080b..acdcfdc66bf34 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -29,7 +29,7 @@ class SimpleOpBuilder : public BaseOpBuilder { bool do_op_validation) const override ORT_MUST_USE_RESULT; private: - Status ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; + Status ExplicitOpCheck(const NodeUnit& node_unit) const; Status ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -41,30 +41,9 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; }; -static int32_t GetDefaultAxisAttribute(const std::string& op_type, int opset_version) { - if (op_type == "Softmax" || op_type == "LogSoftmax") { - // Default axis changed from 1 to -1 in opset 13. - return opset_version < 13 ? 1 : -1; - } - - return 0; -} - -Status SimpleOpBuilder::ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { +Status SimpleOpBuilder::ExplicitOpCheck(const NodeUnit& node_unit) const { const std::string& op_type = node_unit.OpType(); - // QNN Softmax and LogSoftmax only support an axis value equal to input_rank - 1 (i.e., same as -1). - if (op_type == "Softmax" || op_type == "LogSoftmax") { - int32_t axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion()); - Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; - ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis)); - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape), - "QNN EP: Cannot get shape for Softmax input"); - ORT_RETURN_IF(axis != static_cast(input_shape.size() - 1), - "QNN ", op_type.c_str(), " only supports an `axis` attribute equal to input_rank-1 (or -1)"); - } - if (op_type == "GridSample") { NodeAttrHelper node_helper(node_unit); std::string mode = node_helper.Get("mode", "linear"); @@ -231,7 +210,7 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w const std::string& op_type = node_unit.OpType(); if (do_op_validation) { - ORT_RETURN_IF_ERROR(ExplicitOpCheck(qnn_model_wrapper, node_unit)); + ORT_RETURN_IF_ERROR(ExplicitOpCheck(node_unit)); // Skip the op validation for DepthToSpace & SpaceToDepth if it's not NHWC data layout if (node_unit.Domain() != kMSInternalNHWCDomain && (op_type == "DepthToSpace" || op_type == "SpaceToDepth" || op_type == "GridSample")) { return Status::OK(); @@ -251,8 +230,8 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::vector param_tensor_names; // Add attribute - if (op_type == "LogSoftmax" || op_type == "Softmax" || op_type == "Concat") { - int32_t default_axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion()); + if (op_type == "Concat") { + int32_t default_axis = 0; Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis)); QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc new file mode 100644 index 0000000000000..49d85d76e25a8 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -0,0 +1,237 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace qnn { + +class SoftmaxOpBuilder : public BaseOpBuilder { + public: + SoftmaxOpBuilder() : BaseOpBuilder("SoftmaxOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SoftmaxOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +constexpr int32_t GetDefaultAxisAttribute(int opset_version) { + // Default axis changed from 1 to -1 in opset 13. + return opset_version < 13 ? 1 : -1; +} + +Status SoftmaxOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + const int opset_version = node_unit.SinceVersion(); + + // The QNN HTP backend only supports an `axis` attribute that refers to the last input dimension. + // QNN EP is able to support arbitrary axis attributes by wrapping the QNN operator with transposes. + // However, the exception is Softmax/LogSoftmax with opset < 13. For these older ONNX operators, only + // axis == input_rank - 1 is supported. + if (opset_version < 13) { + const std::string& op_type = node_unit.OpType(); + + int32_t axis = GetDefaultAxisAttribute(opset_version); + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis)); + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape), + "QNN EP: Cannot get shape for Softmax input"); + ORT_RETURN_IF(axis != static_cast(input_shape.size() - 1), + "QNN ", op_type.c_str(), + " only supports an `axis` attribute equal to input_rank-1 (or -1) for ONNX opset < 13"); + } + + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); +} + +static std::vector GetTransposePermToUseLastAxis(uint32_t input_rank, uint32_t axis) { + assert(axis < input_rank); + std::vector transpose_perm; + transpose_perm.reserve(input_rank); + + for (uint32_t dim = 0; dim < input_rank; dim++) { + transpose_perm.push_back(dim); + } + + // Swap axis dim with last dim. + transpose_perm[axis] = input_rank - 1; + transpose_perm[input_rank - 1] = axis; + + return transpose_perm; +} + +Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + const auto& inputs = node_unit.Inputs(); + assert(inputs.size() == 1); + + int32_t axis = GetDefaultAxisAttribute(node_unit.SinceVersion()); + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis)); + + OnnxInputInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input_info)); + const size_t input_rank = input_info.shape.size(); + + // If the axis attribute refers to the last dimension, then process the input as normal. + if (!is_npu_backend || axis == static_cast(input_rank) - 1) { + return ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names); + } + + // + // The axis does **not** refer to the last input dimension. Must wrap transposes around the operator to be able to use + // QNN's Softmax operator, which always uses an axis value that refers to the last dimension. + // + + std::vector transpose_perm = GetTransposePermToUseLastAxis(static_cast(input_rank), + static_cast(axis)); + + const std::string& input_name = inputs[0].node_arg.Name(); + std::string op_input_name = input_info.is_initializer ? input_name : input_name + "_ort_qnn_ep_transpose"; + input_names.push_back(op_input_name); + + std::vector op_input_shape = input_info.shape; + op_input_shape[input_rank - 1] = input_info.shape[axis]; + op_input_shape[axis] = input_info.shape[input_rank - 1]; + + ORT_RETURN_IF(input_info.is_initializer, "QNN EP does not support (Log)Softmax with an initializer input, ", + "which should be optimized away by the ORT optimizer"); + + // Input is dynamic, so add transpose node before input. + const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name); + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(), + input_name, + op_input_name, + input_info.shape, + transpose_perm, + op_input_shape, + input_info.qnn_data_type, + input_info.quant_param, + do_op_validation, + is_graph_input)); + + Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, op_input_name); + QnnTensorWrapper input_tensorwrapper(op_input_name, tensor_type, input_info.qnn_data_type, input_info.quant_param, + std::move(op_input_shape), {}); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + + return Status::OK(); +} + +Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + const std::string& op_type = node_unit.OpType(); + const auto& outputs = node_unit.Outputs(); + assert(outputs.size() == 1); + + int32_t axis = GetDefaultAxisAttribute(node_unit.SinceVersion()); + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis)); + + OnnxInputInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(outputs[0], output_info)); + const size_t output_rank = output_info.shape.size(); + const bool axis_is_last_dim = static_cast(axis) == output_rank - 1; + + // If axis refers to the last dimension, process outputs as usual. + if (!is_npu_backend || axis_is_last_dim) { + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar); + + std::vector param_tensor_names; + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + return ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(op_type)); + } + + // + // The axis **does** not refer to the last dimension. Must wrap the operator with Transposes to be able to use + // QNN's Softmax operator, which only supports an axis that refers to the last dimension. + // + + axis_qnn_scalar.uint32Value = static_cast(output_rank - 1); // NOTE: override axis. + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar); + + std::vector param_tensor_names; + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + const std::string& orig_output_name = outputs[0].node_arg.Name(); + std::string op_output_name = orig_output_name + "_ort_qnn_ep_transpose"; + + std::vector op_output_shape = output_info.shape; + op_output_shape[output_rank - 1] = output_info.shape[axis]; + op_output_shape[axis] = output_info.shape[output_rank - 1]; + + QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, output_info.quant_param, + std::vector(op_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + GetQnnOpType(node_unit.OpType()), + std::move(input_names), + {op_output_name}, + std::move(param_tensor_names)), + "Failed to add node."); + + const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name); + std::vector transpose_perm = GetTransposePermToUseLastAxis(static_cast(output_rank), + static_cast(axis)); + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(), + op_output_name, + orig_output_name, + op_output_shape, + transpose_perm, + output_info.shape, + output_info.qnn_data_type, + output_info.quant_param, + do_op_validation, + false, + is_graph_output)); + + return Status::OK(); +} + +void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6cd9cbac72620..d497bc1c069d2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -238,22 +238,25 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, initializer_input_lookup, qnn_backend_manager_->GetQnnBackendType()); - for (const auto& node : graph_viewer.Nodes()) { - const NodeUnit* node_unit = node_unit_map.at(&node); + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + + const NodeUnit* node_unit = node_unit_map.at(node); const bool supported = IsNodeSupported(qnn_model_wrapper, *node_unit, node_unit_supported_result, logger); LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node.Index() - << "] name: [" << node.Name() - << "] Operator type: [" << node.OpType() + << "] index: [" << node->Index() + << "] name: [" << node->Name() + << "] Operator type: [" << node->OpType() << "] as part of the NodeUnit type: [" << node_unit->OpType() << "] index: [" << node_unit->Index() << "] name: [" << node_unit->Name() << "]"; if (supported) { - supported_nodes.insert(&node); + supported_nodes.insert(node); } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 774df067fe347..38266f566e6e1 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -142,5 +142,43 @@ bool IsValidMultidirectionalBroadcast(std::vector& shape_a, return true; } +bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { + // WebNN changed the name of the MLOperandDescriptor's data type from "type" to "dataType", + // use a duplicate entry temporarily to workaround this API breaking issue. + // TODO: Remove legacy "type" once all browsers implement the new "dataType". + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + desc.set("type", emscripten::val("uint8")); + desc.set("dataType", emscripten::val("uint8")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + desc.set("type", emscripten::val("float16")); + desc.set("dataType", emscripten::val("float16")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + desc.set("type", emscripten::val("float32")); + desc.set("dataType", emscripten::val("float32")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + desc.set("type", emscripten::val("int32")); + desc.set("dataType", emscripten::val("int32")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + desc.set("type", emscripten::val("int64")); + desc.set("dataType", emscripten::val("int64")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + desc.set("type", emscripten::val("uint32")); + desc.set("dataType", emscripten::val("uint32")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + desc.set("type", emscripten::val("uint64")); + desc.set("dataType", emscripten::val("uint64")); + return true; + default: + return false; + } +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index cdad9b22a8ab8..46c456556e016 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -231,5 +231,8 @@ bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_t bool IsValidMultidirectionalBroadcast(std::vector& shape_a, std::vector& shape_b, const logging::Logger& logger); + +bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 04e6d2b548aba..12c2cf6dd0a62 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -34,7 +34,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto rank = static_cast(input_shape.size()); emscripten::val desc = emscripten::val::object(); - desc.set("type", emscripten::val("int64")); + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type"); emscripten::val dims = emscripten::val::array(); dims.call("push", rank); desc.set("dimensions", dims); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 6a86ca7aca6e9..beee8b1d77cee 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -69,8 +69,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); std::vector permutation(input_shape.size()); std::iota(permutation.begin(), permutation.end(), 0); - permutation.erase(permutation.begin() + axis); - permutation.push_back(axis); + std::rotate(permutation.begin() + axis, permutation.begin() + axis + 1, permutation.end()); options.set("permutation", emscripten::val::array(permutation)); input = model_builder.GetBuilder().call("transpose", input, options); } @@ -87,7 +86,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, output = model_builder.GetBuilder().call("softmax", input); - // Transpose back to the axis. + // Restore from 2-D to the original shape. if (input_shape.size() != 2) { std::vector new_shape; std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), @@ -98,13 +97,12 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); } - // Reshape to the original shape. + // Restore the corresponding axis back to the initial position from the last position. if (axis != static_cast(input_shape.size() - 1)) { emscripten::val options = emscripten::val::object(); std::vector permutation(input_shape.size()); std::iota(permutation.begin(), permutation.end(), 0); - permutation.pop_back(); - permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + std::rotate(permutation.rbegin(), permutation.rbegin() + 1, permutation.rend() - axis); options.set("permutation", emscripten::val::array(permutation)); output = model_builder.GetBuilder().call("transpose", output, options); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 2eae8cebbbd66..0ac9fb7ff380d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -122,6 +122,7 @@ Status ModelBuilder::RegisterInitializers() { auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); if (IsSupportedDataType(data_type, wnn_device_type_)) { + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); unpacked_tensors_.push_back({}); std::vector& unpacked_tensor = unpacked_tensors_.back(); ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); @@ -129,37 +130,30 @@ Status ModelBuilder::RegisterInitializers() { emscripten::val view = emscripten::val::undefined(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - desc.set("type", emscripten::val("uint8")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - desc.set("type", emscripten::val("float16")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - desc.set("type", emscripten::val("float32")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - desc.set("type", emscripten::val("int32")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: - desc.set("type", emscripten::val("int64")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - desc.set("type", emscripten::val("uint32")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - desc.set("type", emscripten::val("uint64")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(unpacked_tensor.data()))}; break; @@ -238,35 +232,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i } data_type = type_proto->tensor_type().elem_type(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - desc.set("type", emscripten::val("uint8")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - desc.set("type", emscripten::val("float16")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - desc.set("type", emscripten::val("float32")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - desc.set("type", emscripten::val("int32")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - desc.set("type", emscripten::val("int64")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - desc.set("type", emscripten::val("uint32")); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - desc.set("type", emscripten::val("uint64")); - break; - default: { - // TODO: support other type. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The ", input_output_type, " of graph doesn't have valid type, name: ", name, - " type: ", type_proto->tensor_type().elem_type()); - } - } + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); } if (is_input) { @@ -316,41 +282,35 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( memcpy(dest, buffer, size); emscripten::val view = emscripten::val::undefined(); emscripten::val desc = emscripten::val::object(); + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("uint8")); break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("float16")); break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(float), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("float32")); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("int32")); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("int64")); break; case ONNX_NAMESPACE::TensorProto_DataType_UINT32: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("uint32")); break; case ONNX_NAMESPACE::TensorProto_DataType_UINT64: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint64_t), reinterpret_cast(dest))}; - desc.set("type", emscripten::val("uint64")); break; default: break; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 50c9f6681a0c8..4649ac35c3647 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -12,6 +12,10 @@ #include "core/session/ort_apis.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" +#if defined(USE_DML) +#include "core/providers/dml/dml_provider_factory_creator.h" +#endif + using namespace onnxruntime; namespace { @@ -67,7 +71,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; - if (strcmp(provider_name, "QNN") == 0) { + if (strcmp(provider_name, "DML") == 0) { +#if defined(USE_DML) + options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "QNN") == 0) { #if defined(USE_QNN) options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value))); #else diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 95a8f59186ff5..35e03bf9eacd5 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -879,18 +879,10 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; - auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - device_id = std::stoi(option.second); - } - } - } - } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + auto cit = provider_options_map.find(type); + return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second) + ->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index a4491d29b3698..cd7dc7017cf16 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -88,13 +88,17 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if instance_norm_bias is None: return - if not ( - len(instance_norm_scale.shape) == 1 - and len(instance_norm_bias.shape) == 1 - and instance_norm_scale.shape == instance_norm_bias.shape - and instance_norm_scale.shape[0] == 32 - ): - logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0]) + # Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32]. + if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32): + logger.debug( + "Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape) + ) + return + + if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32): + logger.debug( + "Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape) + ) return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -105,7 +109,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: - logger.info("GroupNorm channels=%d", weight_elements) + logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements) + return self.add_initializer( name=group_norm_name + "_gamma", diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 087b9d604128e..bc88f69fa990f 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -6,6 +6,7 @@ #include "TestCase.h" #include +#include #include #include #include @@ -185,7 +186,7 @@ void LoopDataFile(int test_data_pb_fd, bool is_input, const TestModelInfo& model f.SetCloseOnDelete(true); google::protobuf::io::CodedInputStream coded_input(&f); bool clean_eof = false; - int item_id = 1; + [[maybe_unused]] int item_id = 1; for (proto::TraditionalMLData data; ParseDelimitedFromCodedStream(&data, &coded_input, &clean_eof); ++item_id, data.Clear()) { @@ -731,6 +732,8 @@ void LoadTests(const std::vector>& input_paths const std::vector>& whitelisted_test_cases, const TestTolerances& tolerances, const std::unordered_set>& disabled_tests, + std::unique_ptr> broken_tests, + std::unique_ptr> broken_tests_keyword_set, const std::function)>& process_function) { std::vector> paths(input_paths); while (!paths.empty()) { @@ -783,11 +786,60 @@ void LoadTests(const std::vector>& input_paths ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported"); } + auto test_case_dir = model_info->GetDir(); + auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir; + +#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) + // to skip some models like *-int8 or *-qdq + if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || + (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain"); + return true; + } +#endif + + bool has_test_data = false; + LoopDir(test_case_dir, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool { + if (filename[0] == '.') return true; + if (f_type == OrtFileType::TYPE_DIR) { + has_test_data = true; + return false; + } + return true; + }); + if (!has_test_data) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data"); + return true; + } + + if (broken_tests) { + BrokenTest t = {ToUTF8String(test_case_name), ""}; + auto iter = broken_tests->find(t); + auto opset_version = model_info->GetNominalOpsetVersion(); + if (iter != broken_tests->end() && + (opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() || + iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests"); + return true; + } + } + + if (broken_tests_keyword_set) { + for (auto iter2 = broken_tests_keyword_set->begin(); iter2 != broken_tests_keyword_set->end(); ++iter2) { + std::string keyword = *iter2; + if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords"); + return true; + } + } + } + const auto tolerance_key = ToUTF8String(my_dir_name); std::unique_ptr l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info), tolerances.absolute(tolerance_key), tolerances.relative(tolerance_key)); + fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str()); process_function(std::move(l)); return true; }); @@ -1178,6 +1230,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"candy", "Temporarily disabled pending investigation"}); broken_tests->insert({"BERT_Squad", "Temporarily disabled pending investigation"}); broken_tests->insert({"LSTM_Seq_lens_unpacked", "The parameter is incorrect"}); + broken_tests->insert({"mlperf_ssd_resnet34_1200", "The parameter is incorrect"}); broken_tests->insert({"resize_downsample_scales_linear", "DML uses half_pixel and this test assumed \"asymmetric\" but does not include \"mode\""}); diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h index 4d4b2177019c9..96b0b5f6f7c08 100644 --- a/onnxruntime/test/onnx/TestCase.h +++ b/onnxruntime/test/onnx/TestCase.h @@ -101,12 +101,6 @@ class TestTolerances { const Map relative_overrides_; }; -void LoadTests(const std::vector>& input_paths, - const std::vector>& whitelisted_test_cases, - const TestTolerances& tolerances, - const std::unordered_set>& disabled_tests, - const std::function)>& process_function); - struct BrokenTest { std::string test_name_; std::string reason_; @@ -118,6 +112,16 @@ struct BrokenTest { } }; +void LoadTests(const std::vector>& input_paths, + const std::vector>& whitelisted_test_cases, + const TestTolerances& tolerances, + const std::unordered_set>& disabled_tests, + std::unique_ptr> broken_test_list, + std::unique_ptr> broken_tests_keyword_set, + const std::function)>& process_function); + std::unique_ptr> GetBrokenTests(const std::string& provider_name); std::unique_ptr> GetBrokenTestsKeyWordSet(const std::string& provider_name); + +std::unique_ptr> GetBrokenTestsKeyWordSet(const std::string& provider_name); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index f165b3a4a647a..de5431ca4a460 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -783,10 +783,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); all_disabled_tests.insert(std::begin(x86_disabled_tests), std::end(x86_disabled_tests)); #endif + auto broken_tests = GetBrokenTests(provider_name); + auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet(provider_name); std::vector tests; LoadTests(data_dirs, whitelisted_test_cases, LoadTestTolerances(enable_cuda, enable_openvino, override_tolerance, atol, rtol), all_disabled_tests, + std::move(broken_tests), + std::move(broken_tests_keyword_set), [&owned_tests, &tests](std::unique_ptr l) { tests.push_back(l.get()); owned_tests.push_back(std::move(l)); @@ -803,18 +807,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); fwrite(res.c_str(), 1, res.size(), stdout); } - auto broken_tests = GetBrokenTests(provider_name); int result = 0; for (const auto& p : stat.GetFailedTest()) { - BrokenTest t = {p.first, ""}; - auto iter = broken_tests->find(t); - if (iter == broken_tests->end() || (p.second != TestModelInfo::unknown_version && !iter->broken_opset_versions_.empty() && - iter->broken_opset_versions_.find(p.second) == iter->broken_opset_versions_.end())) { - fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str()); - result = -1; - } else { - fprintf(stderr, "test %s failed, but it is a known broken test, so we ignore it\n", p.first.c_str()); - } + fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str()); + result = -1; } return result; } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 57b2403e23a37..1111a92a385fd 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -662,9 +662,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); + std::unordered_map dml_options; + dml_options["performance_preference"] = "high_performance"; + dml_options["device_filter"] = "gpu"; + session_options.AppendExecutionProvider("DML", dml_options); #else - ORT_THROW("DirectML is not supported in this build\n"); + ORT_THROW("DML is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index be8afa7636b3d..e024eafcd6572 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -447,8 +447,9 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Log_U16) { // Check that QNN compiles DQ -> Softmax -> Q as a single unit. // Test that the default axis (-1) for SoftMax opset 13 works. TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_DefaultAxis) { + const std::vector input_data = GetFloatDataInRange(-5.0f, 5.0f, 6); RunQDQOpTest("Softmax", - {TestInputDef({1, 2, 3}, false, -5.0f, 5.0f)}, + {TestInputDef({1, 2, 3}, false, input_data)}, {}, // Uses default axis of -1 for opset 13 13, ExpectedEPNodeAssignment::All); @@ -466,14 +467,43 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_U16_DefaultAxis) { true); // Use com.microsoft domain for Q/DQ ops } -// Check that QNN compiles DQ -> Softmax -> Q as a single unit. -// Test that an axis != -1 is not supported. -TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_UnsupportedAxis) { +// Test that 8-bit QDQ Softmax (opset 13) with axis != -1 is supported by QNN EP. +// QNN EP will wrap the operator with transposes. +TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_NonLastAxis) { + const std::vector input_data = {0.0f, 1.0f, 2.0f, 10.0f, 11.0f, 12.0f, 100.0f, 110.0f, 120.0f, + 1.0856307f, 0.99734545f, 0.2829785f, 1.5062947f, 0.5786002f, 1.6514366f, + 2.4266791f, 0.42891264f, 1.2659363f}; RunQDQOpTest("Softmax", - {TestInputDef({1, 2, 3}, false, -5.0f, 5.0f)}, + {TestInputDef({1, 2, 3, 3}, false, input_data)}, {utils::MakeAttribute("axis", static_cast(1))}, 13, - ExpectedEPNodeAssignment::None); + ExpectedEPNodeAssignment::All); +} + +// Test that 8-bit QDQ Softmax (opset 13) with axis != -1 is supported by QNN EP. +// QNN EP will wrap the operator with transposes. +// This is a configuration used in one of our partner's models. +TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_NonLastAxis_LargeInput) { + const std::vector input_data = GetFloatDataInRange(-50.0f, 50.0f, 124); + RunQDQOpTest("Softmax", + {TestInputDef({1, 124, 1}, false, input_data)}, + {utils::MakeAttribute("axis", static_cast(1))}, + 13, + ExpectedEPNodeAssignment::All); +} + +// Test that 16-bit QDQ Softmax (opset 13) with axis != -1 is supported by QNN EP. +// QNN EP will wrap the operator with transposes. +// This is a configuration used in one of our partner's models. +TEST_F(QnnHTPBackendTests, UnaryOp_Softmax13_U16_NonLastAxis_LargeInput) { + const std::vector input_data = GetFloatDataInRange(-50.0f, 50.0f, 124); + RunQDQOpTest("Softmax", + {TestInputDef({1, 124, 1}, false, input_data)}, + {utils::MakeAttribute("axis", static_cast(1))}, + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); } // Check that QNN compiles DQ -> Softmax -> Q as a single unit. @@ -507,15 +537,15 @@ TEST_F(QnnHTPBackendTests, UnaryOp_LogSoftmax13_DefaultAxis) { ExpectedEPNodeAssignment::All); } -// Check that QNN compiles DQ -> LogSoftmax -> Q as a single unit. -// Test that an axis != -1 is not supported. -TEST_F(QnnHTPBackendTests, UnaryOp_LogSoftmax13_UnsupportedAxis) { +// Test that 8-bit QDQ LogSoftmax (opset 13) with axis != -1 is supported by QNN EP. +// QNN EP will wrap the operator with transposes. +TEST_F(QnnHTPBackendTests, UnaryOp_LogSoftmax13_NonLastAxis) { std::vector input_data = GetFloatDataInRange(-5.0f, 5.0f, 6); RunQDQOpTest("LogSoftmax", {TestInputDef({1, 2, 3}, false, input_data)}, {utils::MakeAttribute("axis", static_cast(1))}, 13, - ExpectedEPNodeAssignment::None); + ExpectedEPNodeAssignment::All); } // Check that QNN compiles DQ -> LogSoftmax -> Q as a single unit. diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index 7f3cbc254969e..1baec80cb7c45 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -6,7 +6,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT +from onnxscript import FLOAT, INT64 import onnxruntime as ort @@ -18,7 +18,7 @@ def shard_tensor(X, rank, axis, num_shards): return np.split(X, num_shards, axis)[rank] -class TestDistributedMatMul(unittest.TestCase): +class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): @onnxscript.script() def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: @@ -312,6 +312,99 @@ def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: expected = np.matmul(tensor_x, tensor_w) np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + def test_slice_sr_axis1(self): + @onnxscript.script() + def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: + return MICROSOFT_OPSET.DistributedSlice( + tensor_x, + tensor_starts, + tensor_ends, + tensor_axes, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["S[0]R", "R", "R", "R", "R"], + output_shard_specs=["S[0]R"], + ) + + rank = comm.Get_rank() + # Shape [2, 4] + tensor_x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32) + tensor_starts = np.array([0], dtype=np.int64) + tensor_ends = np.array([2], dtype=np.int64) + tensor_axes = np.array([1], dtype=np.int64) + + onnx_model = slice_sr_axis1.to_model_proto( + input_types=[FLOAT[1, 4], INT64[1], INT64[1], INT64[1]], + output_types=[FLOAT[1, 2]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=0, num_shards=2) + + result = sess.run( + None, + { + "tensor_x": tensor_shard_x, + "tensor_starts": tensor_starts, + "tensor_ends": tensor_ends, + "tensor_axes": tensor_axes, + }, + ) + + expected = tensor_shard_x[:, 0:2] + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_slice_rs_axis1(self): + @onnxscript.script() + def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: + return MICROSOFT_OPSET.DistributedSlice( + tensor_x, + tensor_starts, + tensor_ends, + tensor_axes, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RS[0]", "R", "R", "R", "R"], + output_shard_specs=["RS[0]"], + ) + + rank = comm.Get_rank() + # Shape [2, 4] + tensor_x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32) + tensor_starts = np.array([0], dtype=np.int64) + tensor_ends = np.array([2], dtype=np.int64) + tensor_axes = np.array([1], dtype=np.int64) + + onnx_model = slice_sr_axis1.to_model_proto( + input_types=[FLOAT[2, 2], INT64[1], INT64[1], INT64[1]], + output_types=[FLOAT[2, 1]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=1, num_shards=2) + result = sess.run( + None, + { + "tensor_x": tensor_shard_x, + "tensor_starts": tensor_starts, + "tensor_ends": tensor_ends, + "tensor_axes": tensor_axes, + }, + ) + + expected = tensor_x[:, 0:2][:, rank : rank + 1] + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 8e21013da2353..0bf402b750115 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -280,7 +280,7 @@ def ReduceKernelNode( # noqa: N802 "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", - "DropoutGrad": "{indent}p = 1 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n", + "DropoutGrad": "{indent}p = 1.0 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n", "Identity": "{indent}{o0} = {i0}\n", } @@ -420,7 +420,7 @@ def DropoutNode( # noqa: N802 offset_str = f"{node.global_offset} + " if node.global_offset != sympy.Integer(0) else "" offset_str += self._get_offset_mask(node.offset_calc, node.inputs[0].name)[0] code_buffer += ( - f"{space_indent}p = 1 - {p_var_name}\n" + f"{space_indent}p = 1.0 - {p_var_name}\n" f"{space_indent}random = tl.rand(t_seed_cuda, {offset_str})\n" f"{space_indent}{mask_var_name} = random < p\n" f"{space_indent}{output_var_name} = tl.where({mask_var_name}, {input_var_name} / p, 0.0)\n" diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index bc50e4fb0a943..66b6a36e5a8c0 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -55,7 +55,7 @@ WORKDIR /workspace # install Android SDK and tools ENV ANDROID_HOME=~/android-sdk -ENV NDK_VERSION=25.0.8775105 +ENV NDK_VERSION=26.0.10792818 ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${NDK_VERSION} RUN aria2c -q -d /tmp -o cmdline-tools.zip \ diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index 2e7ac9508a41e..588b5d049ee3c 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -9,11 +9,6 @@ parameters: - 'custom' default: 'nightly (@dev)' -- name: NodePipelineId - displayName: 'Node npm package build Id' - type: string - default: 'latest' - variables: # pipeline should define the following varaibles # ExtraBuildArgs @@ -29,6 +24,11 @@ variables: NpmPackagingMode: '$(VersionSuffix)' resources: + pipelines: + - pipeline: build + source: 'Zip-Nuget-Java-Nodejs Packaging Pipeline' + trigger: true + branch: main repositories: - repository: manylinux type: Github @@ -65,35 +65,13 @@ stages: runCodesignValidationInjection: false timeoutInMinutes: 10 steps: - - - ${{ if eq(parameters.NodePipelineId, 'latest') }}: - - task: DownloadPipelineArtifact@2 - inputs: - buildType: 'specific' - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: '940' - specificBuildWithTriggering: true - buildVersionToDownload: 'latestFromBranch' - branchName: 'refs/heads/main' - artifactName: 'NPM_packages' - targetPath: '$(Pipeline.Workspace)' - displayName: 'Download onnxruntime-node Pipeline Artifact' - - - ${{ if ne(parameters.NodePipelineId, 'latest') }}: - - task: DownloadPipelineArtifact@2 - inputs: - buildType: 'specific' - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: '940' - buildVersionToDownload: 'specific' - pipelineId: '${{ parameters.NodePipelineId }}' - artifactName: 'NPM_packages' - targetPath: '$(Pipeline.Workspace)' - displayName: 'Download onnxruntime-node Pipeline Artifact' + - download: build + artifact: 'NPM_packages' + displayName: 'Download onnxruntime-node Pipeline Artifact' - task: CopyFiles@2 inputs: - sourceFolder: $(Pipeline.Workspace) + sourceFolder: '$(Pipeline.Workspace)\build\NPM_packages' contents: onnxruntime-*.tgz targetFolder: $(Build.ArtifactStagingDirectory)\node-artifacts displayName: 'Copy onnxruntime-node Artifacts' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index b1e36e63e86ab..81e8d67b79021 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -231,6 +231,15 @@ stages: searchPattern: '**/*.pdb' symbolServerType: teamServices + - ${{ if eq(parameters['DoCompliance'], 'true') }}: + - template: ../../templates/compliance.yml + parameters : + msbuildPlatform: ${{ parameters.sln_platform }} + + - template: ../../templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + # Node.js Publish - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - task: BatchScript@1 @@ -285,15 +294,6 @@ stages: targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.sln_platform }}-dml' - - ${{ if eq(parameters['DoCompliance'], 'true') }}: - - template: ../../templates/compliance.yml - parameters : - msbuildPlatform: ${{ parameters.sln_platform }} - - - template: ../../templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml index 0e034dff9d0b2..8cc7f63a193cc 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml @@ -3,7 +3,7 @@ parameters: - name: AndroidNdkVersion type: string - default: "25.0.8775105" # LTS version + default: "26.0.10792818" # LTS version steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 788b02f539821..187c7656602f5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -171,6 +171,10 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + # temporarily allow this test to fail, so that people are not blocked. + # investigation is ongoing for the root cause of the random failure (Edge crash). + # TODO: remove this line once the root cause is found and fixed. + continueOnError: true - script: | npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 5057f74046638..cb5cbbecb5ef0 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -238,11 +238,16 @@ static std::vector GetAllTestCases() { // Bad onnx test output caused by previously wrong SAME_UPPER/SAME_LOWER for ConvTranspose allDisabledTests.insert(ORT_TSTR("cntk_simple_seg")); + auto broken_tests = GetBrokenTests("dml"); + auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet("dml"); + WINML_EXPECT_NO_THROW(LoadTests( dataDirs, whitelistedTestCases, TestTolerances(1e-3, 1e-3, {}, {}), allDisabledTests, + std::move(broken_tests), + std::move(broken_tests_keyword_set), [&tests](std::unique_ptr l) { tests.push_back(l.get()); ownedTests.push_back(std::move(l));