From 9f68a27c7a3a932a574d50db19f40393a0cedf81 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 30 Jan 2024 17:04:01 +0800 Subject: [PATCH 01/11] [ORTModule] Handle Cast on Constant Number on Triton Code-gen (#19321) When using scaled_dot_product_attention on float16 type, the exported graph has Sqrt(float16(constant)), which cannot be ConstantFold in ORT because Sqrt CPU kernel doesn't support float16. This causes Triton code-gen generates code like: result = 128.0.to(tl.float32) This code cannot be compiled because .to() cannot be applied to constant. This PR is to handle such case that constant number will not do the Cast. --- .../python/training/ort_triton/_codegen.py | 4 +-- .../python/training/ort_triton/_utils.py | 8 ++++++ .../orttraining_test_ortmodule_triton.py | 27 +++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index e0f65ed272d38..9c7214f467af1 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -37,7 +37,7 @@ from ._lowering import lower from ._sorted_graph import SortedGraph from ._sympy_utils import parse_shape, sympy_dot -from ._utils import may_add_brackets +from ._utils import is_number, may_add_brackets class TritonCodegen(NodeVisitor): @@ -318,7 +318,7 @@ def ComputeNode( # noqa: N802 if op_type == "Cast": from_dtype = node.inputs[0].dtype.type to_dtype = node.outputs[0].dtype.type - if from_dtype == to_dtype: + if from_dtype == to_dtype or is_number(kwargs["i0"]): op_type = "Identity" elif to_dtype == np.bool_: op_type = "CastBool" diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index c80e28f6f73df..95e6703be8783 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -150,3 +150,11 @@ def next_power_of_2(n: int) -> int: n |= n >> 16 n += 1 return n + + +def is_number(name: str) -> bool: + try: + float(name) + return True + except ValueError: + return name.startswith("float(") and name.endswith(")") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 0c381d70ca4c1..922f5c696500d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -12,6 +12,7 @@ import pytest import torch from onnx import TensorProto, helper +from packaging.version import Version from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -842,6 +843,32 @@ def _gen_inputs(dtype): _run_module_test(NeuralNetSliceScel, dtype, _gen_inputs, 2) +@pytest.mark.skipif( + Version(torch.__version__) < Version("2.1"), reason="PyTorch has scaled_dot_product_attention since 2.1." +) +def test_scaled_dot_product_attention_module(): + class NeuralNetScaledDotProductAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) + self.linear2 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) + self.linear3 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( + self.linear1(q), self.linear2(k), self.linear3(v) + ).to(torch.float16) + + def _gen_inputs(dtype): + return [ + (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), + (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), + (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), + ] + + _run_module_test(NeuralNetScaledDotProductAttention, torch.float16, _gen_inputs, 3) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("input_shapes", [([128, 64], [64, 64]), ([16, 64, 128], [16, 128, 64])]) def test_matmul_tunable_op(dtype, input_shapes): From a92802f9403e3ca7313e7d29f663038669bffc57 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 30 Jan 2024 08:16:57 -0800 Subject: [PATCH 02/11] Disable a few tests for wasm build (#19316) --- cmake/onnxruntime_unittests.cmake | 5 ++++- onnxruntime/test/unittest_main/test_main.cc | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0987d6d164dbd..351ea1a95581b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -824,6 +824,9 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/providers/memcpy_test.cc" ) endif() + list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" + "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc" + "${TEST_SRC_DIR}/providers/cpu/math/einsum_test.cc") endif() set(test_all_args) @@ -906,7 +909,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 97169df36fdd7..4c38c90c2b418 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -59,8 +59,8 @@ int TEST_MAIN(int argc, char** argv) { int status = 0; ORT_TRY { - ::testing::InitGoogleTest(&argc, argv); ortenv_setup(); + ::testing::InitGoogleTest(&argc, argv); // allow verbose logging to be enabled by setting this environment variable to a numeric log level constexpr auto kLogLevelEnvironmentVariableName = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; From 3e17ca3dabd76d370827ef119f092be1b85422ea Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Tue, 30 Jan 2024 08:44:20 -0800 Subject: [PATCH 03/11] Fix iOS artifacts issue in Microsoft.ML.OnnxRuntime Nuget Package (#19311) ### Description Updates to only include ios archs framework in artifacts included in Nuget Package. ### Motivation and Context Related issue: https://github.com/microsoft/onnxruntime/issues/19295#issuecomment-1914143256 --------- Co-authored-by: rachguo Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../apple/apple_package_test/Podfile.template | 6 +++- ...ult_full_ios_framework_build_settings.json | 30 +++++++++++++++++++ .../github/apple/test_apple_packages.py | 13 ++++++-- .../azure-pipelines/templates/c-api-cpu.yml | 23 +++++++------- 4 files changed, 58 insertions(+), 14 deletions(-) create mode 100644 tools/ci_build/github/apple/default_full_ios_framework_build_settings.json diff --git a/onnxruntime/test/platform/apple/apple_package_test/Podfile.template b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template index 3d191d6fb1cc6..4958e4fa85490 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/Podfile.template +++ b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template @@ -1,6 +1,10 @@ def include_macos_target if '@C_POD_NAME@' != 'onnxruntime-mobile-c' - return true + if ENV['SKIP_MACOS_TEST'] != 'true' + return true + else + return false + end end return false end diff --git a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json new file mode 100644 index 0000000000000..445bfca9889ff --- /dev/null +++ b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json @@ -0,0 +1,30 @@ +{ + "build_osx_archs": { + "iphoneos": [ + "arm64" + ], + "iphonesimulator": [ + "arm64", + "x86_64" + ] + }, + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--use_coreml", + "--use_xnnpack", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" + ], + "iphoneos": [ + "--ios", + "--apple_deploy_target=12.0" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ] + } +} diff --git a/tools/ci_build/github/apple/test_apple_packages.py b/tools/ci_build/github/apple/test_apple_packages.py index 6dc4868dac8a3..cd360a63a3a0f 100644 --- a/tools/ci_build/github/apple/test_apple_packages.py +++ b/tools/ci_build/github/apple/test_apple_packages.py @@ -112,7 +112,10 @@ def _test_apple_packages(args): subprocess.run(["pod", "cache", "clean", "--all"], shell=False, check=True, cwd=target_proj_path) # install pods - subprocess.run(["pod", "install"], shell=False, check=True, cwd=target_proj_path) + # set env to skip macos test targets accordingly + env = os.environ.copy() + env["SKIP_MACOS_TEST"] = "true" if args.skip_macos_test else "false" + subprocess.run(["pod", "install"], shell=False, check=True, cwd=target_proj_path, env=env) # run the tests if not args.prepare_test_project_only: @@ -144,7 +147,7 @@ def _test_apple_packages(args): cwd=target_proj_path, ) - if PackageVariant[args.variant] != PackageVariant.Mobile: + if PackageVariant[args.variant] != PackageVariant.Mobile and not args.skip_macos_test: subprocess.run( [ "xcrun", @@ -206,6 +209,12 @@ def parse_args(): help="Prepare the test project only, without running the tests", ) + parser.add_argument( + "--skip_macos_test", + action="store_true", + help="Skip macos platform tests. Specify this argument when build targets only contain ios archs. ", + ) + return parser.parse_args() diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 168602a17910b..8bdb395c00dc3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -119,31 +119,32 @@ stages: - script: | set -e -x python3 tools/ci_build/github/apple/build_apple_framework.py \ - --build_dir "$(Build.BinariesDirectory)/apple_framework" \ + --build_dir "$(Build.BinariesDirectory)/ios_framework" \ --path_to_protoc_exe $(Build.BinariesDirectory)/protobuf_install/bin/protoc \ - tools/ci_build/github/apple/default_full_apple_framework_build_settings.json + tools/ci_build/github/apple/default_full_ios_framework_build_settings.json mkdir $(Build.BinariesDirectory)/artifacts - mkdir -p $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) - cp -R $(Build.BinariesDirectory)/apple_framework/framework_out/onnxruntime.xcframework \ - $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) + mkdir -p $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) + cp -R $(Build.BinariesDirectory)/ios_framework/framework_out/onnxruntime.xcframework \ + $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) pushd $(Build.BinariesDirectory)/artifacts_staging zip -vr $(Build.BinariesDirectory)/artifacts/onnxruntime_xcframework.zip \ - onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) + onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) popd displayName: "Build Apple xcframework" - script: | python3 tools/ci_build/github/apple/test_apple_packages.py \ --fail_if_cocoapods_missing \ - --framework_info_file "$(Build.BinariesDirectory)/apple_framework/xcframework_info.json" \ - --c_framework_dir "$(Build.BinariesDirectory)/apple_framework/framework_out" \ - --variant Full + --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \ + --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \ + --variant Full \ + --skip_macos_test displayName: "Test Apple framework" - task: PublishBuildArtifacts@1 inputs: pathtoPublish: '$(Build.BinariesDirectory)/artifacts' - artifactName: 'onnxruntime-apple-full-xcframework' + artifactName: 'onnxruntime-ios-full-xcframework' - template: component-governance-component-detection-steps.yml parameters: @@ -350,7 +351,7 @@ stages: - template: flex-downloadPipelineArtifact.yml parameters: StepName: 'Download iOS Pipeline Artifact' - ArtifactName: 'onnxruntime-apple-full-xcframework' + ArtifactName: 'onnxruntime-ios-full-xcframework' TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} From ffc3431a660ba2fe3fb220be24f0ff3260d828bd Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 30 Jan 2024 09:18:50 -0800 Subject: [PATCH 04/11] Update ScatterElements to Support Opset 13, 15, 18 (#19198) `ScatterElements` in opset 18 has been around for a while. However, the highest opset supporting `ScatterElements` in ORT is 13. This PR implement this op in CUDA EP by replacing `assignment` in the current CDUA kernel with `atomic reduction` (e.g., atomic add, atomic max). A series of fundamental atomic functions (e.g., atomic max for int8_t and half) are implemented in `common.cuh`; the implementation is general enough to cover old CUDA and new CUDA versions. - The core changes are in `cuda/atomic/common.cuh` with very detailed documentation including `bit-wise operation's visualization`. They are also copied to `rocm/atomic/common.cuh` to support AMD GPU. - `/cuda/tensor/gather_elements_impl.cu` contains small changes to call the new atomic functions to support new `reduction` behavior in new `ScatterElements`. - New `ScatterElements` are defined in `rocm_execution_provider.cc` and `cuda_execution_provider.cc`. --- docs/OperatorKernels.md | 4 +- .../core/providers/cpu/tensor/scatter.cc | 14 - .../core/providers/cuda/atomic/common.cuh | 311 ++++++++++++++++++ .../providers/cuda/cuda_execution_provider.cc | 8 +- .../cuda/tensor/gather_elements_impl.cu | 52 ++- .../cuda/tensor/gather_elements_impl.h | 11 + .../providers/cuda/tensor/scatter_elements.cc | 32 +- .../providers/cuda/tensor/scatter_elements.h | 10 + .../core/providers/rocm/atomic/common.cuh | 299 +++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 10 +- .../providers/cpu/tensor/scatter_op_test.cc | 132 ++++++++ 11 files changed, 858 insertions(+), 25 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9d9b266355335..2ea557b7d61fe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -744,7 +744,9 @@ Do not modify directly.* |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 8844b7e7a26c4..c7a2005924836 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -198,13 +198,6 @@ struct Func_Min { } }; -template <> -struct Func_Min { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); - } -}; - template <> struct Func_Min { void operator()(BFloat16*, const BFloat16*) const { @@ -233,13 +226,6 @@ struct Func_Max { } }; -template <> -struct Func_Max { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); - } -}; - template <> struct Func_Max { void operator()(BFloat16*, const BFloat16*) const { diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 14fa2d0706f73..170aa3a2d8d0c 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). +// +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & AtomicCasType::mask; + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); +} + +// It accumulates `val` into the `address` using the `func`. +// This function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, assumed, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Record the value used to compute new value. + assumed = observed; + + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != assumed); +} + +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b > a ? b : a; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b < a ? b : a; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +} +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MulFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +#endif +} +__device__ __forceinline__ void atomic_max(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MaxFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +#endif +} +__device__ __forceinline__ void atomic_min(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MinFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +#endif +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3fc4ed355a12b..77e682e05a2a4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1046,7 +1046,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1254,6 +1254,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1269,6 +1270,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); @@ -1937,7 +1939,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2138,6 +2140,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2159,6 +2162,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 10c8625b39ef8..b710e8a1b48c2 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -95,7 +95,37 @@ struct OffsetCalculatorFor2D { template struct FuncAssignment { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + start_addr[index] = value; + } +}; + +template +struct FuncAdd { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_add(start_addr + index, value); + } +}; + +template +struct FuncMul { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_mul(start_addr + index, value); + } +}; + +template +struct FuncMax { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_max(start_addr + index, value); + } +}; + +template +struct FuncMin { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_min(start_addr + index, value); + } }; template @@ -238,8 +268,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con template Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, T* output_data, const GatherScatterElementsArgs& args) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAssignment()); + if (args.operation == GatherScatterElementsArgs::Operation::NONE) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAssignment()); + } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAdd()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMul()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMax()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMin()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); + } } #define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h index 631d0bf049c6f..7b1c88f1fc1cb 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h @@ -10,6 +10,14 @@ namespace onnxruntime { namespace cuda { struct GatherScatterElementsArgs { + enum class Operation { + NONE, + ADD, + MUL, + MAX, + MIN + }; + int64_t rank; int64_t axis; int64_t input_size; @@ -19,6 +27,9 @@ struct GatherScatterElementsArgs { TArray indices_fdms; TArray indices_strides; int64_t indices_size; + // operation used to combine values associated the same + // memory location in the output tensor. + Operation operation; }; template diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index e4d145154971e..42a9f50001103 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe DataTypeImpl::GetTensorType()}), ScatterElements); -ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), @@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector(); CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args); + if (reduction_ == "none") { + args.operation = GatherScatterElementsArgs::Operation::NONE; + } else if (reduction_ == "add") { + args.operation = GatherScatterElementsArgs::Operation::ADD; + } else if (reduction_ == "mul") { + args.operation = GatherScatterElementsArgs::Operation::MUL; + } else if (reduction_ == "min") { + args.operation = GatherScatterElementsArgs::Operation::MIN; + } else if (reduction_ == "max") { + args.operation = GatherScatterElementsArgs::Operation::MAX; + } else { + ORT_THROW("Unsupported reduction type"); + } + // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. int dtype = GetElementType(input_tensor->DataType()->Size()); if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h index 3e9e0ce041845..3884b716da308 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel { ScatterElements(const OpKernelInfo& info) : CudaKernel(info) { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); + reduction_ = info.GetAttrOrDefault("reduction", "none"); + + ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" || + reduction_ == "mul" || reduction_ == "max" || + reduction_ == "min", + "Invalid reduction attribute value of ", reduction_); } ~ScatterElements() = default; Status ComputeInternal(OpKernelContext* context) const override; @@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel { struct ComputeImpl; int64_t axis_; + // "reduction" attribute has been defined since opset 13 but + // we never implemented it. Let's try to support them starting + // with opset 18. + std::string reduction_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh index 4e235702028c6..b5d01b91c70ed 100644 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ b/onnxruntime/core/providers/rocm/atomic/common.cuh @@ -59,5 +59,304 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz atomic_add(start_addr + index, value); } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). +// +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & AtomicCasType::mask; + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); +} + +// It accumulates `val` into the `address` using the `func`. +// This function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, assumed, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Record the value used to compute new value. + assumed = observed; + + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != assumed); +} + +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b > a ? b : a; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b < a ? b : a; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +} +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(half* address, half value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(half* address, half value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index fff3d14b763d5..ee3578326ac6d 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1069,7 +1069,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1290,6 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1302,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); - +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2004,7 +2005,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2225,6 +2226,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2237,7 +2239,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 9b44bf400c05e..30e27bb15fa57 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -302,5 +302,137 @@ TEST(Scatter, BoolInputWithAxis) { scatter_bool_with_axis_tests("ScatterElements", 11); } +TEST(ScatterElements, AddReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "add"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, AddReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "add"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MulReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "mul"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MulReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "mul"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f})); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); + test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f})); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_Float) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_Double) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f})); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); + test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f})); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_Float) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_Double) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + } // namespace test } // namespace onnxruntime From b84cb247e3ef06639925120a84838ab970ef6843 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Tue, 30 Jan 2024 10:25:14 -0800 Subject: [PATCH 05/11] io_binding to handle optional input of sequence type_proto (#19273) --- onnxruntime/python/onnxruntime_pybind_mlvalue.cc | 7 ++++++- .../test/python/onnxruntime_test_python.py | 8 ++++++++ onnxruntime/test/testdata/identity_opt.onnx | Bin 0 -> 133 bytes 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/identity_opt.onnx diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index f470e9f6b6ed1..0bbcee12ea5cf 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -659,7 +659,12 @@ static bool CheckIfInputIsSequenceType(const std::string& name_input, if (!temp) { throw std::runtime_error("Corresponding type_proto is null"); } else { - type_proto = *temp; + if (temp->has_optional_type()) { + const ::onnx::TypeProto_Optional& optional_type_proto = temp->optional_type(); + type_proto = optional_type_proto.elem_type(); + } else { + type_proto = *temp; + } } return type_proto.has_sequence_type(); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e210917e7ad9a..68e441c87860e 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -650,6 +650,14 @@ def do_test_get_and_set_tuning_results(ep): if "ROCMExecutionProvider" in onnxrt.get_available_providers(): do_test_get_and_set_tuning_results("ROCMExecutionProvider") + def test_run_model_with_optional_sequence_input(self): + sess = onnxrt.InferenceSession(get_name("identity_opt.onnx")) + x = [np.array([1, 2, 3, 4, 5]).astype(np.float32)] + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + res = sess.run([output_name], {input_name: x}) + np.testing.assert_allclose(res[0], x) + def test_run_model(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) diff --git a/onnxruntime/test/testdata/identity_opt.onnx b/onnxruntime/test/testdata/identity_opt.onnx new file mode 100644 index 0000000000000000000000000000000000000000..24c05f7b7227f6f91601bc0490c5a24b493774ca GIT binary patch literal 133 zcmd Date: Tue, 30 Jan 2024 10:53:10 -0800 Subject: [PATCH 06/11] Windows - Only set thread affinity on Server with auto affinity (#19318) ### Description Only set thread affinity on Server with auto affinity. Auto affinity = when API user does specify thread settings or affinity themselves. ### Motivation and Context On client best to let OS scheduler handle. On big (P-Core) / little (E-Core) CPU designs affinity overrides win32 Quality of Service (QoS) and has high power usage. Specifically on background workloads whose process is tagged QoS Utility (Background), this affinity setting overrides the OS scheduler that only wants to schedule on the E-Cores. Thus P-Cores waking up uses more energy than intended on client and users gets less battery life. Foreground AI workloads would be tagged QoS High and would run the ORT threads on all cores. --- onnxruntime/core/util/thread_utils.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index 48f58add8237b..a5a165e150cf1 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -7,6 +7,7 @@ #ifdef _WIN32 #include +#include #endif #include #include "core/session/ort_apis.h" @@ -98,7 +99,16 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { } options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { +#ifdef _WIN32 + // Only set thread affinity on Server with auto affinity. + // On client best to let OS scheduler handle. + // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage + if (IsWindowsServer()) { + to.affinities = std::move(default_affinities); + } +#else to.affinities = std::move(default_affinities); +#endif } } if (options.thread_pool_size <= 1) { From febec1c5860c9b39e7ddd7167ea3cfa28ec2d2db Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:59:15 -0800 Subject: [PATCH 07/11] Update Whisper export with beam search (#19322) ### Description This PR updates the Whisper export with beam search by adding the following. - Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the Whisper with beam search model - Sets the default PyTorch attention implementation to `eager` to allow existing attention fusions to continue working - Re-uses the cache directory when loading the PyTorch model to reduce memory used on disk - Adds `--disable_auto_mixed_precision` to the example FP16 export command ### Motivation and Context - [This PR](https://github.com/microsoft/onnxruntime/pull/19112) added the `is_unidirectional` parameter to `CheckInputs`, but it was not provided when checking the inputs in `DecoderMaskedMultiHeadAttention`. - [This PR](https://github.com/microsoft/onnxruntime/pull/19200) explains the reasoning behind why `eager` is used to load the `WhisperAttention` class. - By re-using the cache directory for loading the PyTorch model, only one copy of the PyTorch model is saved on disk instead of two copies. - By providing this flag, there will be less Cast nodes in the Whisper with beam search model to switch between FP16 and FP32 precision. --- .../bert/decoder_masked_multihead_attention.cc | 2 ++ .../tools/transformers/models/whisper/README.md | 4 ++-- .../models/whisper/convert_to_onnx.py | 2 +- .../transformers/models/whisper/whisper_helper.py | 15 +++++++++++++-- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index a9b60da0c96ca..66c0aceaed1e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + bool is_unidirectional = false; bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* num_heads_, mask_filter_value_, scale_, + is_unidirectional, past_present_share_buffer_, is_dmmha_packing, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 8ff5c8a6e1de0..02100266200f8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 50637b772c233..e15a12c07bed7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -478,7 +478,7 @@ def main(argv=None): # Wrap parity check in try-except to allow export to continue in case this produces an error try: with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device) + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) if max_diff > 1e-4: logger.warning("PyTorch and ONNX Runtime results are NOT close") else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 8c22cd5e745b3..a4bef1f06b4fe 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -12,7 +12,9 @@ import numpy as np import torch +from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor +from transformers import __version__ as transformers_version from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper @@ -88,7 +90,10 @@ def load_model( Returns: Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. """ - model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir) + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) @@ -262,11 +267,17 @@ def optimize_onnx( @staticmethod def verify_onnx( model_name_or_path: str, + cache_dir: str, ort_session: InferenceSession, device: torch.device, ): """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" - pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device) + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + pt_model = WhisperForConditionalGeneration.from_pretrained( + model_name_or_path, cache_dir=cache_dir, **extra_kwargs + ).to(device) processor = WhisperProcessor.from_pretrained(model_name_or_path) config = WhisperConfig.from_pretrained(model_name_or_path) From 04afe77305c06181e31b8934df9ee8d3c19af2a7 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 30 Jan 2024 12:40:30 -0800 Subject: [PATCH 08/11] Update ThirdPartyNotices.txt: Add Intel neural-speed (#19332) Add Intel neural-speed to ThirdPartyNotices.txt because it will be shipped in the default build in most of our packages. --- ThirdPartyNotices.txt | 207 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 700206180decd..30894903ec8d2 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6299,3 +6299,210 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +_____ + +neural-speed + +https://github.com/intel/neural-speed + + Apache License + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + ============================================================================ + + Copyright 2016-2019 Intel Corporation + Copyright 2018 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This distribution includes third party software ("third party programs"). + This third party software, even if included with the distribution of + the Intel software, may be governed by separate license terms, including + without limitation, third party license terms, other Intel software license + terms, and open source software license terms. These separate license terms + govern your use of the third party programs as set forth in the + "THIRD-PARTY-PROGRAMS" file. From c379a89bcb26bc4838efe70fae04760106e8d081 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:29:12 -0800 Subject: [PATCH 09/11] [MLAS AArch64] SQNBitGemm optimization (#19272) 1. Add support for packing 4-bit values 32 at a time for CompInt8. 32 4-bit values can fit into a single 128-bit NEON register. For CompInt8, this enables a more efficient path for block sizes greater than or equal to 32. CompFp32 seems to do better with handling 16 elements at a time, so this 32-value packing is not used there. Pack differently based on compute type. Adjust APIs to handle this. 2. Introduce template argument for whether to handle zero-point. This results in less code for the no zero-point (symmetric) case. However, there is a binary size increase due to the additional template instantiations. --- .../cpu/quantization/matmul_nbits.cc | 130 +++-- onnxruntime/core/mlas/inc/mlas_qnbit.h | 16 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 146 +++--- onnxruntime/core/mlas/lib/sqnbitgemm.h | 27 ++ .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 445 ++++++++++++++---- .../test/mlas/bench/bench_sqnbitgemm.cpp | 19 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 7 +- 7 files changed, 558 insertions(+), 232 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 72948c74d7877..166f5c8f52f54 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" + #ifdef ORT_NEURAL_SPEED #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" #endif @@ -16,6 +17,39 @@ namespace onnxruntime { namespace contrib { +namespace { +int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + const auto accuracy_level = std::clamp(accuracy_level_attr, + static_cast(CompMostAccurate), + static_cast(CompLeastAccurate)); + +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + + // Find a supported accuracy level that is not less accurate than the one given. + // CompMostAccurate is always supported with the fallback implementation. + // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. + int64_t effective_accuracy_level = accuracy_level; + for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { + const auto compute_type = static_cast(effective_accuracy_level); + if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { + break; + } + } + + return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) +} +} // namespace + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -24,7 +58,7 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{info.GetAttr("accuracy_level")} { + accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -58,17 +92,22 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + bool is_asym_{false}; bool all_constant_{false}; -#endif + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + if (!all_constant_) { return Status::OK(); } @@ -116,11 +155,17 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); - if (packed_b_size_ == 0) return Status::OK(); + const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); + if (packed_b_size_ == 0) { + return Status::OK(); + } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -136,7 +181,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -159,6 +206,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } #endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } @@ -167,8 +215,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); -#ifdef ORT_NEURAL_SPEED - if (packed_b_.get()) { + +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -234,37 +284,43 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); - if (has_single_b_matrix && packed_b_) { - for (int64_t accuracy_level = accuracy_level_; - accuracy_level >= static_cast(CompMostAccurate); - --accuracy_level) { - const auto compute_type = static_cast(accuracy_level); - if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { - IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); - } + if (has_single_b_matrix) { + const auto compute_type = static_cast(accuracy_level_); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const void* b_data = [&]() -> const void* { + if (packed_b_) { + return packed_b_.get(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); + const Tensor* b = ctx->Input(1); + return b->DataRaw(); + }(); + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data; + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 047011e70bd4d..32e9cc98106d5 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -37,9 +37,7 @@ typedef enum { CompMostAccurate = CompUndef, CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_COMPUTE_TYPE; - -using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these +} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. @@ -102,18 +100,12 @@ MlasSQNBitGemmBatch( /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -153,13 +145,15 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -169,6 +163,7 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[in] QuantBData quantized B data * @param[out] PackedQuantBData packed quantized B data * @param[in] ThreadPool optional thread pool to use @@ -179,6 +174,7 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 0d8a5692359a6..38c31c8841761 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -39,23 +39,17 @@ enum SQNBitGemmVariant { SQNBitGemmVariant GetSQNBitGemmVariant( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { if (ComputeType == CompFp32 || ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8 && M == 1) { + } else if (ComputeType == CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -67,9 +61,6 @@ GetSQNBitGemmVariant( bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -80,7 +71,7 @@ MlasIsSQNBitGemmAvailable( return false; } - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { @@ -164,7 +155,7 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); if (PerGemmWorkspaceStride == 0) { @@ -178,91 +169,24 @@ MlasSQNBitGemmBatchWorkspaceSize( return WorkspaceSize + Alignment - 1; } -namespace -{ - -void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - // - // Pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - for (size_t kk = 0; kk < BlkLen; kk += 16) { - for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + 4]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += 8; - PackedQuantBData += 8; - } - } - ); -} - -} // namespace - size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - // Ensure that a general implementation is available on this platform. - // For now, all implementations share the same packed format. - { - // Currently, there are implementations specific to M = 1, so pick a more general M > 1. - constexpr size_t M = 2; - // A CompUndef implementation should be available if any is available. - constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; - const bool HasGeneralImplementation = - MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (!HasGeneralImplementation) { - return 0; - } + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; } - if (BlkBitWidth == 4) { - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->SQ4BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); } return 0; @@ -274,20 +198,28 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool ) { - if (BlkBitWidth == 4) { - SQ4BitGemmPackQuantBData( + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + Dispatch->SQ4BitGemmPackQuantBData( N, K, BlkLen, + ComputeType, static_cast(QuantBData), static_cast(PackedQuantBData), ThreadPool ); + return; } } @@ -512,7 +444,37 @@ SQ4BitGemm_CompInt8( return; } - assert(false && "not implemented for M > 1"); + // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. + // TODO Replace it with an optimized implementation. + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } + } } typedef void(InitializeWorkspaceFn)( @@ -594,7 +556,7 @@ MlasSQNBitGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index a66db79dc290a..3992bc3e452a3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -99,6 +99,33 @@ Q8BlkAlignment() // struct MLAS_SQNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + typedef void(SQ4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + // // CompFp32 kernel function prototypes. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 69fd427fa574a..c4c54a9be34d8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,14 +15,115 @@ Module Name: --*/ -#include "sqnbitgemm.h" - #include #include #include #include +#include "sqnbitgemm.h" + +// +// Quantized B data packing function implementation. +// + +namespace +{ + +size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +} // namespace + +// +// General helpers. +// + namespace { @@ -95,7 +196,16 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) } } -template +} // namespace + +// +// CompFp32 kernel implementation. +// + +namespace +{ + +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompFp32( size_t BlkLen, @@ -112,11 +222,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( ) { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); const uint8x8_t LowMask = vdup_n_u8(0x0F); @@ -137,7 +247,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true for (size_t k = 0; k < CountK; k += BlkLen) { const size_t k_blk_len = std::min(CountK - k, BlkLen); @@ -147,8 +258,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } ); - float offset[NCols]; // Includes zero point and float conversion offset of 16. - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. + // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -157,11 +269,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( : (zp_packed & std::byte{0x0F}); offset[i] = 16.0f + std::to_integer(zp); }); - } else { - UnrolledLoop([&](size_t i) { - constexpr float zp = 8.0f; - offset[i] = 16.0f + zp; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { @@ -187,8 +294,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); - // dequantize B - // shift left 3 and widen to 16 bits uint16x8_t bv_u16[NCols][2]; UnrolledLoop([&](size_t i) { @@ -217,10 +322,17 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( }); // subtract float conversion offset (16) and zero point - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } // multiply by scale UnrolledLoop([&](size_t i) { @@ -237,7 +349,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( // increment pointers to next block QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -258,8 +372,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( } } -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32( +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( size_t BlkLen, const float* A, const std::byte* QuantBData, @@ -295,7 +410,7 @@ SQ4BitGemmM1Kernel_CompFp32( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( + ComputeDotProducts_BlkBitWidth4_CompFp32( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -306,7 +421,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -319,7 +434,7 @@ SQ4BitGemmM1Kernel_CompFp32( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1>( + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -330,7 +445,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -339,6 +454,49 @@ SQ4BitGemmM1Kernel_CompFp32( } } +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + MLAS_FORCEINLINE void Q4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, @@ -353,6 +511,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( { auto impl0_reference = [&]() { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; float* Dst = FpData; @@ -378,11 +537,11 @@ Q4BitBlkDequantBForSgemm_CompFp32( : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const size_t packed_idx = kk % 16; + const size_t packed_idx = kk % SubBlkLen; - const bool is_low_half = packed_idx < 8; - const size_t packed_byte_idx = packed_idx % 8; - const size_t packed_range_offset = (kk / 16) * 8; + const bool is_low_half = packed_idx < (SubBlkLen / 2); + const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); + const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); @@ -415,7 +574,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( } // -// CompInt8 kernel implementation and related helpers +// CompInt8 kernel implementation. // template @@ -431,8 +590,6 @@ QuantizeBlock( assert(BlkLen % SubBlkLen == 0); - constexpr size_t VectorCount = SubBlkLen / 4; - // // Scan block values first to determine scale. // @@ -443,16 +600,16 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - float32x4_t abs_a[VectorCount]; - UnrolledLoop([&](size_t i) { + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { abs_a[i] = vabsq_f32(a[i]); }); // find amax of SubBlkLen elements - for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { for (size_t i = 0; i < interval; ++i) { abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); } @@ -477,19 +634,19 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { a[i] = vmulq_n_f32(a[i], scale_reciprocal); }); - int32x4_t a_s32[VectorCount]; - UnrolledLoop([&](size_t i) { + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { a_s32[i] = vcvtaq_s32_f32(a[i]); }); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); @@ -530,7 +687,7 @@ QuantizeARow_CompInt8( } } -template +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8( size_t BlkLen, @@ -546,20 +703,22 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( const float* BiasPtr ) { - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + static_assert(SubBlkLen == 16 || SubBlkLen == 32, "SubBlkLen must be 16 or 32"); - const uint8x8_t LowMask = vdup_n_u8(0x0F); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + [[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16 + [[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32 const std::byte* QuantA = QuantARowPtr; const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true float32x4_t acc[NCols]{}; @@ -572,8 +731,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( float b_scale[NCols]; UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); - int8_t b_zp[NCols]; - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -581,42 +740,73 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( ? std::to_integer(zp_packed >> 4) : std::to_integer(zp_packed & std::byte{0x0F}); }); - } else { - UnrolledLoop([&](size_t i) { - b_zp[i] = 8; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { // load A row vector - int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + int8x16_t av[SubBlkLen / 16]; + UnrolledLoop([&](size_t i) { + av[i] = vld1q_s8(a_data + k_idx_in_blk + i * 16); + }); // load B column vectors - uint8x8_t bv_packed[NCols]; + int8x16_t bv[NCols][SubBlkLen / 16]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - int8x16_t bv[NCols]; - UnrolledLoop([&](size_t i) { - const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); - const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); - bv[i] = vcombine_s8(lo, hi); - }); + if constexpr (SubBlkLen == 16) { + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + UnrolledLoop([&](size_t i) { + const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMaskU8x8)); + const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); + bv[i][0] = vcombine_s8(lo, hi); + }); + } else { + static_assert(SubBlkLen == 32); + + uint8x16_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1q_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + UnrolledLoop([&](size_t i) { + bv[i][0] = vreinterpretq_s8_u8(vandq_u8(bv_packed[i], LowMaskU8x16)); + bv[i][1] = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed[i], 4)); + }); + } // subtract B zero point - UnrolledLoop([&](size_t i) { - const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); - bv[i] = vsubq_s8(bv[i], zp_v); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); + UnrolledLoop([&](size_t j) { + bv[i][j] = vsubq_s8(bv[i][j], zp_v); + }); + }); + } else { + const int8x16_t zp_v = vdupq_n_s8(8); + + UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t j) { + bv[i][j] = vsubq_s8(bv[i][j], zp_v); + }); + }); + } // compute quantized dot product int32x4_t dot[NCols]{}; UnrolledLoop([&](size_t i) { - dot[i] = vdotq_s32(dot[i], av, bv[i]); + UnrolledLoop([&](size_t j) { + dot[i] = vdotq_s32(dot[i], av[j], bv[i][j]); + }); }); // convert dot product result to float @@ -636,7 +826,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( QuantA += Q8BlkSize(BlkLen); QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -657,9 +849,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( } } -MLAS_FORCEINLINE +template void -SQ4BitGemmM1Kernel_CompInt8( +SQ4BitGemmM1Kernel_CompInt8_Impl( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, @@ -673,7 +865,6 @@ SQ4BitGemmM1Kernel_CompInt8( ) { constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; const std::byte* QuantARowPtr = QuantA; float* CRowPtr = C; @@ -695,7 +886,7 @@ SQ4BitGemmM1Kernel_CompInt8( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8( + ComputeDotProducts_BlkBitWidth4_CompInt8( BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -706,7 +897,7 @@ SQ4BitGemmM1Kernel_CompInt8( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -719,7 +910,7 @@ SQ4BitGemmM1Kernel_CompInt8( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>( BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -730,7 +921,7 @@ SQ4BitGemmM1Kernel_CompInt8( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -739,6 +930,94 @@ SQ4BitGemmM1Kernel_CompInt8( } } +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + } // namespace // @@ -748,8 +1027,12 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 668d7a0611367..b7b453415838a 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -61,10 +61,11 @@ void SQNBITGEMM(benchmark::State& state) { } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get()); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + tp.get()); } MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; @@ -87,7 +88,9 @@ void SQNBITGEMM(benchmark::State& state) { } } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) { + constexpr size_t BlkBitWidth = 4; + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); ArgsProductWithFilter(b, @@ -96,19 +99,17 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { {1, 1024, 2048}, // M {4096, 11008}, // N {4096, 11008}, // K - {8}, // Threads + {1, 8}, // Threads {int64_t{false}, int64_t{true}}, // Symmetric {int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType - [](const std::vector& args) { + [&](const std::vector& args) { return MlasIsSQNBitGemmAvailable( - // M, N, K - narrow(args[1]), narrow(args[2]), narrow(args[3]), // BlkBitWidth, BlkLen - 4, narrow(args[0]), + BlkBitWidth, narrow(args[0]), // ComputeType static_cast(args[6])); }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4>)->Apply(SQ4BitGemmArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 4fb8ab41745d5..ed09d7ee92b2a 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -259,10 +259,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* PackedQuantBData = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool()); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + GetMlasThreadPool()); } if (ComputeType == CompFp32) { @@ -330,7 +331,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Tue, 30 Jan 2024 15:59:37 -0800 Subject: [PATCH 10/11] Move einsum's test data to constexpr variables (#19320) ### Description emscripten's C++ compiler has difficulty on compiling einsum_test.cc because the file has too many local variables. So I moved them to constexpr. --- cmake/onnxruntime_unittests.cmake | 3 +- .../test/providers/cpu/math/einsum_test.cc | 1670 +++++++++++++---- 2 files changed, 1316 insertions(+), 357 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 351ea1a95581b..714f35380ca02 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -825,8 +825,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") ) endif() list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" - "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc" - "${TEST_SRC_DIR}/providers/cpu/math/einsum_test.cc") + "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") endif() set(test_all_args) diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 05b936a41e3c1..4e968d3de6b8a 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -769,374 +769,1334 @@ TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) { // for two and three inputs (most common use-case of Einsum operator) struct EinsumTestCase { - std::string equation; - std::vector shape; - std::vector expected; - EinsumTestCase(const std::string& eq, const std::vector& sh, const std::vector& exp) : equation(eq), shape(sh), expected(exp) {} + std::string_view equation; + gsl::span shape; + gsl::span expected; }; +static constexpr std::string_view equation0 = "abc,cd->abc"; +static constexpr std::array shape0{2, 2, 2}; +static constexpr std::array expected0{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}; +static constexpr std::string_view equation1 = "abc,cd->abd"; +static constexpr std::array shape1{2, 2, 2}; +static constexpr std::array expected1{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}; +static constexpr std::string_view equation2 = "abc,cd->acd"; +static constexpr std::array shape2{2, 2, 2}; +static constexpr std::array expected2{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}; +static constexpr std::string_view equation3 = "abc,dc->abd"; +static constexpr std::array shape3{2, 2, 2}; +static constexpr std::array expected3{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}; +static constexpr std::string_view equation4 = "abc,dc->abc"; +static constexpr std::array shape4{2, 2, 2}; +static constexpr std::array expected4{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}; +static constexpr std::string_view equation5 = "abc,dc->acd"; +static constexpr std::array shape5{2, 2, 2}; +static constexpr std::array expected5{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}; +static constexpr std::string_view equation6 = "acb,cd->acd"; +static constexpr std::array shape6{2, 2, 2}; +static constexpr std::array expected6{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}; +static constexpr std::string_view equation7 = "acb,cd->abc"; +static constexpr std::array shape7{2, 2, 2}; +static constexpr std::array expected7{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}; +static constexpr std::string_view equation8 = "acb,cd->abd"; +static constexpr std::array shape8{2, 2, 2}; +static constexpr std::array expected8{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}; +static constexpr std::string_view equation9 = "acb,dc->acd"; +static constexpr std::array shape9{2, 2, 2}; +static constexpr std::array expected9{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}; +static constexpr std::string_view equation10 = "acb,dc->abd"; +static constexpr std::array shape10{2, 2, 2}; +static constexpr std::array expected10{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}; +static constexpr std::string_view equation11 = "acb,dc->abc"; +static constexpr std::array shape11{2, 2, 2}; +static constexpr std::array expected11{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}; +static constexpr std::string_view equation12 = "bac,cd->bac"; +static constexpr std::array shape12{2, 2, 2}; +static constexpr std::array expected12{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}; +static constexpr std::string_view equation13 = "bac,cd->bad"; +static constexpr std::array shape13{2, 2, 2}; +static constexpr std::array expected13{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}; +static constexpr std::string_view equation14 = "bac,cd->bcd"; +static constexpr std::array shape14{2, 2, 2}; +static constexpr std::array expected14{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}; +static constexpr std::string_view equation15 = "bac,dc->bad"; +static constexpr std::array shape15{2, 2, 2}; +static constexpr std::array expected15{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}; +static constexpr std::string_view equation16 = "bac,dc->bac"; +static constexpr std::array shape16{2, 2, 2}; +static constexpr std::array expected16{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}; +static constexpr std::string_view equation17 = "bac,dc->bcd"; +static constexpr std::array shape17{2, 2, 2}; +static constexpr std::array expected17{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}; +static constexpr std::string_view equation18 = "bca,cd->bcd"; +static constexpr std::array shape18{2, 2, 2}; +static constexpr std::array expected18{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}; +static constexpr std::string_view equation19 = "bca,cd->bac"; +static constexpr std::array shape19{2, 2, 2}; +static constexpr std::array expected19{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}; +static constexpr std::string_view equation20 = "bca,cd->bad"; +static constexpr std::array shape20{2, 2, 2}; +static constexpr std::array expected20{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}; +static constexpr std::string_view equation21 = "bca,dc->bcd"; +static constexpr std::array shape21{2, 2, 2}; +static constexpr std::array expected21{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}; +static constexpr std::string_view equation22 = "bca,dc->bad"; +static constexpr std::array shape22{2, 2, 2}; +static constexpr std::array expected22{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}; +static constexpr std::string_view equation23 = "bca,dc->bac"; +static constexpr std::array shape23{2, 2, 2}; +static constexpr std::array expected23{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}; +static constexpr std::string_view equation24 = "cab,cd->cad"; +static constexpr std::array shape24{2, 2, 2}; +static constexpr std::array expected24{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}; +static constexpr std::string_view equation25 = "cab,cd->cbd"; +static constexpr std::array shape25{2, 2, 2}; +static constexpr std::array expected25{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}; +static constexpr std::string_view equation26 = "cab,dc->cad"; +static constexpr std::array shape26{2, 2, 2}; +static constexpr std::array expected26{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}; +static constexpr std::string_view equation27 = "cab,dc->cbd"; +static constexpr std::array shape27{2, 2, 2}; +static constexpr std::array expected27{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}; +static constexpr std::string_view equation28 = "cba,cd->cbd"; +static constexpr std::array shape28{2, 2, 2}; +static constexpr std::array expected28{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}; +static constexpr std::string_view equation29 = "cba,cd->cad"; +static constexpr std::array shape29{2, 2, 2}; +static constexpr std::array expected29{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}; +static constexpr std::string_view equation30 = "cba,dc->cbd"; +static constexpr std::array shape30{2, 2, 2}; +static constexpr std::array expected30{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}; +static constexpr std::string_view equation31 = "cba,dc->cad"; +static constexpr std::array shape31{2, 2, 2}; +static constexpr std::array expected31{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}; +static constexpr std::array case0 = {{ + {equation0, shape0, expected0}, + {equation1, shape1, expected1}, + {equation2, shape2, expected2}, + {equation3, shape3, expected3}, + {equation4, shape4, expected4}, + {equation5, shape5, expected5}, + {equation6, shape6, expected6}, + {equation7, shape7, expected7}, + {equation8, shape8, expected8}, + {equation9, shape9, expected9}, + {equation10, shape10, expected10}, + {equation11, shape11, expected11}, + {equation12, shape12, expected12}, + {equation13, shape13, expected13}, + {equation14, shape14, expected14}, + {equation15, shape15, expected15}, + {equation16, shape16, expected16}, + {equation17, shape17, expected17}, + {equation18, shape18, expected18}, + {equation19, shape19, expected19}, + {equation20, shape20, expected20}, + {equation21, shape21, expected21}, + {equation22, shape22, expected22}, + {equation23, shape23, expected23}, + {equation24, shape24, expected24}, + {equation25, shape25, expected25}, + {equation26, shape26, expected26}, + {equation27, shape27, expected27}, + {equation28, shape28, expected28}, + {equation29, shape29, expected29}, + {equation30, shape30, expected30}, + {equation31, shape31, expected31}, +}}; + +static constexpr std::string_view equation32 = "abc,cd,def->abd"; +static constexpr std::array shape32{2, 2, 2}; +static constexpr std::array expected32{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation33 = "abc,cd,def->abe"; +static constexpr std::array shape33{2, 2, 2}; +static constexpr std::array expected33{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation34 = "abc,cd,def->acd"; +static constexpr std::array shape34{2, 2, 2}; +static constexpr std::array expected34{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation35 = "abc,cd,def->ace"; +static constexpr std::array shape35{2, 2, 2}; +static constexpr std::array expected35{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation36 = "abc,cd,dfe->abd"; +static constexpr std::array shape36{2, 2, 2}; +static constexpr std::array expected36{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation37 = "abc,cd,dfe->abf"; +static constexpr std::array shape37{2, 2, 2}; +static constexpr std::array expected37{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation38 = "abc,cd,dfe->acd"; +static constexpr std::array shape38{2, 2, 2}; +static constexpr std::array expected38{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation39 = "abc,cd,dfe->acf"; +static constexpr std::array shape39{2, 2, 2}; +static constexpr std::array expected39{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation40 = "abc,cd,edf->abe"; +static constexpr std::array shape40{2, 2, 2}; +static constexpr std::array expected40{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation41 = "abc,cd,edf->abd"; +static constexpr std::array shape41{2, 2, 2}; +static constexpr std::array expected41{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation42 = "abc,cd,edf->ace"; +static constexpr std::array shape42{2, 2, 2}; +static constexpr std::array expected42{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation43 = "abc,cd,edf->acd"; +static constexpr std::array shape43{2, 2, 2}; +static constexpr std::array expected43{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation44 = "abc,cd,efd->abe"; +static constexpr std::array shape44{2, 2, 2}; +static constexpr std::array expected44{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation45 = "abc,cd,efd->abf"; +static constexpr std::array shape45{2, 2, 2}; +static constexpr std::array expected45{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation46 = "abc,cd,efd->ace"; +static constexpr std::array shape46{2, 2, 2}; +static constexpr std::array expected46{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation47 = "abc,cd,efd->acf"; +static constexpr std::array shape47{2, 2, 2}; +static constexpr std::array expected47{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation48 = "abc,cd,fde->abf"; +static constexpr std::array shape48{2, 2, 2}; +static constexpr std::array expected48{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation49 = "abc,cd,fde->abd"; +static constexpr std::array shape49{2, 2, 2}; +static constexpr std::array expected49{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation50 = "abc,cd,fde->acf"; +static constexpr std::array shape50{2, 2, 2}; +static constexpr std::array expected50{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation51 = "abc,cd,fde->acd"; +static constexpr std::array shape51{2, 2, 2}; +static constexpr std::array expected51{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation52 = "abc,cd,fed->abf"; +static constexpr std::array shape52{2, 2, 2}; +static constexpr std::array expected52{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation53 = "abc,cd,fed->abe"; +static constexpr std::array shape53{2, 2, 2}; +static constexpr std::array expected53{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation54 = "abc,cd,fed->acf"; +static constexpr std::array shape54{2, 2, 2}; +static constexpr std::array expected54{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation55 = "abc,cd,fed->ace"; +static constexpr std::array shape55{2, 2, 2}; +static constexpr std::array expected55{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation56 = "abc,dc,def->abd"; +static constexpr std::array shape56{2, 2, 2}; +static constexpr std::array expected56{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation57 = "abc,dc,def->abe"; +static constexpr std::array shape57{2, 2, 2}; +static constexpr std::array expected57{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation58 = "abc,dc,def->acd"; +static constexpr std::array shape58{2, 2, 2}; +static constexpr std::array expected58{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation59 = "abc,dc,def->ace"; +static constexpr std::array shape59{2, 2, 2}; +static constexpr std::array expected59{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation60 = "abc,dc,dfe->abd"; +static constexpr std::array shape60{2, 2, 2}; +static constexpr std::array expected60{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation61 = "abc,dc,dfe->abf"; +static constexpr std::array shape61{2, 2, 2}; +static constexpr std::array expected61{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation62 = "abc,dc,dfe->acd"; +static constexpr std::array shape62{2, 2, 2}; +static constexpr std::array expected62{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation63 = "abc,dc,dfe->acf"; +static constexpr std::array shape63{2, 2, 2}; +static constexpr std::array expected63{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation64 = "abc,dc,edf->abe"; +static constexpr std::array shape64{2, 2, 2}; +static constexpr std::array expected64{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation65 = "abc,dc,edf->abd"; +static constexpr std::array shape65{2, 2, 2}; +static constexpr std::array expected65{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation66 = "abc,dc,edf->ace"; +static constexpr std::array shape66{2, 2, 2}; +static constexpr std::array expected66{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation67 = "abc,dc,edf->acd"; +static constexpr std::array shape67{2, 2, 2}; +static constexpr std::array expected67{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation68 = "abc,dc,efd->abe"; +static constexpr std::array shape68{2, 2, 2}; +static constexpr std::array expected68{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation69 = "abc,dc,efd->abf"; +static constexpr std::array shape69{2, 2, 2}; +static constexpr std::array expected69{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation70 = "abc,dc,efd->ace"; +static constexpr std::array shape70{2, 2, 2}; +static constexpr std::array expected70{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation71 = "abc,dc,efd->acf"; +static constexpr std::array shape71{2, 2, 2}; +static constexpr std::array expected71{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation72 = "abc,dc,fde->abf"; +static constexpr std::array shape72{2, 2, 2}; +static constexpr std::array expected72{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation73 = "abc,dc,fde->abd"; +static constexpr std::array shape73{2, 2, 2}; +static constexpr std::array expected73{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation74 = "abc,dc,fde->acf"; +static constexpr std::array shape74{2, 2, 2}; +static constexpr std::array expected74{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation75 = "abc,dc,fde->acd"; +static constexpr std::array shape75{2, 2, 2}; +static constexpr std::array expected75{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation76 = "abc,dc,fed->abf"; +static constexpr std::array shape76{2, 2, 2}; +static constexpr std::array expected76{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation77 = "abc,dc,fed->abe"; +static constexpr std::array shape77{2, 2, 2}; +static constexpr std::array expected77{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation78 = "abc,dc,fed->acf"; +static constexpr std::array shape78{2, 2, 2}; +static constexpr std::array expected78{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation79 = "abc,dc,fed->ace"; +static constexpr std::array shape79{2, 2, 2}; +static constexpr std::array expected79{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation80 = "acb,cd,def->acd"; +static constexpr std::array shape80{2, 2, 2}; +static constexpr std::array expected80{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation81 = "acb,cd,def->ace"; +static constexpr std::array shape81{2, 2, 2}; +static constexpr std::array expected81{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation82 = "acb,cd,def->abd"; +static constexpr std::array shape82{2, 2, 2}; +static constexpr std::array expected82{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation83 = "acb,cd,def->abe"; +static constexpr std::array shape83{2, 2, 2}; +static constexpr std::array expected83{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation84 = "acb,cd,dfe->acd"; +static constexpr std::array shape84{2, 2, 2}; +static constexpr std::array expected84{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation85 = "acb,cd,dfe->acf"; +static constexpr std::array shape85{2, 2, 2}; +static constexpr std::array expected85{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation86 = "acb,cd,dfe->abd"; +static constexpr std::array shape86{2, 2, 2}; +static constexpr std::array expected86{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation87 = "acb,cd,dfe->abf"; +static constexpr std::array shape87{2, 2, 2}; +static constexpr std::array expected87{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation88 = "acb,cd,edf->ace"; +static constexpr std::array shape88{2, 2, 2}; +static constexpr std::array expected88{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation89 = "acb,cd,edf->acd"; +static constexpr std::array shape89{2, 2, 2}; +static constexpr std::array expected89{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation90 = "acb,cd,edf->abe"; +static constexpr std::array shape90{2, 2, 2}; +static constexpr std::array expected90{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation91 = "acb,cd,edf->abd"; +static constexpr std::array shape91{2, 2, 2}; +static constexpr std::array expected91{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation92 = "acb,cd,efd->ace"; +static constexpr std::array shape92{2, 2, 2}; +static constexpr std::array expected92{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation93 = "acb,cd,efd->acf"; +static constexpr std::array shape93{2, 2, 2}; +static constexpr std::array expected93{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation94 = "acb,cd,efd->abe"; +static constexpr std::array shape94{2, 2, 2}; +static constexpr std::array expected94{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation95 = "acb,cd,efd->abf"; +static constexpr std::array shape95{2, 2, 2}; +static constexpr std::array expected95{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation96 = "acb,cd,fde->acf"; +static constexpr std::array shape96{2, 2, 2}; +static constexpr std::array expected96{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation97 = "acb,cd,fde->acd"; +static constexpr std::array shape97{2, 2, 2}; +static constexpr std::array expected97{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation98 = "acb,cd,fde->abf"; +static constexpr std::array shape98{2, 2, 2}; +static constexpr std::array expected98{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation99 = "acb,cd,fde->abd"; +static constexpr std::array shape99{2, 2, 2}; +static constexpr std::array expected99{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation100 = "acb,cd,fed->acf"; +static constexpr std::array shape100{2, 2, 2}; +static constexpr std::array expected100{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation101 = "acb,cd,fed->ace"; +static constexpr std::array shape101{2, 2, 2}; +static constexpr std::array expected101{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation102 = "acb,cd,fed->abf"; +static constexpr std::array shape102{2, 2, 2}; +static constexpr std::array expected102{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation103 = "acb,cd,fed->abe"; +static constexpr std::array shape103{2, 2, 2}; +static constexpr std::array expected103{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation104 = "acb,dc,def->acd"; +static constexpr std::array shape104{2, 2, 2}; +static constexpr std::array expected104{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation105 = "acb,dc,def->ace"; +static constexpr std::array shape105{2, 2, 2}; +static constexpr std::array expected105{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; + +static constexpr std::string_view equation106 = "acb,dc,def->abd"; +static constexpr std::array shape106{2, 2, 2}; +static constexpr std::array expected106{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation107 = "acb,dc,def->abe"; +static constexpr std::array shape107{2, 2, 2}; +static constexpr std::array expected107{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation108 = "acb,dc,dfe->acd"; +static constexpr std::array shape108{2, 2, 2}; +static constexpr std::array expected108{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation109 = "acb,dc,dfe->acf"; +static constexpr std::array shape109{2, 2, 2}; +static constexpr std::array expected109{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +static constexpr std::string_view equation110 = "acb,dc,dfe->abd"; +static constexpr std::array shape110{2, 2, 2}; +static constexpr std::array expected110{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation111 = "acb,dc,dfe->abf"; +static constexpr std::array shape111{2, 2, 2}; +static constexpr std::array expected111{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation112 = "acb,dc,edf->ace"; +static constexpr std::array shape112{2, 2, 2}; +static constexpr std::array expected112{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation113 = "acb,dc,edf->acd"; +static constexpr std::array shape113{2, 2, 2}; +static constexpr std::array expected113{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation114 = "acb,dc,edf->abe"; +static constexpr std::array shape114{2, 2, 2}; +static constexpr std::array expected114{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation115 = "acb,dc,edf->abd"; +static constexpr std::array shape115{2, 2, 2}; +static constexpr std::array expected115{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation116 = "acb,dc,efd->ace"; +static constexpr std::array shape116{2, 2, 2}; +static constexpr std::array expected116{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation117 = "acb,dc,efd->acf"; +static constexpr std::array shape117{2, 2, 2}; +static constexpr std::array expected117{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation118 = "acb,dc,efd->abe"; +static constexpr std::array shape118{2, 2, 2}; +static constexpr std::array expected118{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation119 = "acb,dc,efd->abf"; +static constexpr std::array shape119{2, 2, 2}; +static constexpr std::array expected119{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation120 = "acb,dc,fde->acf"; +static constexpr std::array shape120{2, 2, 2}; +static constexpr std::array expected120{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation121 = "acb,dc,fde->acd"; +static constexpr std::array shape121{2, 2, 2}; +static constexpr std::array expected121{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation122 = "acb,dc,fde->abf"; +static constexpr std::array shape122{2, 2, 2}; +static constexpr std::array expected122{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation123 = "acb,dc,fde->abd"; +static constexpr std::array shape123{2, 2, 2}; +static constexpr std::array expected123{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation124 = "acb,dc,fed->acf"; +static constexpr std::array shape124{2, 2, 2}; +static constexpr std::array expected124{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation125 = "acb,dc,fed->ace"; +static constexpr std::array shape125{2, 2, 2}; +static constexpr std::array expected125{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation126 = "acb,dc,fed->abf"; +static constexpr std::array shape126{2, 2, 2}; +static constexpr std::array expected126{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation127 = "acb,dc,fed->abe"; +static constexpr std::array shape127{2, 2, 2}; +static constexpr std::array expected127{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation128 = "bac,cd,def->bad"; +static constexpr std::array shape128{2, 2, 2}; +static constexpr std::array expected128{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation129 = "bac,cd,def->bae"; +static constexpr std::array shape129{2, 2, 2}; +static constexpr std::array expected129{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation130 = "bac,cd,def->bcd"; +static constexpr std::array shape130{2, 2, 2}; +static constexpr std::array expected130{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation131 = "bac,cd,def->bce"; +static constexpr std::array shape131{2, 2, 2}; +static constexpr std::array expected131{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation132 = "bac,cd,dfe->bad"; +static constexpr std::array shape132{2, 2, 2}; +static constexpr std::array expected132{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation133 = "bac,cd,dfe->baf"; +static constexpr std::array shape133{2, 2, 2}; +static constexpr std::array expected133{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation134 = "bac,cd,dfe->bcd"; +static constexpr std::array shape134{2, 2, 2}; +static constexpr std::array expected134{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation135 = "bac,cd,dfe->bcf"; +static constexpr std::array shape135{2, 2, 2}; +static constexpr std::array expected135{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation136 = "bac,cd,edf->bae"; +static constexpr std::array shape136{2, 2, 2}; +static constexpr std::array expected136{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation137 = "bac,cd,edf->bad"; +static constexpr std::array shape137{2, 2, 2}; +static constexpr std::array expected137{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation138 = "bac,cd,edf->bce"; +static constexpr std::array shape138{2, 2, 2}; +static constexpr std::array expected138{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation139 = "bac,cd,edf->bcd"; +static constexpr std::array shape139{2, 2, 2}; +static constexpr std::array expected139{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation140 = "bac,cd,efd->bae"; +static constexpr std::array shape140{2, 2, 2}; +static constexpr std::array expected140{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation141 = "bac,cd,efd->baf"; +static constexpr std::array shape141{2, 2, 2}; +static constexpr std::array expected141{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation142 = "bac,cd,efd->bce"; +static constexpr std::array shape142{2, 2, 2}; +static constexpr std::array expected142{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation143 = "bac,cd,efd->bcf"; +static constexpr std::array shape143{2, 2, 2}; +static constexpr std::array expected143{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation144 = "bac,cd,fde->baf"; +static constexpr std::array shape144{2, 2, 2}; +static constexpr std::array expected144{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation145 = "bac,cd,fde->bad"; +static constexpr std::array shape145{2, 2, 2}; +static constexpr std::array expected145{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation146 = "bac,cd,fde->bcf"; +static constexpr std::array shape146{2, 2, 2}; +static constexpr std::array expected146{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation147 = "bac,cd,fde->bcd"; +static constexpr std::array shape147{2, 2, 2}; +static constexpr std::array expected147{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation148 = "bac,cd,fed->baf"; +static constexpr std::array shape148{2, 2, 2}; +static constexpr std::array expected148{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation149 = "bac,cd,fed->bae"; +static constexpr std::array shape149{2, 2, 2}; +static constexpr std::array expected149{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation150 = "bac,cd,fed->bcf"; +static constexpr std::array shape150{2, 2, 2}; +static constexpr std::array expected150{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation151 = "bac,cd,fed->bce"; +static constexpr std::array shape151{2, 2, 2}; +static constexpr std::array expected151{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation152 = "bac,dc,def->bad"; +static constexpr std::array shape152{2, 2, 2}; +static constexpr std::array expected152{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation153 = "bac,dc,def->bae"; +static constexpr std::array shape153{2, 2, 2}; +static constexpr std::array expected153{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation154 = "bac,dc,def->bcd"; +static constexpr std::array shape154{2, 2, 2}; +static constexpr std::array expected154{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation155 = "bac,dc,def->bce"; +static constexpr std::array shape155{2, 2, 2}; +static constexpr std::array expected155{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation156 = "bac,dc,dfe->bad"; +static constexpr std::array shape156{2, 2, 2}; +static constexpr std::array expected156{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation157 = "bac,dc,dfe->baf"; +static constexpr std::array shape157{2, 2, 2}; +static constexpr std::array expected157{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation158 = "bac,dc,dfe->bcd"; +static constexpr std::array shape158{2, 2, 2}; +static constexpr std::array expected158{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation159 = "bac,dc,dfe->bcf"; +static constexpr std::array shape159{2, 2, 2}; +static constexpr std::array expected159{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation160 = "bac,dc,edf->bae"; +static constexpr std::array shape160{2, 2, 2}; +static constexpr std::array expected160{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation161 = "bac,dc,edf->bad"; +static constexpr std::array shape161{2, 2, 2}; +static constexpr std::array expected161{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation162 = "bac,dc,edf->bce"; +static constexpr std::array shape162{2, 2, 2}; +static constexpr std::array expected162{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation163 = "bac,dc,edf->bcd"; +static constexpr std::array shape163{2, 2, 2}; +static constexpr std::array expected163{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation164 = "bac,dc,efd->bae"; +static constexpr std::array shape164{2, 2, 2}; +static constexpr std::array expected164{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation165 = "bac,dc,efd->baf"; +static constexpr std::array shape165{2, 2, 2}; +static constexpr std::array expected165{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation166 = "bac,dc,efd->bce"; +static constexpr std::array shape166{2, 2, 2}; +static constexpr std::array expected166{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation167 = "bac,dc,efd->bcf"; +static constexpr std::array shape167{2, 2, 2}; +static constexpr std::array expected167{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation168 = "bac,dc,fde->baf"; +static constexpr std::array shape168{2, 2, 2}; +static constexpr std::array expected168{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation169 = "bac,dc,fde->bad"; +static constexpr std::array shape169{2, 2, 2}; +static constexpr std::array expected169{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation170 = "bac,dc,fde->bcf"; +static constexpr std::array shape170{2, 2, 2}; +static constexpr std::array expected170{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation171 = "bac,dc,fde->bcd"; +static constexpr std::array shape171{2, 2, 2}; +static constexpr std::array expected171{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation172 = "bac,dc,fed->baf"; +static constexpr std::array shape172{2, 2, 2}; +static constexpr std::array expected172{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation173 = "bac,dc,fed->bae"; +static constexpr std::array shape173{2, 2, 2}; +static constexpr std::array expected173{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation174 = "bac,dc,fed->bcf"; +static constexpr std::array shape174{2, 2, 2}; +static constexpr std::array expected174{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation175 = "bac,dc,fed->bce"; +static constexpr std::array shape175{2, 2, 2}; +static constexpr std::array expected175{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation176 = "bca,cd,def->bcd"; +static constexpr std::array shape176{2, 2, 2}; +static constexpr std::array expected176{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation177 = "bca,cd,def->bce"; +static constexpr std::array shape177{2, 2, 2}; +static constexpr std::array expected177{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation178 = "bca,cd,def->bad"; +static constexpr std::array shape178{2, 2, 2}; +static constexpr std::array expected178{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation179 = "bca,cd,def->bae"; +static constexpr std::array shape179{2, 2, 2}; +static constexpr std::array expected179{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation180 = "bca,cd,dfe->bcd"; +static constexpr std::array shape180{2, 2, 2}; +static constexpr std::array expected180{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation181 = "bca,cd,dfe->bcf"; +static constexpr std::array shape181{2, 2, 2}; +static constexpr std::array expected181{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation182 = "bca,cd,dfe->bad"; +static constexpr std::array shape182{2, 2, 2}; +static constexpr std::array expected182{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation183 = "bca,cd,dfe->baf"; +static constexpr std::array shape183{2, 2, 2}; +static constexpr std::array expected183{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation184 = "bca,cd,edf->bce"; +static constexpr std::array shape184{2, 2, 2}; +static constexpr std::array expected184{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation185 = "bca,cd,edf->bcd"; +static constexpr std::array shape185{2, 2, 2}; +static constexpr std::array expected185{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation186 = "bca,cd,edf->bae"; +static constexpr std::array shape186{2, 2, 2}; +static constexpr std::array expected186{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation187 = "bca,cd,edf->bad"; +static constexpr std::array shape187{2, 2, 2}; +static constexpr std::array expected187{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation188 = "bca,cd,efd->bce"; +static constexpr std::array shape188{2, 2, 2}; +static constexpr std::array expected188{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation189 = "bca,cd,efd->bcf"; +static constexpr std::array shape189{2, 2, 2}; +static constexpr std::array expected189{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation190 = "bca,cd,efd->bae"; +static constexpr std::array shape190{2, 2, 2}; +static constexpr std::array expected190{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation191 = "bca,cd,efd->baf"; +static constexpr std::array shape191{2, 2, 2}; +static constexpr std::array expected191{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation192 = "bca,cd,fde->bcf"; +static constexpr std::array shape192{2, 2, 2}; +static constexpr std::array expected192{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation193 = "bca,cd,fde->bcd"; +static constexpr std::array shape193{2, 2, 2}; +static constexpr std::array expected193{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation194 = "bca,cd,fde->baf"; +static constexpr std::array shape194{2, 2, 2}; +static constexpr std::array expected194{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation195 = "bca,cd,fde->bad"; +static constexpr std::array shape195{2, 2, 2}; +static constexpr std::array expected195{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation196 = "bca,cd,fed->bcf"; +static constexpr std::array shape196{2, 2, 2}; +static constexpr std::array expected196{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation197 = "bca,cd,fed->bce"; +static constexpr std::array shape197{2, 2, 2}; +static constexpr std::array expected197{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation198 = "bca,cd,fed->baf"; +static constexpr std::array shape198{2, 2, 2}; +static constexpr std::array expected198{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation199 = "bca,cd,fed->bae"; +static constexpr std::array shape199{2, 2, 2}; +static constexpr std::array expected199{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation200 = "bca,dc,def->bcd"; +static constexpr std::array shape200{2, 2, 2}; +static constexpr std::array expected200{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation201 = "bca,dc,def->bce"; +static constexpr std::array shape201{2, 2, 2}; +static constexpr std::array expected201{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +static constexpr std::string_view equation202 = "bca,dc,def->bad"; +static constexpr std::array shape202{2, 2, 2}; +static constexpr std::array expected202{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation203 = "bca,dc,def->bae"; +static constexpr std::array shape203{2, 2, 2}; +static constexpr std::array expected203{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation204 = "bca,dc,dfe->bcd"; +static constexpr std::array shape204{2, 2, 2}; +static constexpr std::array expected204{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation205 = "bca,dc,dfe->bcf"; +static constexpr std::array shape205{2, 2, 2}; +static constexpr std::array expected205{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +static constexpr std::string_view equation206 = "bca,dc,dfe->bad"; +static constexpr std::array shape206{2, 2, 2}; +static constexpr std::array expected206{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation207 = "bca,dc,dfe->baf"; +static constexpr std::array shape207{2, 2, 2}; +static constexpr std::array expected207{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation208 = "bca,dc,edf->bce"; +static constexpr std::array shape208{2, 2, 2}; +static constexpr std::array expected208{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation209 = "bca,dc,edf->bcd"; +static constexpr std::array shape209{2, 2, 2}; +static constexpr std::array expected209{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation210 = "bca,dc,edf->bae"; +static constexpr std::array shape210{2, 2, 2}; +static constexpr std::array expected210{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation211 = "bca,dc,edf->bad"; +static constexpr std::array shape211{2, 2, 2}; +static constexpr std::array expected211{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation212 = "bca,dc,efd->bce"; +static constexpr std::array shape212{2, 2, 2}; +static constexpr std::array expected212{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation213 = "bca,dc,efd->bcf"; +static constexpr std::array shape213{2, 2, 2}; +static constexpr std::array expected213{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation214 = "bca,dc,efd->bae"; +static constexpr std::array shape214{2, 2, 2}; +static constexpr std::array expected214{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation215 = "bca,dc,efd->baf"; +static constexpr std::array shape215{2, 2, 2}; +static constexpr std::array expected215{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation216 = "bca,dc,fde->bcf"; +static constexpr std::array shape216{2, 2, 2}; +static constexpr std::array expected216{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation217 = "bca,dc,fde->bcd"; +static constexpr std::array shape217{2, 2, 2}; +static constexpr std::array expected217{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation218 = "bca,dc,fde->baf"; +static constexpr std::array shape218{2, 2, 2}; +static constexpr std::array expected218{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation219 = "bca,dc,fde->bad"; +static constexpr std::array shape219{2, 2, 2}; +static constexpr std::array expected219{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation220 = "bca,dc,fed->bcf"; +static constexpr std::array shape220{2, 2, 2}; +static constexpr std::array expected220{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation221 = "bca,dc,fed->bce"; +static constexpr std::array shape221{2, 2, 2}; +static constexpr std::array expected221{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation222 = "bca,dc,fed->baf"; +static constexpr std::array shape222{2, 2, 2}; +static constexpr std::array expected222{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation223 = "bca,dc,fed->bae"; +static constexpr std::array shape223{2, 2, 2}; +static constexpr std::array expected223{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation224 = "cab,cd,def->cad"; +static constexpr std::array shape224{2, 2, 2}; +static constexpr std::array expected224{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation225 = "cab,cd,def->cae"; +static constexpr std::array shape225{2, 2, 2}; +static constexpr std::array expected225{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation226 = "cab,cd,def->cbd"; +static constexpr std::array shape226{2, 2, 2}; +static constexpr std::array expected226{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation227 = "cab,cd,def->cbe"; +static constexpr std::array shape227{2, 2, 2}; +static constexpr std::array expected227{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation228 = "cab,cd,dfe->cad"; +static constexpr std::array shape228{2, 2, 2}; +static constexpr std::array expected228{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation229 = "cab,cd,dfe->caf"; +static constexpr std::array shape229{2, 2, 2}; +static constexpr std::array expected229{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation230 = "cab,cd,dfe->cbd"; +static constexpr std::array shape230{2, 2, 2}; +static constexpr std::array expected230{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation231 = "cab,cd,dfe->cbf"; +static constexpr std::array shape231{2, 2, 2}; +static constexpr std::array expected231{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation232 = "cab,cd,edf->cae"; +static constexpr std::array shape232{2, 2, 2}; +static constexpr std::array expected232{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation233 = "cab,cd,edf->cad"; +static constexpr std::array shape233{2, 2, 2}; +static constexpr std::array expected233{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation234 = "cab,cd,edf->cbe"; +static constexpr std::array shape234{2, 2, 2}; +static constexpr std::array expected234{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation235 = "cab,cd,edf->cbd"; +static constexpr std::array shape235{2, 2, 2}; +static constexpr std::array expected235{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation236 = "cab,cd,efd->cae"; +static constexpr std::array shape236{2, 2, 2}; +static constexpr std::array expected236{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation237 = "cab,cd,efd->caf"; +static constexpr std::array shape237{2, 2, 2}; +static constexpr std::array expected237{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation238 = "cab,cd,efd->cbe"; +static constexpr std::array shape238{2, 2, 2}; +static constexpr std::array expected238{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation239 = "cab,cd,efd->cbf"; +static constexpr std::array shape239{2, 2, 2}; +static constexpr std::array expected239{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation240 = "cab,cd,fde->caf"; +static constexpr std::array shape240{2, 2, 2}; +static constexpr std::array expected240{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation241 = "cab,cd,fde->cad"; +static constexpr std::array shape241{2, 2, 2}; +static constexpr std::array expected241{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation242 = "cab,cd,fde->cbf"; +static constexpr std::array shape242{2, 2, 2}; +static constexpr std::array expected242{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation243 = "cab,cd,fde->cbd"; +static constexpr std::array shape243{2, 2, 2}; +static constexpr std::array expected243{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation244 = "cab,cd,fed->caf"; +static constexpr std::array shape244{2, 2, 2}; +static constexpr std::array expected244{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation245 = "cab,cd,fed->cae"; +static constexpr std::array shape245{2, 2, 2}; +static constexpr std::array expected245{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation246 = "cab,cd,fed->cbf"; +static constexpr std::array shape246{2, 2, 2}; +static constexpr std::array expected246{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation247 = "cab,cd,fed->cbe"; +static constexpr std::array shape247{2, 2, 2}; +static constexpr std::array expected247{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation248 = "cab,dc,def->cad"; +static constexpr std::array shape248{2, 2, 2}; +static constexpr std::array expected248{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation249 = "cab,dc,def->cae"; +static constexpr std::array shape249{2, 2, 2}; +static constexpr std::array expected249{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation250 = "cab,dc,def->cbd"; +static constexpr std::array shape250{2, 2, 2}; +static constexpr std::array expected250{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; + +static constexpr std::string_view equation251 = "cab,dc,def->cbe"; +static constexpr std::array shape251{2, 2, 2}; +static constexpr std::array expected251{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation252 = "cab,dc,dfe->cad"; +static constexpr std::array shape252{2, 2, 2}; +static constexpr std::array expected252{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation253 = "cab,dc,dfe->caf"; +static constexpr std::array shape253{2, 2, 2}; +static constexpr std::array expected253{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation254 = "cab,dc,dfe->cbd"; +static constexpr std::array shape254{2, 2, 2}; +static constexpr std::array expected254{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +static constexpr std::string_view equation255 = "cab,dc,dfe->cbf"; +static constexpr std::array shape255{2, 2, 2}; +static constexpr std::array expected255{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation256 = "cab,dc,edf->cae"; +static constexpr std::array shape256{2, 2, 2}; +static constexpr std::array expected256{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation257 = "cab,dc,edf->cad"; +static constexpr std::array shape257{2, 2, 2}; +static constexpr std::array expected257{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation258 = "cab,dc,edf->cbe"; +static constexpr std::array shape258{2, 2, 2}; +static constexpr std::array expected258{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation259 = "cab,dc,edf->cbd"; +static constexpr std::array shape259{2, 2, 2}; +static constexpr std::array expected259{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation260 = "cab,dc,efd->cae"; +static constexpr std::array shape260{2, 2, 2}; +static constexpr std::array expected260{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation261 = "cab,dc,efd->caf"; +static constexpr std::array shape261{2, 2, 2}; +static constexpr std::array expected261{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation262 = "cab,dc,efd->cbe"; +static constexpr std::array shape262{2, 2, 2}; +static constexpr std::array expected262{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation263 = "cab,dc,efd->cbf"; +static constexpr std::array shape263{2, 2, 2}; +static constexpr std::array expected263{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::string_view equation264 = "cab,dc,fde->caf"; +static constexpr std::array shape264{2, 2, 2}; +static constexpr std::array expected264{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation265 = "cab,dc,fde->cad"; +static constexpr std::array shape265{2, 2, 2}; +static constexpr std::array expected265{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation266 = "cab,dc,fde->cbf"; +static constexpr std::array shape266{2, 2, 2}; +static constexpr std::array expected266{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation267 = "cab,dc,fde->cbd"; +static constexpr std::array shape267{2, 2, 2}; +static constexpr std::array expected267{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation268 = "cab,dc,fed->caf"; +static constexpr std::array shape268{2, 2, 2}; +static constexpr std::array expected268{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation269 = "cab,dc,fed->cae"; +static constexpr std::array shape269{2, 2, 2}; +static constexpr std::array expected269{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation270 = "cab,dc,fed->cbf"; +static constexpr std::array shape270{2, 2, 2}; +static constexpr std::array expected270{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation271 = "cab,dc,fed->cbe"; +static constexpr std::array shape271{2, 2, 2}; +static constexpr std::array expected271{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::string_view equation272 = "cba,cd,def->cbd"; +static constexpr std::array shape272{2, 2, 2}; +static constexpr std::array expected272{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation273 = "cba,cd,def->cbe"; +static constexpr std::array shape273{2, 2, 2}; +static constexpr std::array expected273{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation274 = "cba,cd,def->cad"; +static constexpr std::array shape274{2, 2, 2}; +static constexpr std::array expected274{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation275 = "cba,cd,def->cae"; +static constexpr std::array shape275{2, 2, 2}; +static constexpr std::array expected275{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation276 = "cba,cd,dfe->cbd"; +static constexpr std::array shape276{2, 2, 2}; +static constexpr std::array expected276{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation277 = "cba,cd,dfe->cbf"; +static constexpr std::array shape277{2, 2, 2}; +static constexpr std::array expected277{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation278 = "cba,cd,dfe->cad"; +static constexpr std::array shape278{2, 2, 2}; +static constexpr std::array expected278{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation279 = "cba,cd,dfe->caf"; +static constexpr std::array shape279{2, 2, 2}; +static constexpr std::array expected279{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation280 = "cba,cd,edf->cbe"; +static constexpr std::array shape280{2, 2, 2}; +static constexpr std::array expected280{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation281 = "cba,cd,edf->cbd"; +static constexpr std::array shape281{2, 2, 2}; +static constexpr std::array expected281{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation282 = "cba,cd,edf->cae"; +static constexpr std::array shape282{2, 2, 2}; +static constexpr std::array expected282{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation283 = "cba,cd,edf->cad"; +static constexpr std::array shape283{2, 2, 2}; +static constexpr std::array expected283{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation284 = "cba,cd,efd->cbe"; +static constexpr std::array shape284{2, 2, 2}; +static constexpr std::array expected284{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation285 = "cba,cd,efd->cbf"; +static constexpr std::array shape285{2, 2, 2}; +static constexpr std::array expected285{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation286 = "cba,cd,efd->cae"; +static constexpr std::array shape286{2, 2, 2}; +static constexpr std::array expected286{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation287 = "cba,cd,efd->caf"; +static constexpr std::array shape287{2, 2, 2}; +static constexpr std::array expected287{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation288 = "cba,cd,fde->cbf"; +static constexpr std::array shape288{2, 2, 2}; +static constexpr std::array expected288{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation289 = "cba,cd,fde->cbd"; +static constexpr std::array shape289{2, 2, 2}; +static constexpr std::array expected289{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation290 = "cba,cd,fde->caf"; +static constexpr std::array shape290{2, 2, 2}; +static constexpr std::array expected290{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation291 = "cba,cd,fde->cad"; +static constexpr std::array shape291{2, 2, 2}; +static constexpr std::array expected291{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation292 = "cba,cd,fed->cbf"; +static constexpr std::array shape292{2, 2, 2}; +static constexpr std::array expected292{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation293 = "cba,cd,fed->cbe"; +static constexpr std::array shape293{2, 2, 2}; +static constexpr std::array expected293{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation294 = "cba,cd,fed->caf"; +static constexpr std::array shape294{2, 2, 2}; +static constexpr std::array expected294{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation295 = "cba,cd,fed->cae"; +static constexpr std::array shape295{2, 2, 2}; +static constexpr std::array expected295{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation296 = "cba,dc,def->cbd"; +static constexpr std::array shape296{2, 2, 2}; +static constexpr std::array expected296{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation297 = "cba,dc,def->cbe"; +static constexpr std::array shape297{2, 2, 2}; +static constexpr std::array expected297{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation298 = "cba,dc,def->cad"; +static constexpr std::array shape298{2, 2, 2}; +static constexpr std::array expected298{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +static constexpr std::string_view equation299 = "cba,dc,def->cae"; +static constexpr std::array shape299{2, 2, 2}; +static constexpr std::array expected299{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation300 = "cba,dc,dfe->cbd"; +static constexpr std::array shape300{2, 2, 2}; +static constexpr std::array expected300{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation301 = "cba,dc,dfe->cbf"; +static constexpr std::array shape301{2, 2, 2}; +static constexpr std::array expected301{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation302 = "cba,dc,dfe->cad"; +static constexpr std::array shape302{2, 2, 2}; +static constexpr std::array expected302{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +static constexpr std::string_view equation303 = "cba,dc,dfe->caf"; +static constexpr std::array shape303{2, 2, 2}; +static constexpr std::array expected303{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation304 = "cba,dc,edf->cbe"; +static constexpr std::array shape304{2, 2, 2}; +static constexpr std::array expected304{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation305 = "cba,dc,edf->cbd"; +static constexpr std::array shape305{2, 2, 2}; +static constexpr std::array expected305{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation306 = "cba,dc,edf->cae"; +static constexpr std::array shape306{2, 2, 2}; +static constexpr std::array expected306{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation307 = "cba,dc,edf->cad"; +static constexpr std::array shape307{2, 2, 2}; +static constexpr std::array expected307{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation308 = "cba,dc,efd->cbe"; +static constexpr std::array shape308{2, 2, 2}; +static constexpr std::array expected308{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation309 = "cba,dc,efd->cbf"; +static constexpr std::array shape309{2, 2, 2}; +static constexpr std::array expected309{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation310 = "cba,dc,efd->cae"; +static constexpr std::array shape310{2, 2, 2}; +static constexpr std::array expected310{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation311 = "cba,dc,efd->caf"; +static constexpr std::array shape311{2, 2, 2}; +static constexpr std::array expected311{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::string_view equation312 = "cba,dc,fde->cbf"; +static constexpr std::array shape312{2, 2, 2}; +static constexpr std::array expected312{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation313 = "cba,dc,fde->cbd"; +static constexpr std::array shape313{2, 2, 2}; +static constexpr std::array expected313{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation314 = "cba,dc,fde->caf"; +static constexpr std::array shape314{2, 2, 2}; +static constexpr std::array expected314{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation315 = "cba,dc,fde->cad"; +static constexpr std::array shape315{2, 2, 2}; +static constexpr std::array expected315{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation316 = "cba,dc,fed->cbf"; +static constexpr std::array shape316{2, 2, 2}; +static constexpr std::array expected316{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation317 = "cba,dc,fed->cbe"; +static constexpr std::array shape317{2, 2, 2}; +static constexpr std::array expected317{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation318 = "cba,dc,fed->caf"; +static constexpr std::array shape318{2, 2, 2}; +static constexpr std::array expected318{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation319 = "cba,dc,fed->cae"; +static constexpr std::array shape319{2, 2, 2}; +static constexpr std::array expected319{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::array case1 = {{{equation32, shape32, expected32}, + {equation33, shape33, expected33}, + {equation34, shape34, expected34}, + {equation35, shape35, expected35}, + {equation36, shape36, expected36}, + {equation37, shape37, expected37}, + {equation38, shape38, expected38}, + {equation39, shape39, expected39}, + {equation40, shape40, expected40}, + {equation41, shape41, expected41}, + {equation42, shape42, expected42}, + {equation43, shape43, expected43}, + {equation44, shape44, expected44}, + {equation45, shape45, expected45}, + {equation46, shape46, expected46}, + {equation47, shape47, expected47}, + {equation48, shape48, expected48}, + {equation49, shape49, expected49}, + {equation50, shape50, expected50}, + {equation51, shape51, expected51}, + {equation52, shape52, expected52}, + {equation53, shape53, expected53}, + {equation54, shape54, expected54}, + {equation55, shape55, expected55}, + {equation56, shape56, expected56}, + {equation57, shape57, expected57}, + {equation58, shape58, expected58}, + {equation59, shape59, expected59}, + {equation60, shape60, expected60}, + {equation61, shape61, expected61}, + {equation62, shape62, expected62}, + {equation63, shape63, expected63}, + {equation64, shape64, expected64}, + {equation65, shape65, expected65}, + {equation66, shape66, expected66}, + {equation67, shape67, expected67}, + {equation68, shape68, expected68}, + {equation69, shape69, expected69}, + {equation70, shape70, expected70}, + {equation71, shape71, expected71}, + {equation72, shape72, expected72}, + {equation73, shape73, expected73}, + {equation74, shape74, expected74}, + {equation75, shape75, expected75}, + {equation76, shape76, expected76}, + {equation77, shape77, expected77}, + {equation78, shape78, expected78}, + {equation79, shape79, expected79}, + {equation80, shape80, expected80}, + {equation81, shape81, expected81}, + {equation82, shape82, expected82}, + {equation83, shape83, expected83}, + {equation84, shape84, expected84}, + {equation85, shape85, expected85}, + {equation86, shape86, expected86}, + {equation87, shape87, expected87}, + {equation88, shape88, expected88}, + {equation89, shape89, expected89}, + {equation90, shape90, expected90}, + {equation91, shape91, expected91}, + {equation92, shape92, expected92}, + {equation93, shape93, expected93}, + {equation94, shape94, expected94}, + {equation95, shape95, expected95}, + {equation96, shape96, expected96}, + {equation97, shape97, expected97}, + {equation98, shape98, expected98}, + {equation99, shape99, expected99}, + {equation100, shape100, expected100}, + {equation101, shape101, expected101}, + {equation102, shape102, expected102}, + {equation103, shape103, expected103}, + {equation104, shape104, expected104}, + {equation105, shape105, expected105}, + {equation106, shape106, expected106}, + {equation107, shape107, expected107}, + {equation108, shape108, expected108}, + {equation109, shape109, expected109}, + {equation110, shape110, expected110}, + {equation111, shape111, expected111}, + {equation112, shape112, expected112}, + {equation113, shape113, expected113}, + {equation114, shape114, expected114}, + {equation115, shape115, expected115}, + {equation116, shape116, expected116}, + {equation117, shape117, expected117}, + {equation118, shape118, expected118}, + {equation119, shape119, expected119}, + {equation120, shape120, expected120}, + {equation121, shape121, expected121}, + {equation122, shape122, expected122}, + {equation123, shape123, expected123}, + {equation124, shape124, expected124}, + {equation125, shape125, expected125}, + {equation126, shape126, expected126}, + {equation127, shape127, expected127}, + {equation128, shape128, expected128}, + {equation129, shape129, expected129}, + {equation130, shape130, expected130}, + {equation131, shape131, expected131}, + {equation132, shape132, expected132}, + {equation133, shape133, expected133}, + {equation134, shape134, expected134}, + {equation135, shape135, expected135}, + {equation136, shape136, expected136}, + {equation137, shape137, expected137}, + {equation138, shape138, expected138}, + {equation139, shape139, expected139}, + {equation140, shape140, expected140}, + {equation141, shape141, expected141}, + {equation142, shape142, expected142}, + {equation143, shape143, expected143}, + {equation144, shape144, expected144}, + {equation145, shape145, expected145}, + {equation146, shape146, expected146}, + {equation147, shape147, expected147}, + {equation148, shape148, expected148}, + {equation149, shape149, expected149}, + {equation150, shape150, expected150}, + {equation151, shape151, expected151}, + {equation152, shape152, expected152}, + {equation153, shape153, expected153}, + {equation154, shape154, expected154}, + {equation155, shape155, expected155}, + {equation156, shape156, expected156}, + {equation157, shape157, expected157}, + {equation158, shape158, expected158}, + {equation159, shape159, expected159}, + {equation160, shape160, expected160}, + {equation161, shape161, expected161}, + {equation162, shape162, expected162}, + {equation163, shape163, expected163}, + {equation164, shape164, expected164}, + {equation165, shape165, expected165}, + {equation166, shape166, expected166}, + {equation167, shape167, expected167}, + {equation168, shape168, expected168}, + {equation169, shape169, expected169}, + {equation170, shape170, expected170}, + {equation171, shape171, expected171}, + {equation172, shape172, expected172}, + {equation173, shape173, expected173}, + {equation174, shape174, expected174}, + {equation175, shape175, expected175}, + {equation176, shape176, expected176}, + {equation177, shape177, expected177}, + {equation178, shape178, expected178}, + {equation179, shape179, expected179}, + {equation180, shape180, expected180}, + {equation181, shape181, expected181}, + {equation182, shape182, expected182}, + {equation183, shape183, expected183}, + {equation184, shape184, expected184}, + {equation185, shape185, expected185}, + {equation186, shape186, expected186}, + {equation187, shape187, expected187}, + {equation188, shape188, expected188}, + {equation189, shape189, expected189}, + {equation190, shape190, expected190}, + {equation191, shape191, expected191}, + {equation192, shape192, expected192}, + {equation193, shape193, expected193}, + {equation194, shape194, expected194}, + {equation195, shape195, expected195}, + {equation196, shape196, expected196}, + {equation197, shape197, expected197}, + {equation198, shape198, expected198}, + {equation199, shape199, expected199}, + {equation200, shape200, expected200}, + {equation201, shape201, expected201}, + {equation202, shape202, expected202}, + {equation203, shape203, expected203}, + {equation204, shape204, expected204}, + {equation205, shape205, expected205}, + {equation206, shape206, expected206}, + {equation207, shape207, expected207}, + {equation208, shape208, expected208}, + {equation209, shape209, expected209}, + {equation210, shape210, expected210}, + {equation211, shape211, expected211}, + {equation212, shape212, expected212}, + {equation213, shape213, expected213}, + {equation214, shape214, expected214}, + {equation215, shape215, expected215}, + {equation216, shape216, expected216}, + {equation217, shape217, expected217}, + {equation218, shape218, expected218}, + {equation219, shape219, expected219}, + {equation220, shape220, expected220}, + {equation221, shape221, expected221}, + {equation222, shape222, expected222}, + {equation223, shape223, expected223}, + {equation224, shape224, expected224}, + {equation225, shape225, expected225}, + {equation226, shape226, expected226}, + {equation227, shape227, expected227}, + {equation228, shape228, expected228}, + {equation229, shape229, expected229}, + {equation230, shape230, expected230}, + {equation231, shape231, expected231}, + {equation232, shape232, expected232}, + {equation233, shape233, expected233}, + {equation234, shape234, expected234}, + {equation235, shape235, expected235}, + {equation236, shape236, expected236}, + {equation237, shape237, expected237}, + {equation238, shape238, expected238}, + {equation239, shape239, expected239}, + {equation240, shape240, expected240}, + {equation241, shape241, expected241}, + {equation242, shape242, expected242}, + {equation243, shape243, expected243}, + {equation244, shape244, expected244}, + {equation245, shape245, expected245}, + {equation246, shape246, expected246}, + {equation247, shape247, expected247}, + {equation248, shape248, expected248}, + {equation249, shape249, expected249}, + {equation250, shape250, expected250}, + {equation251, shape251, expected251}, + {equation252, shape252, expected252}, + {equation253, shape253, expected253}, + {equation254, shape254, expected254}, + {equation255, shape255, expected255}, + {equation256, shape256, expected256}, + {equation257, shape257, expected257}, + {equation258, shape258, expected258}, + {equation259, shape259, expected259}, + {equation260, shape260, expected260}, + {equation261, shape261, expected261}, + {equation262, shape262, expected262}, + {equation263, shape263, expected263}, + {equation264, shape264, expected264}, + {equation265, shape265, expected265}, + {equation266, shape266, expected266}, + {equation267, shape267, expected267}, + {equation268, shape268, expected268}, + {equation269, shape269, expected269}, + {equation270, shape270, expected270}, + {equation271, shape271, expected271}, + {equation272, shape272, expected272}, + {equation273, shape273, expected273}, + {equation274, shape274, expected274}, + {equation275, shape275, expected275}, + {equation276, shape276, expected276}, + {equation277, shape277, expected277}, + {equation278, shape278, expected278}, + {equation279, shape279, expected279}, + {equation280, shape280, expected280}, + {equation281, shape281, expected281}, + {equation282, shape282, expected282}, + {equation283, shape283, expected283}, + {equation284, shape284, expected284}, + {equation285, shape285, expected285}, + {equation286, shape286, expected286}, + {equation287, shape287, expected287}, + {equation288, shape288, expected288}, + {equation289, shape289, expected289}, + {equation290, shape290, expected290}, + {equation291, shape291, expected291}, + {equation292, shape292, expected292}, + {equation293, shape293, expected293}, + {equation294, shape294, expected294}, + {equation295, shape295, expected295}, + {equation296, shape296, expected296}, + {equation297, shape297, expected297}, + {equation298, shape298, expected298}, + {equation299, shape299, expected299}, + {equation300, shape300, expected300}, + {equation301, shape301, expected301}, + {equation302, shape302, expected302}, + {equation303, shape303, expected303}, + {equation304, shape304, expected304}, + {equation305, shape305, expected305}, + {equation306, shape306, expected306}, + {equation307, shape307, expected307}, + {equation308, shape308, expected308}, + {equation309, shape309, expected309}, + {equation310, shape310, expected310}, + {equation311, shape311, expected311}, + {equation312, shape312, expected312}, + {equation313, shape313, expected313}, + {equation314, shape314, expected314}, + {equation315, shape315, expected315}, + {equation316, shape316, expected316}, + {equation317, shape317, expected317}, + {equation318, shape318, expected318}, + {equation319, shape319, expected319}}}; TEST(Einsum, EinsumTransposeMatMulTwoInputsTestSuite) { - std::vector test_cases{ - EinsumTestCase("abc,cd->abc", std::vector{2, 2, 2}, std::vector{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}), - EinsumTestCase("abc,cd->abd", std::vector{2, 2, 2}, std::vector{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}), - EinsumTestCase("abc,cd->acd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}), - EinsumTestCase("abc,dc->abd", std::vector{2, 2, 2}, std::vector{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}), - EinsumTestCase("abc,dc->abc", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}), - EinsumTestCase("abc,dc->acd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}), - EinsumTestCase("acb,cd->acd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}), - EinsumTestCase("acb,cd->abc", std::vector{2, 2, 2}, std::vector{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}), - EinsumTestCase("acb,cd->abd", std::vector{2, 2, 2}, std::vector{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}), - EinsumTestCase("acb,dc->acd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}), - EinsumTestCase("acb,dc->abd", std::vector{2, 2, 2}, std::vector{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}), - EinsumTestCase("acb,dc->abc", std::vector{2, 2, 2}, std::vector{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}), - EinsumTestCase("bac,cd->bac", std::vector{2, 2, 2}, std::vector{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}), - EinsumTestCase("bac,cd->bad", std::vector{2, 2, 2}, std::vector{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}), - EinsumTestCase("bac,cd->bcd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}), - EinsumTestCase("bac,dc->bad", std::vector{2, 2, 2}, std::vector{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}), - EinsumTestCase("bac,dc->bac", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}), - EinsumTestCase("bac,dc->bcd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}), - EinsumTestCase("bca,cd->bcd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}), - EinsumTestCase("bca,cd->bac", std::vector{2, 2, 2}, std::vector{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}), - EinsumTestCase("bca,cd->bad", std::vector{2, 2, 2}, std::vector{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}), - EinsumTestCase("bca,dc->bcd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}), - EinsumTestCase("bca,dc->bad", std::vector{2, 2, 2}, std::vector{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}), - EinsumTestCase("bca,dc->bac", std::vector{2, 2, 2}, std::vector{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}), - EinsumTestCase("cab,cd->cad", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}), - EinsumTestCase("cab,cd->cbd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}), - EinsumTestCase("cab,dc->cad", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}), - EinsumTestCase("cab,dc->cbd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}), - EinsumTestCase("cba,cd->cbd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}), - EinsumTestCase("cba,cd->cad", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}), - EinsumTestCase("cba,dc->cbd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}), - EinsumTestCase("cba,dc->cad", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f})}; - std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; std::vector m2{0.f, 1.f, 2.f, 3.f}; - for (const auto& tst : test_cases) { + for (const auto& tst : case0) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); - test.AddAttribute("equation", tst.equation); + std::string s(tst.equation); + test.AddAttribute("equation", s); test.AddInput("x", {2, 2, 2}, m1); test.AddInput("y", {2, 2}, m2); - test.AddOutput("o", tst.shape, tst.expected); + + std::vector v1(tst.shape.begin(), tst.shape.end()); + std::vector v2(tst.expected.begin(), tst.expected.end()); + test.AddOutput("o", v1, v2); test.Run(); } } -TEST(Einsum, EinsumTransposeMatMulThreeInputsTestSuite) { - std::vector test_cases_set_1{ - EinsumTestCase("abc,cd,def->abd", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("abc,cd,def->abe", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("abc,cd,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("abc,cd,def->ace", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("abc,cd,dfe->abd", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("abc,cd,dfe->abf", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("abc,cd,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("abc,cd,dfe->acf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("abc,cd,edf->abe", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("abc,cd,edf->abd", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("abc,cd,edf->ace", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("abc,cd,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("abc,cd,efd->abe", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("abc,cd,efd->abf", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("abc,cd,efd->ace", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("abc,cd,efd->acf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("abc,cd,fde->abf", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("abc,cd,fde->abd", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("abc,cd,fde->acf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("abc,cd,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("abc,cd,fed->abf", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("abc,cd,fed->abe", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("abc,cd,fed->acf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("abc,cd,fed->ace", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("abc,dc,def->abd", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("abc,dc,def->abe", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("abc,dc,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("abc,dc,def->ace", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("abc,dc,dfe->abd", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("abc,dc,dfe->abf", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("abc,dc,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("abc,dc,dfe->acf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("abc,dc,edf->abe", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("abc,dc,edf->abd", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("abc,dc,edf->ace", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("abc,dc,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("abc,dc,efd->abe", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("abc,dc,efd->abf", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("abc,dc,efd->ace", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("abc,dc,efd->acf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("abc,dc,fde->abf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("abc,dc,fde->abd", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("abc,dc,fde->acf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("abc,dc,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("abc,dc,fed->abf", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("abc,dc,fed->abe", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("abc,dc,fed->acf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("abc,dc,fed->ace", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("acb,cd,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("acb,cd,def->ace", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("acb,cd,def->abd", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("acb,cd,def->abe", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("acb,cd,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("acb,cd,dfe->acf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("acb,cd,dfe->abd", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("acb,cd,dfe->abf", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("acb,cd,edf->ace", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("acb,cd,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("acb,cd,edf->abe", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("acb,cd,edf->abd", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("acb,cd,efd->ace", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("acb,cd,efd->acf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("acb,cd,efd->abe", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("acb,cd,efd->abf", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("acb,cd,fde->acf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("acb,cd,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("acb,cd,fde->abf", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("acb,cd,fde->abd", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("acb,cd,fed->acf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("acb,cd,fed->ace", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("acb,cd,fed->abf", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("acb,cd,fed->abe", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("acb,dc,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("acb,dc,def->ace", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f})}; - - std::vector test_cases_set_2{ - EinsumTestCase("acb,dc,def->abd", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("acb,dc,def->abe", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("acb,dc,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("acb,dc,dfe->acf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), - EinsumTestCase("acb,dc,dfe->abd", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("acb,dc,dfe->abf", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("acb,dc,edf->ace", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("acb,dc,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("acb,dc,edf->abe", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("acb,dc,edf->abd", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("acb,dc,efd->ace", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("acb,dc,efd->acf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("acb,dc,efd->abe", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("acb,dc,efd->abf", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("acb,dc,fde->acf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("acb,dc,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("acb,dc,fde->abf", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("acb,dc,fde->abd", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("acb,dc,fed->acf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("acb,dc,fed->ace", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("acb,dc,fed->abf", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("acb,dc,fed->abe", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("bac,cd,def->bad", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("bac,cd,def->bae", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("bac,cd,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("bac,cd,def->bce", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("bac,cd,dfe->bad", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("bac,cd,dfe->baf", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("bac,cd,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("bac,cd,dfe->bcf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("bac,cd,edf->bae", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("bac,cd,edf->bad", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("bac,cd,edf->bce", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("bac,cd,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("bac,cd,efd->bae", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("bac,cd,efd->baf", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("bac,cd,efd->bce", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("bac,cd,efd->bcf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("bac,cd,fde->baf", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("bac,cd,fde->bad", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("bac,cd,fde->bcf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("bac,cd,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("bac,cd,fed->baf", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("bac,cd,fed->bae", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("bac,cd,fed->bcf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("bac,cd,fed->bce", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("bac,dc,def->bad", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("bac,dc,def->bae", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("bac,dc,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("bac,dc,def->bce", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("bac,dc,dfe->bad", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("bac,dc,dfe->baf", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("bac,dc,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("bac,dc,dfe->bcf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("bac,dc,edf->bae", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("bac,dc,edf->bad", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("bac,dc,edf->bce", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("bac,dc,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("bac,dc,efd->bae", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("bac,dc,efd->baf", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("bac,dc,efd->bce", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("bac,dc,efd->bcf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("bac,dc,fde->baf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("bac,dc,fde->bad", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("bac,dc,fde->bcf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("bac,dc,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("bac,dc,fed->baf", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("bac,dc,fed->bae", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("bac,dc,fed->bcf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("bac,dc,fed->bce", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("bca,cd,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("bca,cd,def->bce", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("bca,cd,def->bad", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("bca,cd,def->bae", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("bca,cd,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("bca,cd,dfe->bcf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("bca,cd,dfe->bad", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("bca,cd,dfe->baf", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("bca,cd,edf->bce", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("bca,cd,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("bca,cd,edf->bae", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("bca,cd,edf->bad", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("bca,cd,efd->bce", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("bca,cd,efd->bcf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("bca,cd,efd->bae", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("bca,cd,efd->baf", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("bca,cd,fde->bcf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("bca,cd,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("bca,cd,fde->baf", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("bca,cd,fde->bad", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("bca,cd,fed->bcf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("bca,cd,fed->bce", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("bca,cd,fed->baf", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("bca,cd,fed->bae", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("bca,dc,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("bca,dc,def->bce", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), - EinsumTestCase("bca,dc,def->bad", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("bca,dc,def->bae", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("bca,dc,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("bca,dc,dfe->bcf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), - EinsumTestCase("bca,dc,dfe->bad", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("bca,dc,dfe->baf", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("bca,dc,edf->bce", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("bca,dc,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("bca,dc,edf->bae", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("bca,dc,edf->bad", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("bca,dc,efd->bce", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("bca,dc,efd->bcf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("bca,dc,efd->bae", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("bca,dc,efd->baf", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("bca,dc,fde->bcf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("bca,dc,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("bca,dc,fde->baf", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("bca,dc,fde->bad", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("bca,dc,fed->bcf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("bca,dc,fed->bce", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("bca,dc,fed->baf", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("bca,dc,fed->bae", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("cab,cd,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cab,cd,def->cae", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cab,cd,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cab,cd,def->cbe", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cab,cd,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cab,cd,dfe->caf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cab,cd,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cab,cd,dfe->cbf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cab,cd,edf->cae", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cab,cd,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cab,cd,edf->cbe", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cab,cd,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cab,cd,efd->cae", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cab,cd,efd->caf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cab,cd,efd->cbe", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cab,cd,efd->cbf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cab,cd,fde->caf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cab,cd,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cab,cd,fde->cbf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cab,cd,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cab,cd,fed->caf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cab,cd,fed->cae", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cab,cd,fed->cbf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cab,cd,fed->cbe", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cab,dc,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cab,dc,def->cae", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cab,dc,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f})}; - - std::vector test_cases_set_3{ - EinsumTestCase("cab,dc,def->cbe", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cab,dc,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cab,dc,dfe->caf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cab,dc,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), - EinsumTestCase("cab,dc,dfe->cbf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cab,dc,edf->cae", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cab,dc,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cab,dc,edf->cbe", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cab,dc,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cab,dc,efd->cae", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cab,dc,efd->caf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cab,dc,efd->cbe", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cab,dc,efd->cbf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), - EinsumTestCase("cab,dc,fde->caf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cab,dc,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cab,dc,fde->cbf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cab,dc,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cab,dc,fed->caf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cab,dc,fed->cae", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cab,dc,fed->cbf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cab,dc,fed->cbe", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), - EinsumTestCase("cba,cd,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cba,cd,def->cbe", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cba,cd,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cba,cd,def->cae", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cba,cd,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cba,cd,dfe->cbf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cba,cd,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cba,cd,dfe->caf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cba,cd,edf->cbe", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cba,cd,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cba,cd,edf->cae", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cba,cd,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cba,cd,efd->cbe", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cba,cd,efd->cbf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cba,cd,efd->cae", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cba,cd,efd->caf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cba,cd,fde->cbf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cba,cd,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cba,cd,fde->caf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cba,cd,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cba,cd,fed->cbf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cba,cd,fed->cbe", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cba,cd,fed->caf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cba,cd,fed->cae", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cba,dc,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cba,dc,def->cbe", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cba,dc,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), - EinsumTestCase("cba,dc,def->cae", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cba,dc,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cba,dc,dfe->cbf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cba,dc,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), - EinsumTestCase("cba,dc,dfe->caf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cba,dc,edf->cbe", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cba,dc,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cba,dc,edf->cae", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cba,dc,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cba,dc,efd->cbe", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cba,dc,efd->cbf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cba,dc,efd->cae", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cba,dc,efd->caf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), - EinsumTestCase("cba,dc,fde->cbf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cba,dc,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cba,dc,fde->caf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cba,dc,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cba,dc,fed->cbf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cba,dc,fed->cbe", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cba,dc,fed->caf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cba,dc,fed->cae", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f})}; - - auto test_lambda = [](const std::vector& test_cases_set) { - std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; - std::vector m2{0.f, 1.f, 2.f, 3.f}; - std::vector m3{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; - for (const auto& tst : test_cases_set) { - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); - test.AddAttribute("equation", tst.equation); - test.AddInput("x", {2, 2, 2}, m1); - test.AddInput("y", {2, 2}, m2); - test.AddInput("z", {2, 2, 2}, m3); - test.AddOutput("o", tst.shape, tst.expected); - test.Run(); - } - }; - - test_lambda(test_cases_set_1); - test_lambda(test_cases_set_2); - test_lambda(test_cases_set_3); +class EinsumTransposeMatMulThreeInputsTest : public testing::TestWithParam { +}; -} // namespace test +TEST_P(EinsumTransposeMatMulThreeInputsTest, EinsumTransposeMatMulThreeInputsTestSuite) { + const auto& tst = GetParam(); + std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + std::vector m2{0.f, 1.f, 2.f, 3.f}; + std::vector m3{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + std::string s(tst.equation); + test.AddAttribute("equation", s); + test.AddInput("x", {2, 2, 2}, m1); + test.AddInput("y", {2, 2}, m2); + test.AddInput("z", {2, 2, 2}, m3); + std::vector v1(tst.shape.begin(), tst.shape.end()); + std::vector v2(tst.expected.begin(), tst.expected.end()); + test.AddOutput("o", v1, v2); + test.Run(); +} + +INSTANTIATE_TEST_SUITE_P(EinsumTransposeMatMulThreeInputsTests, EinsumTransposeMatMulThreeInputsTest, testing::ValuesIn(case1)); } // namespace test } // namespace onnxruntime From 90883a366ae2ce5402c8b887638a8b2ae3e0efdd Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 31 Jan 2024 08:28:53 +0800 Subject: [PATCH 11/11] [js/webgpu] Add hardSigmoid activation for fusedConv (#19233) ### Description Add hardSigmoid activation for fusedConv. It will be used by mobilenetv3-small-100 model. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 11 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 11 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 12 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 37 +++-- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 35 ++++- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 12 +- js/web/test/data/ops/fused-conv.jsonc | 144 ++++++++++++++++++ .../core/optimizer/conv_activation_fusion.cc | 2 +- 8 files changed, 207 insertions(+), 57 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 1a03621512888..e5ca3204d4433 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -193,10 +193,7 @@ export const createConv2DMatMulProgramInfo = {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; @@ -212,9 +209,7 @@ export const createConv2DMatMulProgramInfo = {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, {name: 'dilation', type: 'i32', length: 2} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 33e50a9a39cb9..e50733559dbe9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -201,10 +201,7 @@ export const createConv2DTransposeMatMulProgramInfo = {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, {type: 'int32', data: filterDims}, {type: 'int32', data: pads} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); @@ -237,9 +234,7 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'filter_dims', type: 'i32', length: filterDims.length}, {name: 'pads', type: 'i32', length: pads.length} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 5881c055ef135..00c1f86d67419 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -449,11 +449,7 @@ export const createMatmulProgramInfo = const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } + appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), ...createTensorShapeVariables(bShapeTemp)); @@ -481,9 +477,7 @@ export const createMatmulProgramInfo = } const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index f81d6577890c5..c0aaaa7ce134b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../ import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActivationSnippet} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -32,10 +32,7 @@ export const createGroupedConvProgramInfo = {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShape)); @@ -61,9 +58,7 @@ export const createGroupedConvProgramInfo = {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, {name: 'output_channels_per_group', type: 'u32'} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} @@ -132,10 +127,13 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), - ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + {type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]} ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShapeInShader)); const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); @@ -147,13 +145,14 @@ export const createGroupedConvVectorizeProgramInfo = inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); } const processBias = hasBias ? 'value += b[output_channel];' : ''; - + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'strides', type: 'i32', length: 2}, + {name: 'pads', type: 'i32', length: 2}, + ]; + appendActivationUniforms(attributes, uniforms); return ` - ${ - shaderHelper.registerUniform('output_size', 'u32') - .registerUniform('strides', 'i32', 2) - .registerUniform('pads', 'i32', 2) - .declareVariables(...inputVars, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -173,7 +172,7 @@ export const createGroupedConvVectorizeProgramInfo = // Use constant instead of uniform can give better performance for w's height/width. for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { let x_height = x_corner.x + i32(w_height); - if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { for (var i = 0; i < ${xNumber}; i++) { let x_width = x_corner.y + i; if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { @@ -185,7 +184,7 @@ export const createGroupedConvVectorizeProgramInfo = for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; for (var i = 0u; i < ${outputNumber}u; i++) { - values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 2e0aa33a957dc..e1dc9a5e0ab7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -2,11 +2,16 @@ // Licensed under the MIT License. import {MAX_CLIP, MIN_CLIP} from '../../util'; +import {ProgramUniform} from '../types'; + +import {UniformsArrayType} from './common'; export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; + readonly alpha?: number; + readonly beta?: number; } export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { @@ -17,17 +22,41 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; case 'Clip': return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${ + valueType}(uniforms.beta)));`; + case '': + return ''; // TODO: adding other activations that can be fused. default: - return ''; + throw new Error(`Unsupported activation ${attributes.activation}`); + } +}; + +export const appendActivationUniformsData = + (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { + if (attributes.activation === 'Clip') { + programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } else if (attributes.activation === 'HardSigmoid') { + programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!}); + } + }; + +export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } else if (attributes.activation === 'HardSigmoid') { + uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); } }; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { const activation = attributes?.activation as string || ''; - - if (activation === 'Clip') { + if (activation === 'HardSigmoid') { + const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; + return {activation, alpha, beta}; + } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index c946ea6366123..188b88b2510d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -7,7 +7,7 @@ import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], @@ -32,11 +32,7 @@ export const createNaiveMatmulProgramInfo = {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K} ]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } + appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), ...createTensorShapeVariables(bShape)); @@ -69,9 +65,7 @@ export const createNaiveMatmulProgramInfo = {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'} ]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index ad1c0a72c11d3..c734d6db9b92a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -142,5 +142,149 @@ ] } ] + }, + { + "name": "fused conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index d27603e4ab3a1..b7cb3ba488c62 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -111,7 +111,7 @@ class ConvActivationSelector : public NodeSelector { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } - } else if (node_ep.empty() || node_ep == kCpuExecutionProvider) { + } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt;