From a7ed425382f16eadcc7d7b874099d231bf42d2cf Mon Sep 17 00:00:00 2001 From: shewu-quic <138087975+shewu-quic@users.noreply.github.com> Date: Sat, 23 Nov 2024 05:48:27 +0800 Subject: [PATCH 01/18] Qaulcomm AI Engine Direct - Fix quantization annotation for per channel quant (#7026) summary: - Fix the 8a8w config in custom annotation - Enable to set act observer and symmetic argument for per channel quant - Remove unuse custom annotation in llama.py --- backends/qualcomm/_passes/layout_transform.py | 1 + backends/qualcomm/partition/common_defs.py | 1 + .../qualcomm/quantizer/custom_annotation.py | 2 +- backends/qualcomm/quantizer/qconfig.py | 29 ++++++++++---- backends/qualcomm/quantizer/quantizer.py | 40 +++++++++++++++---- .../qualcomm/oss_scripts/llama3_2/llama.py | 5 +-- 6 files changed, 58 insertions(+), 20 deletions(-) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index a73ce9acbd..851b547eb6 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -64,6 +64,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index d68441c2f7..1c24d00390 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -14,6 +14,7 @@ exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] to_be_implemented_operator = [ diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 8a3ff40571..0e021c02e6 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -22,7 +22,7 @@ from torch.fx import Node -def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 """ This function is specific for matmul op 16a8w. """ diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index e07ca24d90..abe51066ba 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -221,6 +221,7 @@ def get_ptq_per_channel_quant_config( act_dtype=torch.uint8, weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, + act_symmetric: bool = False, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} @@ -241,13 +242,27 @@ def get_ptq_per_channel_quant_config( ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index da7b0174c0..7a41fb1ae2 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique +from functools import partial from typing import Callable, Optional, Sequence, Set import torch @@ -67,28 +68,44 @@ class QuantDtype(IntEnum): # PTQ (QuantDtype.use_16a16w, False): ( get_16a16w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int16), + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int16, + ), ), (QuantDtype.use_16a8w, False): ( get_16a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int8), + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int8, + ), ), (QuantDtype.use_16a4w, False): ( get_16a4w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, "int4"), + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype="int4", + ), ), (QuantDtype.use_8a8w, False): ( get_8a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(), + partial(get_ptq_per_channel_quant_config), ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, - get_qat_per_channel_quant_config(torch.uint16, "int4"), + partial( + get_qat_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype="int4", + ), ), (QuantDtype.use_8a8w, True): ( get_8a8w_qnn_qat_config, - get_qat_per_channel_quant_config(), + partial(get_qat_per_channel_quant_config), ), } @@ -176,11 +193,18 @@ def set_quant_config( f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" ) - quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ + quant_config_fuc, per_channel_quant_config_fuc = quant_config_dict[ (quant_dtype, is_qat) ] self.quant_config = ( - quant_config_fuc(act_observer) if act_observer else quant_config_fuc() + quant_config_fuc(act_observer=act_observer) + if act_observer + else quant_config_fuc() + ) + self.per_channel_quant_config = ( + per_channel_quant_config_fuc(act_observer=act_observer) + if act_observer + else per_channel_quant_config_fuc() ) def set_per_channel_conv_quant(self, enable: bool) -> None: diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 706c04fd0d..532eb68319 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -293,10 +293,7 @@ def compile(args): start_quantize_ts = time.time() single_llama.quantize( quant_dtype, - custom_annotations=( - custom_annotate_llama_last_conv_16a8w, - annotate_matmul_16a8w, - ), + custom_annotations=(annotate_matmul_16a8w,), ) end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") From b3f2a793b6358956709f6db1adf51e8038c27745 Mon Sep 17 00:00:00 2001 From: Val Tarasyuk Date: Fri, 22 Nov 2024 15:48:16 -0800 Subject: [PATCH 02/18] Add HardTanh to RemovePermutesAroundElementwiseOps (#6992) Differential Revision: D66187338 Pull Request resolved: https://github.com/pytorch/executorch/pull/7036 --- backends/cadence/aot/remove_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index d2251bd9c0..038a219207 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -561,6 +561,7 @@ class Subgraph: exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.hardtanh.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, } From fbcc9a13f875113ab2a0e0f8a67a6586f1fd3bda Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Sat, 23 Nov 2024 14:07:00 -0800 Subject: [PATCH 03/18] fix BUCK rules Differential Revision: D66413908 Pull Request resolved: https://github.com/pytorch/executorch/pull/7051 --- .../coreml/scripts/install_requirements.sh | 2 +- .../coreml/test/test_coreml_partitioner.py | 18 +++++------------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index b3ea0d77ca..b6a0a18b77 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -24,7 +24,7 @@ rm -rf "$COREML_DIR_PATH/third-party" mkdir "$COREML_DIR_PATH/third-party" echo "${green}ExecuTorch: Cloning coremltools." -git clone --depth 1 --branch 8.0 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH +git clone --depth 1 --branch 8.1 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH cd $COREMLTOOLS_DIR_PATH STATUS=$? diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 72a7fbf093..64e1570f0b 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -71,23 +71,15 @@ def test_vit_skip_conv(self): ) ) - conv_block = ["aten.convolution.default", "executorch_call_delegate"] - safe_softmax_block = [ - "getitem", - "getitem", - "getitem", - "getitem", - "aten.any.dim", - "executorch_call_delegate", - ] - final_block = ["getitem"] - total = conv_block + 12 * safe_softmax_block + final_block - assert [ node.target.__name__ for node in delegated_program_manager.exported_program().graph.nodes if node.op == "call_function" - ] == total + ] == [ + "aten.convolution.default", + "executorch_call_delegate", + "getitem", + ] def test_buffer(self): embedding_dim = 3 From 089087b2caf4eb5eefe05fb6fe8fd216b69fb9b7 Mon Sep 17 00:00:00 2001 From: cccclai Date: Sat, 23 Nov 2024 15:51:30 -0800 Subject: [PATCH 04/18] Add qnn 16a16w quantization test (#7039) Add qnn 16a16w quantization test (#7039) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/7039 Differential Revision: D66390212 --- .ci/scripts/test_llama.sh | 8 ++++++++ .github/workflows/trunk.yml | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index dad3e1101f..23a579e67c 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -27,6 +27,10 @@ while [[ $# -gt 0 ]]; do MODE="$2" # portable or xnnpack+custom or xnnpack+custom+qe shift 2 ;; + -pt2e_quantize) + PT2E_QUANTIZE="$2" + shift 2 + ;; -upload) UPLOAD_DIR="$2" shift 2 @@ -234,6 +238,10 @@ if [[ "${COREML}" == "ON" ]]; then fi if [[ "${QNN}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} -kv -v --qnn --disable_dynamic_shape" + echo "PT2E_QUANTIZE is ${PT2E_QUANTIZE}" + if [[ "${PT2E_QUANTIZE}" == "qnn_16a16w" ]]; then + EXPORT_ARGS+=" --tokenizer_path tokenizer.model --pt2e_quantize qnn_16a16w --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --calibration_data Once " + fi fi # Add dynamically linked library location $PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 7afc385a19..ae1b88fb18 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -441,3 +441,39 @@ jobs: cmake-out/examples/models/llama/llama_main --model_path=${ET_MODEL_NAME}.pte --tokenizer_path=${TOKENIZER_BIN_FILE} --prompt="My name is" echo "::endgroup::" + + + test-llama-runner-qnn-linux: + name: test-llama-runner-qnn-linux + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + strategy: + matrix: + dtype: [fp32] + pt2e_quantize: [qnn_16a16w, qnn_8a8w] + mode: [qnn] + fail-fast: false + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-qnn-sdk + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 900 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + BUILD_TOOL="cmake" + DTYPE=${{ matrix.dtype }} + MODE=${{ matrix.mode }} + PT2E_QUANTIZE=${{ matrix.pt2e_quantize }} + + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh + + # Setup executorch + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}" + # Install requirements for export_llama + PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh + # Test llama2 + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -mode "${MODE}" -dtype "${DTYPE}" -pt2e_quantize "${PT2E_QUANTIZE}" From 3f1b085cd5a3eaa57602c6d44fe7debcfbd3b818 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 12 Nov 2024 15:26:34 +0000 Subject: [PATCH 05/18] Add aot_arm_compiler flag to allow the reordering of the inputs * Add capability to use cmd input order in the backend * Extend the test infrastructure to handle this --- backends/arm/arm_backend.py | 31 +++++++++++++++++++++++++++++-- backends/arm/arm_vela.py | 15 +++++++++------ backends/arm/test/common.py | 32 ++++++++++++++++++++++++++------ examples/arm/aot_arm_compiler.py | 18 ++++++++++++++++-- examples/arm/run.sh | 6 +++++- 5 files changed, 85 insertions(+), 17 deletions(-) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 06207611e0..ad2d1e73af 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -52,6 +52,7 @@ def __init__(self): self.permute_nhwc = False self.quantize_io = False self.tosa_version = None + self.input_order = None def ethosu_compile_spec( self, @@ -134,6 +135,14 @@ def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder": self.quantize_io = quantize_io return self + def set_input_order(self, input_order: str = None) -> "ArmCompileSpecBuilder": + """ + Reorder the inputs coming in. This may be required when inputs > 1. + And while using the U55/U85 CompileSpec. + """ + self.input_order = input_order + return self + def build(self) -> List[CompileSpec]: """ Generate a list of compile spec objects from the builder @@ -163,6 +172,13 @@ def build(self) -> List[CompileSpec]: CompileSpec("permute_memory_format", "nhwc".encode()) ) + if self.input_order: + self.compile_spec.append( + CompileSpec( + "input_order", " ".join(map(str, self.input_order)).encode() + ) + ) + if self.quantize_io: self.compile_spec.append(CompileSpec("quantize_io", "True".encode())) @@ -214,6 +230,7 @@ def preprocess( # noqa: C901 artifact_path = None output_format = "" compile_flags = [] + input_order = [] for spec in compile_spec: if spec.key == "debug_artifact_path": artifact_path = spec.value.decode() @@ -221,6 +238,8 @@ def preprocess( # noqa: C901 output_format = spec.value.decode() if spec.key == "compile_flags": compile_flags.append(spec.value.decode()) + if spec.key == "input_order": + input_order = list(map(int, spec.value.decode().split(","))) # Check that the output format is set in the compile spec if not output_format: @@ -246,12 +265,14 @@ def preprocess( # noqa: C901 ) node_visitors = get_node_visitors(edge_program, tosa_spec) - + input_count = 0 for node in graph_module.graph.nodes: if node.op == "call_function": process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": process_placeholder(node, tosa_graph, edge_program, tosa_spec) + if node.name in edge_program.graph_signature.user_inputs: + input_count += 1 elif node.op == "output": process_output(node, tosa_graph) else: @@ -259,6 +280,12 @@ def preprocess( # noqa: C901 # any checking of compatibility. dbg_fail(node, tosa_graph, artifact_path) + if len(input_order) > 0: + if input_count != len(input_order): + raise RuntimeError( + "The rank of the input order is not equal to amount of input tensors" + ) + # TODO: It would be awesome if this dump could somehow be done on top level and not here. # Problem is that the desc.json has to be created on the tosa_graph object, which we can't # access from top level. @@ -275,7 +302,7 @@ def preprocess( # noqa: C901 # preprocess and some consume TOSA fb directly. if output_format == "vela": # Emit vela_bin_stream format - binary = vela_compile(tosa_graph, compile_flags) + binary = vela_compile(tosa_graph, compile_flags, input_order) elif output_format == "tosa": # Emit TOSA flatbuffer binary = bytes(tosa_graph.serialize()) diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py index 01bb8bd55e..918d95ba37 100644 --- a/backends/arm/arm_vela.py +++ b/backends/arm/arm_vela.py @@ -17,10 +17,13 @@ # Pack either input or output tensor block, compose the related arrays into # per-io structs to simplify runtime use. -def vela_bin_pack_io(prefix, data): - ios = struct.pack(" list[CompileSpec]: """ Default compile spec for Ethos-U55 tests. """ return get_u55_compile_spec_unbuilt( - permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path + permute_memory_to_nhwc, + quantize_io=quantize_io, + custom_path=custom_path, + reorder_inputs=reorder_inputs, ).build() def get_u85_compile_spec( - permute_memory_to_nhwc=True, quantize_io=False, custom_path=None + permute_memory_to_nhwc=True, + quantize_io=False, + custom_path=None, + reorder_inputs=None, ) -> list[CompileSpec]: """ Default compile spec for Ethos-U85 tests. """ return get_u85_compile_spec_unbuilt( - permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path + permute_memory_to_nhwc, + quantize_io=quantize_io, + custom_path=custom_path, + reorder_inputs=reorder_inputs, ).build() def get_u55_compile_spec_unbuilt( - permute_memory_to_nhwc=True, quantize_io=False, custom_path=None + permute_memory_to_nhwc=True, + quantize_io=False, + custom_path=None, + reorder_inputs=None, ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify the compile spec before calling .build() to finalize it. @@ -257,12 +272,16 @@ def get_u55_compile_spec_unbuilt( .set_quantize_io(is_option_enabled("quantize_io") or quantize_io) .set_permute_memory_format(permute_memory_to_nhwc) .dump_intermediate_artifacts_to(artifact_path) + .set_input_order(reorder_inputs) ) return compile_spec def get_u85_compile_spec_unbuilt( - permute_memory_to_nhwc=True, quantize_io=False, custom_path=None + permute_memory_to_nhwc=True, + quantize_io=False, + custom_path=None, + reorder_inputs=None, ) -> list[CompileSpec]: """Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify the compile spec before calling .build() to finalize it. @@ -279,6 +298,7 @@ def get_u85_compile_spec_unbuilt( .set_quantize_io(is_option_enabled("quantize_io") or quantize_io) .set_permute_memory_format(permute_memory_to_nhwc) .dump_intermediate_artifacts_to(artifact_path) + .set_input_order(reorder_inputs) ) return compile_spec diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 4953f8735e..ddd5fd6b0b 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -245,7 +245,9 @@ def get_calibration_data( def get_compile_spec( - target: str, intermediates: Optional[str] = None + target: str, + intermediates: Optional[str] = None, + reorder_inputs: Optional[str] = None, ) -> ArmCompileSpecBuilder: spec_builder = None if target == "TOSA": @@ -265,6 +267,7 @@ def get_compile_spec( ) .set_permute_memory_format(True) .set_quantize_io(True) + .set_input_order(reorder_inputs) ) elif "ethos-u85" in target: spec_builder = ( @@ -277,6 +280,7 @@ def get_compile_spec( ) .set_permute_memory_format(True) .set_quantize_io(True) + .set_input_order(reorder_inputs) ) if intermediates is not None: @@ -419,6 +423,14 @@ def get_args(): required=False, help="Location for outputs, if not the default of cwd.", ) + parser.add_argument( + "-r", + "--reorder_inputs", + type=str, + required=False, + default=None, + help="Provide the order of the inputs. This can be required when inputs > 1.", + ) args = parser.parse_args() if args.evaluate and ( @@ -481,7 +493,9 @@ def get_args(): if args.delegate: # As we can target multiple output encodings from ArmBackend, one must # be specified. - compile_spec = get_compile_spec(args.target, args.intermediates) + compile_spec = get_compile_spec( + args.target, args.intermediates, args.reorder_inputs + ) edge = to_edge_transform_and_lower( exported_program, partitioner=[ArmPartitioner(compile_spec)], diff --git a/examples/arm/run.sh b/examples/arm/run.sh index c2c04cd2fd..9dc95600d5 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -20,6 +20,7 @@ script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) root_dir=${script_dir}/ethos-u-scratch model_name="" +reorder_inputs="" aot_arm_compiler_flags="--delegate --quantize" target="ethos-u55-128" output_folder_set=false @@ -37,6 +38,7 @@ help() { echo " --output= Output folder Default: ${output_folder}" echo " --build_only Only build, don't run FVP" echo " --scratch-dir= Path to your Ethos-U scrach dir if you not using default" + echo " --reorder_inputs= Reorder the inputs. This can be required when inputs > 1." exit 0 } @@ -50,6 +52,7 @@ for arg in "$@"; do --output=*) output_folder="${arg#*=}" ; output_folder_set=true ;; --build_only) build_only=true ;; --scratch-dir=*) root_dir="${arg#*=}";; + --reorder_inputs=*) reorder_inputs="${arg#*=}";; *) ;; esac @@ -112,7 +115,7 @@ function generate_pte_file() { # We are using the aot_lib from build_quantization_aot_lib below SO_LIB=$(find cmake-out-aot-lib -name libquantized_ops_aot_lib.${SO_EXT}) - python3 -m examples.arm.aot_arm_compiler --model_name="${model}" --target=${target} ${model_compiler_flags} --output ${output_folder} --so_library="$SO_LIB" 1>&2 + python3 -m examples.arm.aot_arm_compiler --model_name="${model}" --target=${target} ${model_compiler_flags} --reorder_inputs=${reorder_inputs} --output ${output_folder} --so_library="$SO_LIB" 1>&2 [[ -f ${pte_file} ]] || { >&2 echo "Failed to generate a pte file - ${pte_file}"; exit 1; } echo "${pte_file}" } @@ -287,6 +290,7 @@ if [[ -z "$model_name" ]]; then else test_model=( "$model_name" ) model_compiler_flags=( "$aot_arm_compiler_flags" ) + reorder_inputs=( "$reorder_inputs" ) fi # loop over running the AoT flow and executing the model on device From 12ce0cebfb03f8ae0ea9a395f4c0ddee74618068 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 15 Nov 2024 13:20:13 +0100 Subject: [PATCH 06/18] Figure out target-board from compile spec Reduces boilerplate in FVP tests Signed-off-by: Erik Lundell Change-Id: I7b4cdec6ba3da91e9f510830d6d817acaf18c53e --- backends/arm/test/common.py | 11 +++++++++++ backends/arm/test/ops/test_add.py | 26 ++++++-------------------- backends/arm/test/runner_utils.py | 5 ++--- backends/arm/test/tester/arm_tester.py | 6 +++++- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index c425493c36..17353cab31 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -330,3 +330,14 @@ def _clean_dir(dir: Path, filter: str, num_save=10): for remove in sorted_files[0 : len(sorted_files) - num_save]: file = remove[1] file.unlink() + + +def get_target_board(compile_spec: list[CompileSpec]) -> str | None: + for spec in compile_spec: + if spec.key == "compile_flags": + flags = spec.value.decode() + if "u55" in flags: + return "corstone-300" + elif "u85" in flags: + return "corstone-320" + return None diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 66e278ee0f..6676a38add 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -115,6 +115,8 @@ def _test_add_ethos_BI_pipeline( .to_executorch() .serialize() ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) return tester @@ -131,28 +133,20 @@ def test_add_tosa_BI(self, test_data: torch.Tensor): @parameterized.expand(Add.test_parameters) def test_add_u55_BI(self, test_data: torch.Tensor): test_data = (test_data,) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add(), common.get_u55_compile_spec(permute_memory_to_nhwc=True), test_data, ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-300" - ) @parameterized.expand(Add.test_parameters) def test_add_u85_BI(self, test_data: torch.Tensor): test_data = (test_data,) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add(), common.get_u85_compile_spec(permute_memory_to_nhwc=True), test_data, ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-320" - ) @parameterized.expand(Add2.test_parameters) def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): @@ -167,21 +161,13 @@ def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): @parameterized.expand(Add2.test_parameters) def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add2(), common.get_u55_compile_spec(), test_data ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-300" - ) @parameterized.expand(Add2.test_parameters) def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add2(), common.get_u85_compile_spec(), test_data ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-320" - ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 608761098e..5940067af6 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -191,9 +191,6 @@ def init_run( target_board: str, ): - if target_board not in ["corstone-300", "corstone-320"]: - raise RuntimeError(f"Unknown target board: {target_board}") - self.input_names = _get_input_names(edge_program) self.output_node = _get_output_node(exported_program) self.output_name = self.output_node.name @@ -222,6 +219,8 @@ def run_corstone( assert ( self._has_init_run ), "RunnerUtil needs to be initialized using init_run() before running Corstone300." + if self.target_board not in ["corstone-300", "corstone-320"]: + raise RuntimeError(f"Unknown target board: {self.target_board}") pte_path = os.path.join(self.intermediate_path, "program.pte") assert os.path.exists(pte_path), f"Pte path '{pte_path}' not found." diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index e2062f2428..3564a3325a 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -26,6 +26,7 @@ arm_test_options, current_time_formated, get_option, + get_target_board, ) from executorch.backends.arm.test.runner_utils import ( @@ -267,7 +268,7 @@ def run_method_and_compare_outputs( self, inputs: Optional[Tuple[torch.Tensor]] = None, stage: Optional[str] = None, - target_board: Optional[str] = "corstone-300", + target_board: Optional[str] = None, num_runs=1, atol=1e-03, rtol=1e-03, @@ -301,6 +302,9 @@ def run_method_and_compare_outputs( test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None + if target_board is None: + target_board = get_target_board(self.compile_spec) + exported_program = self.stages[self.stage_name(tester.Export)].artifact edge_program = edge_stage.artifact.exported_program() self.runner_util.init_run( From fbee0c8fd32a4fdf0481e7dc3d6989bc9f29f0a0 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 11 Nov 2024 07:42:41 +0100 Subject: [PATCH 07/18] Add initial support for rshift U55 is restricted to round=True which may cause numerical differences between TOSA and PyTorch. Signed-off-by: Oscar Andersson Change-Id: I280e0dd0573b31333f6386b48d20105023719eb7 --- backends/arm/arm_backend.py | 2 +- backends/arm/operator_support/__init__.py | 7 +- .../operator_support/right_shift_support.py | 35 +++++++ backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_rshift.py | 99 +++++++++++++++++++ backends/arm/test/ops/test_rshift.py | 90 +++++++++++++++++ 6 files changed, 232 insertions(+), 2 deletions(-) create mode 100644 backends/arm/operator_support/right_shift_support.py create mode 100644 backends/arm/operators/op_rshift.py create mode 100644 backends/arm/test/ops/test_rshift.py diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index ad2d1e73af..59473a9e6d 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -90,7 +90,7 @@ def ethosu_compile_spec( self.compiler_flags.append(extra_flags) base_tosa_version = "TOSA-0.80.0+BI" - if "U55" in config: + if "u55" in config: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 0a88bc45aa..c133ce8003 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -5,4 +5,9 @@ # pyre-unsafe -from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa +from . import ( # noqa + mean_dim_support, + right_shift_support, + tosa_supported_operators, + var_correction_support, +) diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py new file mode 100644 index 0000000000..ee8d5965a1 --- /dev/null +++ b/backends/arm/operator_support/right_shift_support.py @@ -0,0 +1,35 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_tosa_support_check +class RightShiftSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.__rshift__.Scalar] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + + # TODO MLETORCH-525 Remove warning + if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: + logging.warning(f"{node.target} may introduce one-off errors.") + return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 988765990d..a5c2dd8dc5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -27,6 +27,7 @@ op_reciprocal, op_relu, op_repeat, + op_rshift, op_rsqrt, op_select, op_sigmoid, diff --git a/backends/arm/operators/op_rshift.py b/backends/arm/operators/op_rshift.py new file mode 100644 index 0000000000..94b3f8b86d --- /dev/null +++ b/backends/arm/operators/op_rshift.py @@ -0,0 +1,99 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_specification import Tosa_0_80 +from executorch.backends.arm.tosa_utils import tosa_shape +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class RshiftVisitor(NodeVisitor): + target = "aten.__rshift__.Scalar" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_shape = inputs[0].shape + input_0_rank = len(input_shape) + shift_expanded_shape = [1] * input_0_rank + dtype = node.meta["val"].dtype + attr = ts.TosaSerializerAttribute() + cast_input = False + cast_output = False + round = False + cast_type = dtype + if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: + # U55 only supports INT32 and round == True + # TODO MLETORCH-525 Emulate round == False with different decomposition + if dtype != torch.int32: + cast_input = True + cast_output = True + cast_type = torch.int32 + round = True + attr.ArithmeticRightShiftAttribute(round=round) + + if cast_input: + # input needs to be casted to INT32 + shift_input = tosa_graph.addIntermediate( + shape=tosa_shape(input_shape, inputs[0].dim_order), + dtype=map_dtype(cast_type), + ) + tosa_graph.addOperator( + TosaOp.Op().CAST, + [inputs[0].name], + [shift_input.name], + None, + ) + else: + shift_input = inputs[0] + if cast_output: + # add intermediate tensor for right shift + shift = tosa_graph.addIntermediate( + shape=tosa_shape(input_shape, inputs[0].dim_order), + dtype=map_dtype(cast_type), + ) + else: + shift = output + # create tensor with same rank as inputs[0] + data = torch.full( + shift_expanded_shape, fill_value=inputs[1].number, dtype=dtype + ) + shift_const_name = node.name + "-shift_const" + tosa_graph.addConst( + shift_expanded_shape, + map_dtype(cast_type), + data.detach().numpy(), + shift_const_name, + ) + # add right shift operator + tosa_graph.addOperator( + TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, + [shift_input.name, shift_const_name], + [shift.name], + attr, + ) + if cast_output: + # cast output to original output dtype + tosa_graph.addOperator( + TosaOp.Op().CAST, + [shift.name], + [output.name], + None, + ) diff --git a/backends/arm/test/ops/test_rshift.py b/backends/arm/test/ops/test_rshift.py new file mode 100644 index 0000000000..dfbd0fdb3e --- /dev/null +++ b/backends/arm/test/ops/test_rshift.py @@ -0,0 +1,90 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +class TestRshift(unittest.TestCase): + """ + Tests arithmetic right shift + """ + + class Rshift(torch.nn.Module): + test_data = [ + ((torch.IntTensor(5, 5), 2),), + ((torch.IntTensor(1, 2, 3, 4), 3),), + ((torch.ShortTensor(1, 5, 3, 4), 5),), + ((torch.CharTensor(10, 12, 3, 4), 1),), + ] + + def forward(self, x: torch.Tensor, shift: int): + return x >> shift + + def _test_rshift_tosa_MI(self, test_data): + ( + ArmTester( + self.Rshift(), + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .to_edge_transform_and_lower() + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_rshift_tosa_BI(self, test_data): + ( + ArmTester( + self.Rshift(), + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + # TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO + # .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_rshift_ethosu_BI(self, test_data, compile_spec): + return ( + ArmTester( + self.Rshift(), + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + ) + + @parameterized.expand(Rshift.test_data) + def test_rshift_tosa_MI(self, test_data): + self._test_rshift_tosa_MI(test_data) + + @parameterized.expand(Rshift.test_data) + def test_rshift_tosa_BI(self, test_data): + self._test_rshift_tosa_BI(test_data) + + # TODO Enable FVP testing + @parameterized.expand(Rshift.test_data) + def test_rshift_u55_BI(self, test_data): + compile_spec = common.get_u55_compile_spec() + self._test_rshift_ethosu_BI(test_data, compile_spec) + + # TODO Enable FVP testing + @parameterized.expand(Rshift.test_data) + def test_rshift_u85_BI(self, test_data): + compile_spec = common.get_u85_compile_spec() + self._test_rshift_ethosu_BI(test_data, compile_spec) From 1139a1c6188e189247eac649bb2243b53a795590 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 25 Nov 2024 10:09:09 -0800 Subject: [PATCH 08/18] [ET] Add `tsv_path` to `inspector_cli` (#7046) Pull Request resolved: https://github.com/pytorch/executorch/pull/7035 Per https://fb.workplace.com/groups/pytorch.edge.users/posts/1640064163530537/?comment_id=1640127190190901 ghstack-source-id: 255101390 @exported-using-ghexport Differential Revision: [D66379005](https://our.internmc.facebook.com/intern/diff/D66379005/) Co-authored-by: jorgep31415 --- devtools/inspector/_inspector.py | 50 +++++++++++++++++++++++------ devtools/inspector/inspector_cli.py | 7 ++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index d951a1ada9..001ea50550 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1143,23 +1143,18 @@ def to_dataframe( ] return pd.concat(df_list, ignore_index=True) - def print_data_tabular( + def _prepare_dataframe( self, - file: IO[str] = sys.stdout, include_units: bool = True, include_delegate_debug_data: bool = False, - ) -> None: + ) -> pd.DataFrame: """ - Displays the underlying EventBlocks in a structured tabular format, with each row representing an Event. - Args: - file: Which IO stream to print to. Defaults to stdout. - Not used if this is in an IPython environment such as a Jupyter notebook. include_units: Whether headers should include units (default true) include_delegate_debug_data: Whether to include delegate debug metadata (default false) Returns: - None + Returns a pandas DataFrame of the Events in each EventBlock in the inspector, with additional filtering. """ combined_df = self.to_dataframe(include_units, include_delegate_debug_data) @@ -1171,7 +1166,44 @@ def print_data_tabular( ] filtered_column_df.reset_index(drop=True, inplace=True) - display_or_print_df(filtered_column_df, file) + return filtered_column_df + + def print_data_tabular( + self, + file: IO[str] = sys.stdout, + include_units: bool = True, + include_delegate_debug_data: bool = False, + ) -> None: + """ + Displays the underlying EventBlocks in a structured tabular format, with each row representing an Event. + + Args: + file: Which IO stream to print to. Defaults to stdout. + Not used if this is in an IPython environment such as a Jupyter notebook. + include_units: Whether headers should include units (default true) + include_delegate_debug_data: Whether to include delegate debug metadata (default false) + + Returns: + None + """ + df = self._prepare_dataframe(include_units, include_delegate_debug_data) + display_or_print_df(df, file) + + def save_data_to_tsv( + self, + file: IO[str], + ) -> None: + """ + Stores the underlying EventBlocks in tsv format to facilitate copy-paste into spreadsheets. + + Args: + file: Which IO stream to print to. Do not use stdout, as tab separator is not preserved. + + Returns: + None + """ + df = self._prepare_dataframe() + df.to_csv(file, sep="\t") # TODO: write unit test def find_total_for_module(self, module_name: str) -> float: diff --git a/devtools/inspector/inspector_cli.py b/devtools/inspector/inspector_cli.py index db3536a84b..00e74cc25f 100644 --- a/devtools/inspector/inspector_cli.py +++ b/devtools/inspector/inspector_cli.py @@ -43,6 +43,11 @@ def main() -> None: required=False, help="Provide an optional buffer file path.", ) + parser.add_argument( + "--tsv_path", + required=False, + help="Provide an optional tsv file path.", + ) parser.add_argument("--compare_results", action="store_true") args = parser.parse_args() @@ -55,6 +60,8 @@ def main() -> None: target_time_scale=TimeScale(args.target_time_scale), ) inspector.print_data_tabular() + if args.tsv_path: + inspector.save_data_to_tsv(args.tsv_path) if args.compare_results: for event_block in inspector.event_blocks: if event_block.name == "Execute": From a60d929bc77a95bab583933a1096fae4039cb008 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:32:52 -0800 Subject: [PATCH 09/18] Fix PyBind 2.10.4 compatibility issue in executorch/extension/pybindings/pybindings.cpp +1 Differential Revision: D66395519 Pull Request resolved: https://github.com/pytorch/executorch/pull/7056 --- extension/pybindings/pybindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 3b3ba57093..518e66d284 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -795,7 +795,7 @@ struct PyModule final { py::isinstance(debug_buffer_path)) { // Also write out the debug buffer to a separate file if requested. std::string debug_buffer_path_str = - py::cast(debug_buffer_path); + py::cast(debug_buffer_path); const auto debug_buffer = module_->get_etdump_debug_buffer(); write_data_to_file( debug_buffer_path_str, debug_buffer.data(), debug_buffer.size()); From 04f9cedb48d49916193adcf8b6b0cad72e6bbde5 Mon Sep 17 00:00:00 2001 From: David Lin Date: Mon, 25 Nov 2024 11:01:28 -0800 Subject: [PATCH 10/18] Fix test-llama-runner-qnn-linux tests (#7055) fix test-llama-runner-qnn-linux (fp32, qnn) add default value change pull.yml to reflect same changes in trunk.yml --- .ci/scripts/test_llama.sh | 3 +++ .github/workflows/pull.yml | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 23a579e67c..e109845547 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -48,6 +48,9 @@ MODE=${MODE:-"xnnpack+custom"} # Default UPLOAD_DIR to empty string if not set UPLOAD_DIR="${UPLOAD_DIR:-}" +# Default PT2E_QUANTIZE to empty string if not set +PT2E_QUANTIZE="${PT2E_QUANTIZE:-}" + if [[ $# -lt 4 ]]; then # Assuming 4 mandatory args echo "Expecting atleast 4 positional arguments" echo "Usage: [...]" diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 6fc8ca9185..88cd8ff15a 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -368,6 +368,7 @@ jobs: strategy: matrix: dtype: [fp32] + pt2e_quantize: [qnn_16a16w, qnn_8a8w] mode: [qnn] fail-fast: false with: @@ -384,6 +385,7 @@ jobs: DTYPE=${{ matrix.dtype }} BUILD_TOOL="cmake" MODE=${{ matrix.mode }} + PT2E_QUANTIZE=${{ matrix.pt2e_quantize }} PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh @@ -393,7 +395,7 @@ jobs: # Install requirements for export_llama PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh # Test llama2 - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -dtype "${DTYPE}" -mode "${MODE}" + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -mode "${MODE}" -dtype "${DTYPE}" -pt2e_quantize "${PT2E_QUANTIZE}" test-phi-3-mini-runner-linux: name: test-phi-3-mini-runner-linux From d7786272cddbc9635011c4f80b63e917d0a6daf1 Mon Sep 17 00:00:00 2001 From: ckmadhira Date: Tue, 26 Nov 2024 01:45:06 +0530 Subject: [PATCH 11/18] =?UTF-8?q?Added=20Fusion=20G3=20NN=20library=20with?= =?UTF-8?q?=20kernels=20related=20to=20add,=20mul,=20quantize=E2=80=A6=20(?= =?UTF-8?q?#6738)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added Fusion G3 NN library with kernels related to add, mul, quantize, dequantize, cat, layer norm, softmax to backends/cadence folder. Added operators to backends/cadence folder * Updated name space of the operators by appending cadence Signed-off-by: cmadhira@cadence.com * Added nnlib-FusionG3 submodule from FOSS-xtensa git space Signed-off-by: cmadhira@cadence.com * Resolved Linter errors Signed-off-by: cmadhira@cadence.com --------- Signed-off-by: cmadhira@cadence.com Co-authored-by: cmadhira@cadence.com --- .gitmodules | 3 + backends/cadence/CMakeLists.txt | 7 +- backends/cadence/aot/functions_fusion_g3.yaml | 118 +++ .../fusion_g3/operators/CMakeLists.txt | 85 ++ .../cadence/fusion_g3/operators/op_add.cpp | 257 ++++++ .../cadence/fusion_g3/operators/op_cat.cpp | 202 +++++ .../fusion_g3/operators/op_dequantize.cpp | 810 ++++++++++++++++++ .../cadence/fusion_g3/operators/op_mul.cpp | 214 +++++ .../operators/op_native_layer_norm.cpp | 258 ++++++ .../fusion_g3/operators/op_quantize.cpp | 797 +++++++++++++++++ .../fusion_g3/operators/op_softmax.cpp | 118 +++ .../third-party/nnlib/CMakeLists.txt | 19 + .../third-party/nnlib/nnlib-FusionG3 | 1 + 13 files changed, 2888 insertions(+), 1 deletion(-) create mode 100644 backends/cadence/aot/functions_fusion_g3.yaml create mode 100644 backends/cadence/fusion_g3/operators/CMakeLists.txt create mode 100644 backends/cadence/fusion_g3/operators/op_add.cpp create mode 100644 backends/cadence/fusion_g3/operators/op_cat.cpp create mode 100644 backends/cadence/fusion_g3/operators/op_dequantize.cpp create mode 100644 backends/cadence/fusion_g3/operators/op_mul.cpp create mode 100644 backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp create mode 100644 backends/cadence/fusion_g3/operators/op_quantize.cpp create mode 100644 backends/cadence/fusion_g3/operators/op_softmax.cpp create mode 100644 backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt create mode 160000 backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 diff --git a/.gitmodules b/.gitmodules index d1ab8b9aa7..58f2133ed6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -64,6 +64,9 @@ [submodule "third-party/pybind11"] path = third-party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"] + path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 + url = https://github.com/foss-xtensa/nnlib-FusionG3/ [submodule "third-party/ao"] path = third-party/ao url = https://github.com/pytorch/ao.git diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 3c1aa2945a..3cd880622c 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -76,7 +76,12 @@ endif() if(EXECUTORCH_NNLIB_OPT) set(TARGET_DIR hifi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) +endif() + +if(EXECUTORCH_FUSION_G3_OPT) + set(TARGET_DIR fusion_g3) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/operators) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) diff --git a/backends/cadence/aot/functions_fusion_g3.yaml b/backends/cadence/aot/functions_fusion_g3.yaml new file mode 100644 index 0000000000..2c162e1444 --- /dev/null +++ b/backends/cadence/aot/functions_fusion_g3.yaml @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This yaml file contains operators that are also defined by the ATen library. +# For lean mode: +# - Codegen'd target `executorch_generated_lib` will be reading all the information +# from this file, including operator schema and kernel metadata. +# - Selective build target `codegen:executorch_defined_ops` now is selecting all the +# operators in this file, by dumping all the op names into `selected_operators.yaml`. +# +# See the README.md file in executorch/kernels/portable for a description of the syntax used +# by this file. + + +# aten ops +- op: _to_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::to_copy_out + +- op: _softmax.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::softmax_out + +- op: add.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::add_out + +- op: add.Scalar_out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::add_scalar_out + +- op: bmm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::bmm_out + +- op: cat.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::cat_out + +- op: clone.out + kernels: + - arg_meta: null + kernel_name: torch::executor::clone_out + +- op: div.out + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out + +- op: div.out_mode + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out_mode + +- op: embedding.out + kernels: + - arg_meta: null + kernel_name: torch::executor::embedding_out + +- op: full.out + kernels: + - arg_meta: null + kernel_name: torch::executor::full_out + +- op: mul.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::mul_out + +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::mul_scalar_out + +- op: permute_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::permute_copy_out + +- op: sigmoid.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sigmoid_out + +- op: slice_copy.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::slice_copy_Tensor_out + +- op: split_with_sizes_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::split_with_sizes_copy_out + +- op: sub.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sub_out + +- op: view_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::view_copy_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::where_out + +- op: native_layer_norm.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::native_layer_norm_out \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/CMakeLists.txt b/backends/cadence/fusion_g3/operators/CMakeLists.txt new file mode 100644 index 0000000000..704b4aa741 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/CMakeLists.txt @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) +include(${EXECUTORCH_ROOT}/build/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +# ATen compliant ops that are needed to run this model. +set(_aten_ops__srcs + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_mul.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_cat.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_softmax.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_full.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_view_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" +) +add_library(aten_ops_cadence ${_aten_ops__srcs}) +target_link_libraries(aten_ops_cadence PUBLIC executorch) +target_link_libraries(aten_ops_cadence PRIVATE xa_nnlib) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +target_include_directories( + aten_ops_cadence PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/algo/common/include/ + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/include/nnlib + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/include + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/algo/kernels/tables/include +) + +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in functions.yaml +gen_selected_ops( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions_fusion_g3.yaml" "" "" +) +generate_bindings_for_kernels( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML FUNCTIONS_YAML + ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions_fusion_g3.yaml +) +message("Generated files ${gen_command_sources}") + +gen_operators_lib( + LIB_NAME "cadence_ops_lib" KERNEL_LIBS DEPS aten_ops_cadence +) diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp new file mode 100644 index 0000000000..6dc710ce6e --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -0,0 +1,257 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::canCast; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& add_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + torch::executor::check_alpha_type( + torch::executor::native::utils::get_scalar_dtype(alpha), + common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.out"; + + const exec_aten::ArrayRef a_size = a.sizes(); + const exec_aten::ArrayRef b_size = b.sizes(); + const exec_aten::ArrayRef out_size = out.sizes(); + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + /* input shapes and output shapes */ + for (auto i = 0; i < a_size.size(); i++) { + inp1_shape[i] = a_size[i]; + } + + for (auto i = 0; i < b_size.size(); i++) { + inp2_shape[i] = b_size[i]; + } + + for (auto i = 0; i < out_size.size(); i++) { + out_shape[i] = out_size[i]; + } + + /*find broadcast*/ + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool broadcast = (a_is_broadcasted || b_is_broadcasted); + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + int* const out_data = out.mutable_data_ptr(); + + int alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + if (broadcast) { + xa_nn_elm_add_broadcast_5D_32x32_32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim, + alpha_val); + } else { + xa_nn_elm_add_32x32_32( + out_data, inp1_data, inp2_data, alpha_val, out.numel()); + } + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + float alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + if (broadcast) { + xa_nn_elm_add_broadcast_5D_f32xf32_f32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim, + alpha_val); + } else { + xa_nn_elm_add_f32xf32_f32( + out_data, inp1_data, inp2_data, alpha_val, out.numel()); + } + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = + torch::executor::native::utils::scalar_to(alpha); + torch::executor::native::utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name>( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16); + }); + } + + return out; +} + +Tensor& add_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && + torch::executor::check_alpha_type( + torch::executor::native::utils::get_scalar_dtype(alpha), + common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.Scalar_out"; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + int alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + int* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_add_scalar_32x32_32( + out_data, inp1_data, inp2_val, alpha_val, out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + float alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + float* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_add_scalar_f32xf32_f32( + out_data, inp1_data, inp2_val, alpha_val, out.numel()); + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = + torch::executor::native::utils::scalar_to( + alpha); + return val_a + val_alpha * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes:: + SAME_AS_COMMON); + }); + } + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp new file mode 100644 index 0000000000..62bbb0c9d4 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_cat.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ +enum datatype { + Ushort = 20, + Uint = 23, +}; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& cat_out( + KernelRuntimeContext& ctx, + exec_aten::ArrayRef tensors, + int64_t dim, + Tensor& out) { + if (dim < 0) { + dim += out.dim(); + } + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_cat_args(tensors, dim, out), + InvalidArgument, + out); + + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; + size_t expected_out_dim = 0; + torch::executor::get_cat_out_target_size( + tensors, dim, expected_out_size, &expected_out_dim); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + out, {expected_out_size, expected_out_dim}) == Error::Ok, + InvalidArgument, + out); + + const signed char* inp_tensors[tensors.size()]; + const int* inp_tensors_shapes[tensors.size()]; + + int inp_shapes_size[tensors.size()]; + + int temp_sizes[tensors.size()][kTensorDimensionLimit]; + exec_aten::ArrayRef temp_size; + + for (int i = 0; i < tensors.size(); i++) { + inp_tensors[i] = tensors[i].const_data_ptr(); + temp_size = tensors[i].sizes(); + + for (int j = 0; j < temp_size.size(); j++) { + temp_sizes[i][j] = temp_size[j]; + } + inp_tensors_shapes[i] = temp_sizes[i]; // input shapes + inp_shapes_size[i] = temp_size.size(); // number of input dimensions + } + + signed char* out_data = out.mutable_data_ptr(); + + const exec_aten::ArrayRef out_size = out.sizes(); + int out_shapes[kTensorDimensionLimit]; + for (int i = 0; i < out_size.size(); i++) // output shapes + { + out_shapes[i] = out_size[i]; + } + + if (out.scalar_type() == ScalarType::Int) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(int)); + } else if (out.scalar_type() == ScalarType::Short) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(short)); + } else if (out.scalar_type() == ScalarType::Char) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(char)); + } + if (out.scalar_type() == (ScalarType)Uint) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(int)); + } else if (out.scalar_type() == (ScalarType)Ushort) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(short)); + } else if (out.scalar_type() == ScalarType::Byte) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(char)); + + } else { + // Special handling when all inputs are 1D-empty tensors for aten + // consistency In that case, just return an 1D-empty tensor without checking + // dim + bool all_1d_empty = true; + for (size_t i = 0; i < tensors.size(); ++i) { + if (tensors[i].numel() != 0 || tensors[i].dim() != 1) { + all_1d_empty = false; + break; + } + } + if (all_1d_empty) { + return out; + } + + const size_t outer = executorch::runtime::getLeadingDims(out, dim); + const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim); + const size_t ninputs = tensors.size(); + + const auto out_type = out.scalar_type(); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { + CTYPE_OUT* out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { + const auto in_type = tensors[j].scalar_type(); + ET_SWITCH_REALHB_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + if (tensors[j].numel() == 0) { + return; + } + size_t inner = tensors[j].size(dim) * dim_stride; + const CTYPE_IN* const in_ptr = + tensors[j].const_data_ptr() + i * inner; + + for (size_t k = 0; k < inner; ++k) { + out_ptr[k] = static_cast(in_ptr[k]); + } + out_ptr += inner; + }); + } + } + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp new file mode 100644 index 0000000000..784011332f --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -0,0 +1,810 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +template +using optional = exec_aten::optional; +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ + + enum datatype { + Ushort = 20, + Bits4u = 21, + Bits4 = 22 + }; + +/** + * For an input tensor, use the scale and zero_point arguments to quantize it. + */ +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +namespace { + +/** + * Asserts that the parameters are valid. + */ +void check_dequantize_per_tensor_args(const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional& out_dtype, + Tensor& out) +{ + ET_CHECK_MSG( + input.scalar_type() == ScalarType::Byte || + input.scalar_type() == ScalarType::Char || + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::Short || + input.scalar_type() == (ScalarType) Ushort || + input.scalar_type() == (ScalarType) Bits4 || + input.scalar_type() == (ScalarType) Bits4u || + input.scalar_type() == ScalarType::Int, + + "input.scalar_type() %" PRId8 " is not supported:", + static_cast(input.scalar_type())); + + ET_CHECK_MSG( + input.scalar_type() == dtype, + "input.scalar_type() %" PRId8 " is not matching dtype argumenta:", + static_cast(input.scalar_type())); + + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out.scalar_type() == out_dtype.value(), + "output_dtype must match the dtype of the out tensor"); + } + + ET_CHECK_MSG( + quant_min <= quant_max, + "quant min: %" PRId64 " is greater than quant max: %" PRId64, + quant_min, + quant_max); +} + +} // namespace + + +/* Local function which calls the kernels based on the input datatype */ +void Dequantize_impl(Tensor& out, + const Tensor& input, + float *scale_data, + int *zero_point_data, + int *axis, + exec_aten::optional out_dtype) +{ + const exec_aten::ArrayRef input_size = input.sizes(); + + int kTensorDimensionLimit = 5; + + int inp_shape[kTensorDimensionLimit]; + + for(auto i = 0; i < input_size.size(); i++) + { + inp_shape[i] = input_size[i]; + } + + bool is_asym_dequant = 0; + + if(zero_point_data != NULL) //asymmetric dequant + { + if(axis != NULL) //channel + { + for(int i = 0; i < input.size(*axis) ; i++) + { + if(zero_point_data[i] != 0) + { + is_asym_dequant |= 1; + } + } + } + else + { + if(*zero_point_data != 0) //tesor + { + is_asym_dequant |= 1; + } + } + } + float* out_data = out.mutable_data_ptr(); + + if(is_asym_dequant) + { + if (input.scalar_type() == ScalarType::Byte) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8u_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == ScalarType::Char) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == (ScalarType) Ushort) + { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16u_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == ScalarType::Short) + { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4u) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4u_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else + { + if(axis == NULL) + { + // calculate the dequantized output, cast scale to float to match fbgemm + // behavior + #define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; + #define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR); + ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef ASYM_CALCULATE_INT_TYPE_TENSOR + #undef ASYM_DEQUANTIZE_IMPL_TESNOR + } + else + { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual dequantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are dequantizing. + // in other words you are dequantizing in_data[in_ix] + #define ASYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define ASYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL); + ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef ASYM_CALCULATE_INT_TYPE_CHANNEL + #undef ASYM_DEQUANTIZE_IMPL_CHANNEL + } + } + } + else + { + if (input.scalar_type() == ScalarType::Byte) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == ScalarType::Char) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == (ScalarType) Ushort) + { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == ScalarType::Short) + { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4u) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else + { + if(axis == NULL) + { + // calculate the dequantized output, cast scale to float to match fbgemm + // behavior + #define SYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; + #define SYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR); + SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef SYM_DEQUANTIZE_IMPL_TESNOR + #undef SYM_CALCULATE_INT_TYPE_TENSOR + } + else + { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual dequantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are dequantizing. + // in other words you are dequantizing in_data[in_ix] + #define SYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define SYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL); + SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef SYM_DEQUANTIZE_IMPL_CHANNEL + #undef SYM_CALCULATE_INT_TYPE_CHANNEL + } + } + } +} + +/** + * Dequantizes the input tensor according to the formula (input - zero_point) * + * scale + * + * NOTE: quant_min and quant_max are not used in computation, but rather + * metadata that is passed around which can be useful for pattern matching. See + * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more + * info. + */ +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_tensor_out"); + + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); + + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + + Dequantize_impl(out, + input, + &scale_data, + &zero_point_data, + NULL, + out_dtype); + + return out; +} + +Tensor& dequantize_per_tensor_tensor_args_out(const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "Expected scale to be Double tensor received: %" PRId8, + static_cast(scale.scalar_type())); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "Expected scale to be Long tensor received: %" PRId8, + static_cast(zero_point.scalar_type())); + ET_CHECK_MSG( + scale.numel() == 1, + "Exepcted scale to only have one element received: %zd", + ssize_t(scale.numel())); + ET_CHECK_MSG( + zero_point.numel() == 1, + "Exepcted zero_point to only have one element received: %zd", + ssize_t(zero_point.numel())); + + dequantize_per_tensor_out( + input, + scale.const_data_ptr()[0], + zero_point.const_data_ptr()[0], + quant_min, + quant_max, + dtype, + out_dtype, + out); + + return out; +} + +Tensor& dequantize_per_channel_out(const Tensor& input, + const Tensor& scale, + const exec_aten::optional& opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + + // normalize axis + ET_CHECK_MSG( + executorch::runtime::tensor_has_dim(input, axis), + "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", + ssize_t(axis), + ssize_t(input.dim())); + + if (axis < 0) + { + axis += executorch::runtime::nonzero_dim(input); + } + + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); + + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "scale.scalar_type() %" PRId8 " is not double type", + static_cast(scale.scalar_type())); + + ET_CHECK_MSG( + scale.numel() == input.size(axis), + "scale.numel() %zd != input.size(axis) %zd", + ssize_t(scale.numel()), + ssize_t(input.size(axis))); + + if (opt_zero_points.has_value()) { + auto zero_point = opt_zero_points.value(); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); + + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + ssize_t(zero_point.numel()), + ssize_t(input.size(axis))); + } + + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); + + int *axis_ptr = (int *)&axis; + + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt; + int zero_point_data[input.size(axis)]; + int *zero_point_ptr; + if (opt_zero_points.has_value()) + { + zero_point_dt = opt_zero_points.value().const_data_ptr(); + zero_point_ptr = &zero_point_data[0]; + for(int i = 0; i < scale.numel(); i++) + { + zero_point_ptr[i] = (int)zero_point_dt[i]; + } + } + else + { + zero_point_ptr = nullptr; + } + float scale_data[input.size(axis)]; + for(int i = 0; i < scale.numel(); i++) + { + scale_data[i] = (float)scale_dt[i]; + } + Dequantize_impl(out, + input, + scale_data, + zero_point_ptr, + axis_ptr, + out_dtype); + + return out; +} + +Tensor& dequantize_per_channel_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const exec_aten::optional& opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + (void)context; + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); + return dequantize_per_channel_out( + input, + scale, + opt_zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + +Tensor& dequantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +Tensor& dequantize_per_tensor_tensor_args_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_tensor_args_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) +{ + // Refactor this into a util + size_t num_channels = 1; + for (size_t i = 0; i < input.dim() - 1; i++) + { + num_channels *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well + std::array input_sizes; + input_sizes[0] = static_cast(num_channels); + input_sizes[1] = + static_cast(input.size(input.dim() - 1)); +#ifdef USE_ATEN_LIB + Tensor reshaped_input = at::from_blob( + input.mutable_data_ptr(), + input_sizes, + at::TensorOptions(input.scalar_type())); +#else + std::array input_dim_order{0, 1}; + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( + input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); +#endif + + return dequantize_per_channel_out( + reshaped_input, + scale, + zero_points, + 0, /* axis */ + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + +Tensor& dequantize_per_token_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) +{ + (void)context; + return dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp new file mode 100644 index 0000000000..366982ae3f --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_mul.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::canCast; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& mul_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.out"; + + const exec_aten::ArrayRef a_size = a.sizes(); + const exec_aten::ArrayRef b_size = b.sizes(); + const exec_aten::ArrayRef out_size = out.sizes(); + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + /* input shapes and output shapes */ + for (auto i = 0; i < a_size.size(); i++) { + inp1_shape[i] = a_size[i]; + } + + for (auto i = 0; i < b_size.size(); i++) { + inp2_shape[i] = b_size[i]; + } + + for (auto i = 0; i < out_size.size(); i++) { + out_shape[i] = out_size[i]; + } + + /*find broadcast*/ + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool broadcast = (a_is_broadcasted || b_is_broadcasted); + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + int* const out_data = out.mutable_data_ptr(); + + if (broadcast) { + xa_nn_elm_mul_broadcast_5D_32x32_32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim); + } else { + xa_nn_elm_mul_32x32_32(out_data, inp1_data, inp2_data, out.numel()); + } + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + if (broadcast) { + xa_nn_elm_mul_broadcast_5D_f32xf32_f32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim); + } else { + xa_nn_elm_mul_f32xf32_f32(out_data, inp1_data, inp2_data, out.numel()); + } + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + torch::executor::native::utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name>( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16); + }); + } + + return out; +} + +Tensor& mul_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.Scalar_out"; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + int* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_mul_scalar_32x32_32(out_data, inp1_data, inp2_val, out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + float* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_mul_scalar_f32xf32_f32( + out_data, inp1_data, inp2_val, out.numel()); + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes:: + SAME_AS_COMMON); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp new file mode 100644 index 0000000000..68d111795c --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using Tensor = exec_aten::Tensor; +using ScalarType = exec_aten::ScalarType; +using IntArrayRef = exec_aten::ArrayRef; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +namespace { + +template +void layer_norm( + const Tensor& input, + IntArrayRef normalized_shape, + const exec_aten::optional& weight, + const exec_aten::optional& bias, + CTYPE eps, + Tensor& out, + Tensor& mean, + Tensor& rstd) { + size_t dim = input.dim() - normalized_shape.size(); + size_t dim_size = input.size(dim); + + size_t leading = executorch::runtime::getLeadingDims(input, dim); + size_t normalized = + executorch::runtime::getTrailingDims(input, dim) * dim_size; + + if (leading == 0) { + return; + } + + CTYPE* out_data = out.mutable_data_ptr(); + CTYPE* mean_data = mean.mutable_data_ptr(); + CTYPE* rstd_data = rstd.mutable_data_ptr(); + + if (normalized == 0) { + for (int i = 0; i < leading; ++i) { + mean_data[i] = static_cast(0); + rstd_data[i] = static_cast(NAN); + } + return; + } + + const CTYPE* input_data = input.const_data_ptr(); + const CTYPE* weight_data; + if (weight.has_value()) { + weight_data = weight.value().const_data_ptr(); + } else { + weight_data = nullptr; + } + const CTYPE* bias_data; + if (bias.has_value()) { + bias_data = bias.value().const_data_ptr(); + } else { + bias_data = nullptr; + } + + for (int i = 0; i < leading; ++i) { + const CTYPE* x = input_data + i * normalized; + CTYPE* y = out_data + i * normalized; + + // compute E[X] and Var[x] = E[x^2] - E[x]^2 + CTYPE sum = torch::executor::reduce_add(x, normalized); + CTYPE sq_sum = torch::executor::vec_powerf(x, normalized); + CTYPE mean_value = sum / normalized; + CTYPE variance = sq_sum / normalized - mean_value * mean_value; + CTYPE std = std::sqrt(variance + eps); + + // Calculate the elements of output + for (int j = 0; j < normalized; ++j) { + CTYPE w = weight_data ? weight_data[j] : static_cast(1); + CTYPE b = bias_data ? bias_data[j] : static_cast(0); + y[j] = (x[j] - mean_value) / std * w + b; + } + + mean_data[i] = mean_value; + rstd_data[i] = 1.0 / std; + } +} + +} // namespace + +// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight, +// Tensor? bias, float eps, *, Tensor(a!) out, Tensor(b!) mean_out, Tensor(c!) +// rstd_out) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +// As a reference, there's math_native_layer_norm in ATen: +// https://www.internalfb.com/code/fbsource/[2da5b17b086554c6cd0c3ab08a35aeec2a8bad8c]/xplat/caffe2/aten/src/ATen/native/layer_norm.cpp?lines=188 +std::tuple native_layer_norm_out( + KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef normalized_shape, + const exec_aten::optional& weight, + const exec_aten::optional& bias, + double eps, + Tensor& out, + Tensor& mean_out, + Tensor& rstd_out) { + (void)ctx; + + std::tuple ret_val(out, mean_out, rstd_out); + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_layer_norm_args( + input, normalized_shape, weight, bias, out, mean_out, rstd_out), + InvalidArgument, + ret_val); + + // Only support default dim order for now. + // TODO: Support other dim orders. + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_default_dim_order(input), + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order( + input, out, mean_out, rstd_out), + InvalidArgument, + ret_val); + + if (weight.has_value()) { + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(input, weight.value()), + InvalidArgument, + ret_val); + } + + if (bias.has_value()) { + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(input, bias.value()), + InvalidArgument, + ret_val); + } + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit]; + size_t mean_rstd_ndim = 0; + torch::executor::get_layer_norm_out_target_size( + input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, input.sizes()) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok, + InvalidArgument, + ret_val); + + int input_shape[kTensorDimensionLimit]; + for (int i = 0; i < input.dim(); i++) { + input_shape[i] = input.size(i); + } + + if (out.scalar_type() == ScalarType::Float) { + float* const out_data = out.mutable_data_ptr(); + float* const mean_data = mean_out.mutable_data_ptr(); + float* const rstd_data = rstd_out.mutable_data_ptr(); + const float* const inp_data = input.const_data_ptr(); + int dim = input.dim() - normalized_shape.size(); + + int num_elm = 1; + for (int i = 0; i < normalized_shape.size(); i++) { + num_elm *= normalized_shape[i]; + } + + float* weight_data; + if (weight.has_value()) { + weight_data = weight.value().mutable_data_ptr(); + } else { + weight_data = (float*)malloc(num_elm * sizeof(float)); + for (int i = 0; i < num_elm; i++) { + weight_data[i] = 1; + } + } + float* bias_data; + if (bias.has_value()) { + bias_data = bias.value().mutable_data_ptr(); + } else { + bias_data = (float*)malloc(num_elm * sizeof(float)); + for (int i = 0; i < num_elm; i++) { + bias_data[i] = 0; + } + } + + xa_nn_native_layer_norm_f32_f32( + out_data, + mean_data, + rstd_data, + inp_data, + input_shape, + input.dim(), + dim, + weight_data, + bias_data, + (float)eps); + + if (!bias.has_value()) { + free(bias_data); + } + if (!weight.has_value()) { + free(weight_data); + } + } else { + ET_SWITCH_FLOAT_TYPES( + input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() { + layer_norm( + input, + normalized_shape, + weight, + bias, + eps, + out, + mean_out, + rstd_out); + }); + } + + return ret_val; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp new file mode 100644 index 0000000000..bc84829edb --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -0,0 +1,797 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ + enum datatype { + Ushort = 20, + Bits4u = 21, + Bits4 = 22 + }; + +/** + * For an input tensor, use the scale and zero_point arguments to quantize it. + */ +namespace cadence { +namespace impl { +namespace FusionG3 { +namespace native { + + +namespace { + +/** + * Asserts that the parameters are valid. + */ +void check_quantize_per_tensor_args(const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + // Ensure self and out has the same shape + ET_CHECK_MSG( + torch::executor::isFloatingType(input.scalar_type()), + "input.scalar_type() %" PRId8 " is not floating type", + static_cast(input.scalar_type())); + + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + ScalarType out_dtype = out.scalar_type(); + ET_CHECK_MSG( + out_dtype == dtype, + "out.scalar_type() %" PRId8 " is not matching dtype argument %" PRId8, + static_cast(out_dtype), + static_cast(dtype)); + + if (out_dtype == ScalarType::Byte) + { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } + else if (dtype == ScalarType::Char) + { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } + else if (dtype == ScalarType::Bits16) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else if (dtype == ScalarType::Short) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else if (dtype == (ScalarType)Ushort) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else if (dtype == (ScalarType)Bits4u) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo unsigned 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } + else if (dtype == (ScalarType)Bits4) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo signed 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } + else if (dtype == ScalarType::Int) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else + { + ET_CHECK_MSG( + false, "Unsupported dtype: %" PRId8, static_cast(out_dtype)); + } + + ET_CHECK_MSG( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for dtype, expected quant_min_lower_bound: %" PRId32 + " actual quant_min: %" PRId64, + quant_min_lower_bound, + quant_min); + + ET_CHECK_MSG( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for dtype, expected quant_max_upper_bound: %" PRId32 + " actual quant_max: %" PRId64, + quant_max_upper_bound, + quant_max); +}/* check_quantize_per_tensor_args */ + +} // namespace + +template +T quantize_val( + double scale, + int64_t zero_point, + K value, + int64_t quant_min, + int64_t quant_max) +{ + int64_t qvalue; + float inv_scale = 1.0f / static_cast(scale); + qvalue = static_cast( + static_cast(zero_point) + + std::nearbyint(static_cast(inv_scale * value))); + + qvalue = std::max(qvalue, quant_min); + qvalue = std::min(qvalue, quant_max); + return static_cast(qvalue); +} + + +/* Local function which calls the kernels based on the output datatype */ +void quantize_impl(Tensor& out, + const Tensor& input, + float *scale_data, + int *zero_point_data, + int *axis, + int quant_min, + int quant_max) +{ + const exec_aten::ArrayRef input_size = input.sizes(); + + int kTensorDimensionLimit = 5; + + int inp_shape[kTensorDimensionLimit]; + + for(auto i = 0; i < input_size.size(); i++) + { + inp_shape[i] = input_size[i]; + } + + const float* input_data = input.const_data_ptr(); + + bool is_asym_quant = 0; + + if(zero_point_data != NULL) //asymmetric quant + { + if(axis != NULL) //channel + { + for(int i = 0; i < input.size(*axis) ; i++) + { + if(zero_point_data[i] != 0) + { + is_asym_quant |= 1; + } + } + } + else + { + if(*zero_point_data != 0) //tensor + { + is_asym_quant |= 1; + } + } + } + + if(is_asym_quant) + { + if (out.scalar_type() == ScalarType::Byte) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Char) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType)Ushort) + { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Short) + { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType)Bits4u) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType)Bits4) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else + { + if(axis == NULL) + { + // Vector quantization + // calculate the quantized input + #define ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, (int64_t)*zero_point_data, value, \ + (int64_t)quant_min, (int64_t)quant_max); \ + } \ + } break; + #define ASYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \ + ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + + } + else + { + // Channel based quantization + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual quantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are quantizing. + // in other words you are quantizing in_data[in_ix] + #define ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define ASYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \ + ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + } + + #undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR + #undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL + #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR + #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + } + } + else + { + if (out.scalar_type() == ScalarType::Byte) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Char) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType) Ushort) + { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Short) + { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType) Bits4u) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType) Bits4) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else + { + if(axis == NULL) + { + // calculate the quantized input + #define SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, (int64_t)*zero_point_data, value, \ + (int64_t)quant_min, (int64_t)quant_max); \ + } \ + } break; + #define SYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \ + SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + + } + else + { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual quantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are quantizing. + // in other words you are quantizing in_data[in_ix] + #define SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define SYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \ + SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + } + #undef SYM_CALCULATE_FLOAT_TYPE_TENSOR + #undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL + #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR + #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + } + } +} + +// Quantize the input tensor +Tensor& quantize_per_tensor_out(KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_tensor_out"); + + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + quantize_impl(out, + input, + &scale_data, + &zero_point_data, + NULL, + (int) quant_min, + (int) quant_max); + + return out; +} + + +Tensor& quantize_per_tensor_tensor_args_out(KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + // Temporary change to allow not fatal failure for now to unblock some + // expected failure tests that are dying instead of failure. Will revisit + // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal + // failures. + if (scale.scalar_type() != ScalarType::Double) + { + context.fail(torch::executor::Error::InvalidArgument); + return out; + } + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "Expected scale to be Double tensor received: %" PRId8, + static_cast(scale.scalar_type())); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "Expected zero_point to be Long tensor received: %" PRId8, + static_cast(zero_point.scalar_type())); + ET_CHECK_MSG( + scale.numel() == 1, + "Exepcted scale to only have one element received: %zd", + ssize_t(scale.numel())); + ET_CHECK_MSG( + zero_point.numel() == 1, + "Exepcted zero_point to only have one element received: %zd", + ssize_t(zero_point.numel())); + + quantize_per_tensor_out(context, + input, + scale.const_data_ptr()[0], + zero_point.const_data_ptr()[0], + quant_min, + quant_max, + dtype, + out); + + return out; +} + +Tensor& quantize_per_tensor_tensor_args_out(const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + auto context = torch::executor::RuntimeContext(); + auto& res = quantize_per_tensor_tensor_args_out( + context, input, scale, zero_point, quant_min, quant_max, dtype, out); + ET_CHECK(context.failure_state() == Error::Ok); + return res; +} + +Tensor& quantize_per_channel_out(const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + + // normalize axis + ET_CHECK_MSG( + executorch::runtime::tensor_has_dim(input, axis), + "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", + ssize_t(axis), + ssize_t(input.dim())); + + if (axis < 0) + { + axis += executorch::runtime::nonzero_dim(input); + } + + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); + + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "scale.scalar_type() %" PRId8 " is not double type", + static_cast(scale.scalar_type())); + + ET_CHECK_MSG( + scale.numel() == input.size(axis), + "scale.numel() %zd != input.size(axis) %zd", + scale.numel(), + input.size(axis)); + + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); + + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + zero_point.numel(), + input.size(axis)); + + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + + + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt = zero_point.const_data_ptr(); + + float scale_data[input.size(axis)]; + int zero_point_data[input.size(axis)]; + + for(int i = 0; i < scale.numel(); i++) + { + scale_data[i] = (float)scale_dt[i]; + zero_point_data[i] = (int)zero_point_dt[i]; + } + + int *axis_ptr = (int *)&axis; + + quantize_impl(out, + input, + scale_data, + zero_point_data, + axis_ptr, + (int) quant_min, + (int) quant_max); + + return out; +} + +Tensor& quantize_per_channel_out(KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + (void)context; + return quantize_per_channel_out( + input, scale, zero_point, axis, quant_min, quant_max, dtype, out); +} + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + size_t num_tokens = 1; + for (size_t i = 0; i < input.dim() - 1; i++) + { + num_tokens *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well +#ifdef USE_ATEN_LIB + std::vector sizes(2); + sizes[0] = num_tokens; + sizes[1] = input.size(input.dim() - 1); + Tensor reshaped_input = at::from_blob( + input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); +#else + std::array input_dim_order{0, 1}; + std::array input_sizes; + input_sizes[0] = num_tokens; + input_sizes[1] = input.size(input.dim() - 1); + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( + input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); +#endif + + return quantize_per_channel_out( + reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out); +} + +Tensor& quantize_per_token_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + (void)context; + return quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp new file mode 100644 index 0000000000..79ec6dc5d7 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& softmax_out( + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + bool half_to_float, + Tensor& out) +{ + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, executorch::runtime::tensors_have_same_dim_order(in, out), InvalidArgument, out); + + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + + int inp_shapes[in.dim()]; + const exec_aten::ArrayRef in_size = in.sizes(); + for(int i = 0; i < in.dim(); i++) + { + inp_shapes[i] = in_size[i]; + } + + if(out.scalar_type() == ScalarType::Float) + { + const float * const inp_data = in.const_data_ptr(); + float * const out_data = out.mutable_data_ptr(); + int axis = dim; + xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, + in.dim(), &axis); + } + else + { + ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); + + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in softmax dim. During softmax computation each + // value is subtracted by the maximum in value before calling exp + // to preserve numerical stability. + const CTYPE max_in = torch::executor::apply_unary_reduce_fn( + [](const CTYPE val_in, CTYPE val_accum) { + return std::max(val_in, val_accum); + }, + in_data + base, + size, + stride); + + const CTYPE temp_sum = torch::executor:: + apply_unary_map_reduce_fn( + [max_in](const CTYPE val_in) { + return std::exp(val_in - max_in); + }, + [](const CTYPE mapped_in, CTYPE val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const CTYPE val_in) { + return std::exp(val_in - max_in) / temp_sum; + }, + in_data + base, + out_data + base, + size, + stride); + }, + in, + dim); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt b/backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt new file mode 100644 index 0000000000..a2615e0851 --- /dev/null +++ b/backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.10.0) +project(cadence_nnlib) + +add_custom_target( + nnlib_target ALL + COMMAND + make install_nnlib -f makefile -C + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/build + OBJDIR=${CMAKE_CURRENT_BINARY_DIR}/obj + LIBDIR=${CMAKE_CURRENT_BINARY_DIR}/lib -j8 +) + +add_library(xa_nnlib STATIC IMPORTED GLOBAL) +add_dependencies(xa_nnlib nnlib_target) + +set_property( + TARGET xa_nnlib PROPERTY IMPORTED_LOCATION + "${CMAKE_CURRENT_BINARY_DIR}/lib/xa_nnlib.a" +) diff --git a/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 b/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 new file mode 160000 index 0000000000..8ddd1c39d4 --- /dev/null +++ b/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 @@ -0,0 +1 @@ +Subproject commit 8ddd1c39d4b20235ebe9dac68d92848da2885ece From ffb1b7d03c4ec02ebc804731229ef5d1ee3162fe Mon Sep 17 00:00:00 2001 From: JP <46308822+zonglinpeng@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:03:56 -0800 Subject: [PATCH 12/18] fix lint issue from g3 PR (#7060) Summary: ~ Reviewed By: mcremon-meta Differential Revision: D66465813 --- .../fusion_g3/operators/op_dequantize.cpp | 1049 +++++++------- .../fusion_g3/operators/op_quantize.cpp | 1234 +++++++++-------- .../fusion_g3/operators/op_softmax.cpp | 135 +- 3 files changed, 1191 insertions(+), 1227 deletions(-) diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp index 784011332f..f450ed398f 100644 --- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -23,20 +23,16 @@ template using optional = exec_aten::optional; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is - * updated to have support for below data types, these can be removed and + * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ - - enum datatype { - Ushort = 20, - Bits4u = 21, - Bits4 = 22 - }; + +enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. */ -namespace cadence { +namespace cadence { namespace impl { namespace G3 { namespace native { @@ -46,38 +42,38 @@ namespace { /** * Asserts that the parameters are valid. */ -void check_dequantize_per_tensor_args(const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional& out_dtype, - Tensor& out) -{ - ET_CHECK_MSG( +void check_dequantize_per_tensor_args( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional& out_dtype, + Tensor& out) { + ET_CHECK_MSG( input.scalar_type() == ScalarType::Byte || input.scalar_type() == ScalarType::Char || input.scalar_type() == ScalarType::Bits16 || input.scalar_type() == ScalarType::Short || - input.scalar_type() == (ScalarType) Ushort || - input.scalar_type() == (ScalarType) Bits4 || - input.scalar_type() == (ScalarType) Bits4u || + input.scalar_type() == (ScalarType)Ushort || + input.scalar_type() == (ScalarType)Bits4 || + input.scalar_type() == (ScalarType)Bits4u || input.scalar_type() == ScalarType::Int, - + "input.scalar_type() %" PRId8 " is not supported:", static_cast(input.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( input.scalar_type() == dtype, "input.scalar_type() %" PRId8 " is not matching dtype argumenta:", static_cast(input.scalar_type())); - if (out_dtype.has_value()) { + if (out_dtype.has_value()) { ET_CHECK_MSG( out.scalar_type() == out_dtype.value(), "output_dtype must match the dtype of the out tensor"); - } + } - ET_CHECK_MSG( + ET_CHECK_MSG( quant_min <= quant_max, "quant min: %" PRId64 " is greater than quant max: %" PRId64, quant_min, @@ -86,412 +82,395 @@ void check_dequantize_per_tensor_args(const Tensor& input, } // namespace - /* Local function which calls the kernels based on the input datatype */ -void Dequantize_impl(Tensor& out, - const Tensor& input, - float *scale_data, - int *zero_point_data, - int *axis, - exec_aten::optional out_dtype) -{ - const exec_aten::ArrayRef input_size = input.sizes(); +void Dequantize_impl( + Tensor& out, + const Tensor& input, + float* scale_data, + int* zero_point_data, + int* axis, + exec_aten::optional out_dtype) { + const exec_aten::ArrayRef input_size = input.sizes(); - int kTensorDimensionLimit = 5; + int kTensorDimensionLimit = 5; - int inp_shape[kTensorDimensionLimit]; + int inp_shape[kTensorDimensionLimit]; - for(auto i = 0; i < input_size.size(); i++) - { - inp_shape[i] = input_size[i]; - } + for (auto i = 0; i < input_size.size(); i++) { + inp_shape[i] = input_size[i]; + } - bool is_asym_dequant = 0; + bool is_asym_dequant = 0; - if(zero_point_data != NULL) //asymmetric dequant + if (zero_point_data != NULL) // asymmetric dequant + { + if (axis != NULL) // channel { - if(axis != NULL) //channel - { - for(int i = 0; i < input.size(*axis) ; i++) - { - if(zero_point_data[i] != 0) - { - is_asym_dequant |= 1; - } + for (int i = 0; i < input.size(*axis); i++) { + if (zero_point_data[i] != 0) { + is_asym_dequant |= 1; } } - else + } else { + if (*zero_point_data != 0) // tesor { - if(*zero_point_data != 0) //tesor - { - is_asym_dequant |= 1; - } + is_asym_dequant |= 1; } } - float* out_data = out.mutable_data_ptr(); - - if(is_asym_dequant) - { - if (input.scalar_type() == ScalarType::Byte) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8u_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == ScalarType::Char) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); + } + float* out_data = out.mutable_data_ptr(); + + if (is_asym_dequant) { + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8u_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == (ScalarType)Ushort) { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16u_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4u) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4u_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else { + if (axis == NULL) { +// calculate the dequantized output, cast scale to float to match fbgemm +// behavior +#define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; +#define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR); + ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (input.scalar_type() == (ScalarType) Ushort) - { - const uint16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym16u_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == ScalarType::Short) - { - const int16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym16_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == (ScalarType) Bits4u) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym4u_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == (ScalarType) Bits4) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym4_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); +#undef ASYM_CALCULATE_INT_TYPE_TENSOR +#undef ASYM_DEQUANTIZE_IMPL_TESNOR + } else { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // calculate the dequantized output, cast scale to float to match fbgemm - // behavior - #define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - out_data_ptr[i] = static_cast( \ - (input_data_ptr[i] - static_cast(*zero_point_data)) * \ - static_cast(*scale_data)); \ - } \ - } break; - #define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR); - ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef ASYM_CALCULATE_INT_TYPE_TENSOR - #undef ASYM_DEQUANTIZE_IMPL_TESNOR - } - else - { - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual dequantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are dequantizing. - // in other words you are dequantizing in_data[in_ix] - #define ASYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - if (input.dim() == 1) { \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - ET_CHECK_MSG( \ - *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ - const optional dim; \ - torch::executor::apply_over_dim( \ - [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ - size_t numel, size_t stride, size_t base_ix) { \ - for (size_t i = 0; i < numel; i++) { \ - size_t current_ix = base_ix * stride + i; \ - float _scale = scale_data[current_ix]; \ - int64_t zero_point = 0; \ - if (zero_point_data != nullptr) { \ - zero_point = zero_point_data[current_ix]; \ - } \ - out_data_ptr[current_ix] = \ - static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ - _scale; \ - } \ - }, \ - input, \ - dim); \ - break; \ - } \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - float _scale = scale_data[channel_ix]; \ - int64_t _zero_point = 0; \ - if (zero_point_data != nullptr) { \ - _zero_point = zero_point_data[channel_ix]; \ - } \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ - out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define ASYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_CHANNEL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL); - ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef ASYM_CALCULATE_INT_TYPE_CHANNEL - #undef ASYM_DEQUANTIZE_IMPL_CHANNEL - } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual dequantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are dequantizing. +// in other words you are dequantizing in_data[in_ix] +#define ASYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define ASYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL); + ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } +#undef ASYM_CALCULATE_INT_TYPE_CHANNEL +#undef ASYM_DEQUANTIZE_IMPL_CHANNEL + } } - else - { - if (input.scalar_type() == ScalarType::Byte) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym8u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == ScalarType::Char) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym8_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == (ScalarType) Ushort) - { - const uint16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym16u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == ScalarType::Short) - { - const int16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym16_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == (ScalarType) Bits4u) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym4u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else { + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == (ScalarType)Ushort) { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4u) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else { + if (axis == NULL) { +// calculate the dequantized output, cast scale to float to match fbgemm +// behavior +#define SYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; +#define SYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR); + SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (input.scalar_type() == (ScalarType) Bits4) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym4_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); +#undef SYM_DEQUANTIZE_IMPL_TESNOR +#undef SYM_CALCULATE_INT_TYPE_TENSOR + } else { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // calculate the dequantized output, cast scale to float to match fbgemm - // behavior - #define SYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - out_data_ptr[i] = static_cast( \ - (input_data_ptr[i] - static_cast(*zero_point_data)) * \ - static_cast(*scale_data)); \ - } \ - } break; - #define SYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_TESNOR); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR); - SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef SYM_DEQUANTIZE_IMPL_TESNOR - #undef SYM_CALCULATE_INT_TYPE_TENSOR - } - else - { - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual dequantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are dequantizing. - // in other words you are dequantizing in_data[in_ix] - #define SYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - if (input.dim() == 1) { \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - ET_CHECK_MSG( \ - *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ - const optional dim; \ - torch::executor::apply_over_dim( \ - [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ - size_t numel, size_t stride, size_t base_ix) { \ - for (size_t i = 0; i < numel; i++) { \ - size_t current_ix = base_ix * stride + i; \ - float _scale = scale_data[current_ix]; \ - int64_t zero_point = 0; \ - if (zero_point_data != nullptr) { \ - zero_point = zero_point_data[current_ix]; \ - } \ - out_data_ptr[current_ix] = \ - static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ - _scale; \ - } \ - }, \ - input, \ - dim); \ - break; \ - } \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - float _scale = scale_data[channel_ix]; \ - int64_t _zero_point = 0; \ - if (zero_point_data != nullptr) { \ - _zero_point = zero_point_data[channel_ix]; \ - } \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ - out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define SYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_CHANNEL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL); - SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef SYM_DEQUANTIZE_IMPL_CHANNEL - #undef SYM_CALCULATE_INT_TYPE_CHANNEL - } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual dequantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are dequantizing. +// in other words you are dequantizing in_data[in_ix] +#define SYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define SYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL); + SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } +#undef SYM_DEQUANTIZE_IMPL_CHANNEL +#undef SYM_CALCULATE_INT_TYPE_CHANNEL + } } + } } /** @@ -511,56 +490,50 @@ Tensor& dequantize_per_tensor_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_tensor_out"); - check_dequantize_per_tensor_args( + check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); - - float scale_data = (float)scale; - int zero_point_data = (int)zero_point; - - Dequantize_impl(out, - input, - &scale_data, - &zero_point_data, - NULL, - out_dtype); - - return out; + + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + + Dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype); + + return out; } -Tensor& dequantize_per_tensor_tensor_args_out(const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional out_dtype, - Tensor& out) -{ - ET_CHECK_MSG( +Tensor& dequantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) { + ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "Expected scale to be Double tensor received: %" PRId8, static_cast(scale.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.scalar_type() == ScalarType::Long, "Expected scale to be Long tensor received: %" PRId8, static_cast(zero_point.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.numel() == 1, "Exepcted scale to only have one element received: %zd", ssize_t(scale.numel())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.numel() == 1, "Exepcted zero_point to only have one element received: %zd", ssize_t(zero_point.numel())); - dequantize_per_tensor_out( + dequantize_per_tensor_out( input, scale.const_data_ptr()[0], zero_point.const_data_ptr()[0], @@ -570,49 +543,48 @@ Tensor& dequantize_per_tensor_tensor_args_out(const Tensor& input, out_dtype, out); - return out; + return out; } -Tensor& dequantize_per_channel_out(const Tensor& input, - const Tensor& scale, - const exec_aten::optional& opt_zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional out_dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - - // normalize axis - ET_CHECK_MSG( +Tensor& dequantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const exec_aten::optional& opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); + + // normalize axis + ET_CHECK_MSG( executorch::runtime::tensor_has_dim(input, axis), "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", ssize_t(axis), ssize_t(input.dim())); - if (axis < 0) - { - axis += executorch::runtime::nonzero_dim(input); - } + if (axis < 0) { + axis += executorch::runtime::nonzero_dim(input); + } - ET_CHECK_MSG( + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "scale.scalar_type() %" PRId8 " is not double type", static_cast(scale.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.numel() == input.size(axis), "scale.numel() %zd != input.size(axis) %zd", ssize_t(scale.numel()), ssize_t(input.size(axis))); - if (opt_zero_points.has_value()) { + if (opt_zero_points.has_value()) { auto zero_point = opt_zero_points.value(); ET_CHECK_MSG( zero_point.scalar_type() == ScalarType::Long, @@ -624,41 +596,31 @@ Tensor& dequantize_per_channel_out(const Tensor& input, "zero_point.numel() %zd != input.size(axis) %zd", ssize_t(zero_point.numel()), ssize_t(input.size(axis))); - } + } - check_dequantize_per_tensor_args( + check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); - - int *axis_ptr = (int *)&axis; - - const double* scale_dt = scale.const_data_ptr(); - const int64_t* zero_point_dt; - int zero_point_data[input.size(axis)]; - int *zero_point_ptr; - if (opt_zero_points.has_value()) - { - zero_point_dt = opt_zero_points.value().const_data_ptr(); - zero_point_ptr = &zero_point_data[0]; - for(int i = 0; i < scale.numel(); i++) - { - zero_point_ptr[i] = (int)zero_point_dt[i]; - } - } - else - { - zero_point_ptr = nullptr; - } - float scale_data[input.size(axis)]; - for(int i = 0; i < scale.numel(); i++) - { - scale_data[i] = (float)scale_dt[i]; + + int* axis_ptr = (int*)&axis; + + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt; + int zero_point_data[input.size(axis)]; + int* zero_point_ptr; + if (opt_zero_points.has_value()) { + zero_point_dt = opt_zero_points.value().const_data_ptr(); + zero_point_ptr = &zero_point_data[0]; + for (int i = 0; i < scale.numel(); i++) { + zero_point_ptr[i] = (int)zero_point_dt[i]; } - Dequantize_impl(out, - input, - scale_data, - zero_point_ptr, - axis_ptr, - out_dtype); + } else { + zero_point_ptr = nullptr; + } + float scale_data[input.size(axis)]; + for (int i = 0; i < scale.numel(); i++) { + scale_data[i] = (float)scale_dt[i]; + } + Dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype); return out; } @@ -673,14 +635,13 @@ Tensor& dequantize_per_channel_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - (void)context; - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + Tensor& out) { + (void)context; + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); - return dequantize_per_channel_out( + return dequantize_per_channel_out( input, scale, opt_zero_points, @@ -701,12 +662,11 @@ Tensor& dequantize_per_tensor_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - return dequantize_per_tensor_out( + Tensor& out) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_out( input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } @@ -719,12 +679,11 @@ Tensor& dequantize_per_tensor_tensor_args_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - return dequantize_per_tensor_tensor_args_out( + Tensor& out) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_tensor_args_out( input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } @@ -736,47 +695,46 @@ Tensor& dequantize_per_token_out( int64_t quant_max, ScalarType dtype, ScalarType out_dtype, - Tensor& out) -{ - // Refactor this into a util - size_t num_channels = 1; - for (size_t i = 0; i < input.dim() - 1; i++) - { - num_channels *= input.size(i); - } - // This unfortunate change is needed because we compile op_quantize for aten - // mode as well - std::array input_sizes; - input_sizes[0] = static_cast(num_channels); - input_sizes[1] = + Tensor& out) { + // Refactor this into a util + size_t num_channels = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + num_channels *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well + std::array input_sizes; + input_sizes[0] = static_cast(num_channels); + input_sizes[1] = static_cast(input.size(input.dim() - 1)); #ifdef USE_ATEN_LIB - Tensor reshaped_input = at::from_blob( + Tensor reshaped_input = at::from_blob( input.mutable_data_ptr(), input_sizes, at::TensorOptions(input.scalar_type())); #else - std::array input_dim_order{0, 1}; - std::array input_strides; - executorch::runtime::dim_order_to_stride_nocheck( + std::array input_dim_order{0, 1}; + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); - void* input_data = input.mutable_data_ptr(); - torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( - input.scalar_type(), - 2, - input_sizes.data(), - input_data, - input_dim_order.data(), - input_strides.data(), - executorch::runtime::TensorShapeDynamism::STATIC); - Tensor reshaped_input(&reshaped_input_impl); - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = + executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); #endif - return dequantize_per_channel_out( + return dequantize_per_channel_out( reshaped_input, scale, zero_points, @@ -797,8 +755,7 @@ Tensor& dequantize_per_token_out( int64_t quant_max, ScalarType dtype, ScalarType out_dtype, - Tensor& out) -{ + Tensor& out) { (void)context; return dequantize_per_token_out( input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp index bc84829edb..2b8376dc8d 100644 --- a/backends/cadence/fusion_g3/operators/op_quantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -8,10 +8,10 @@ #include #include +#include #include #include #include -#include using exec_aten::Scalar; using exec_aten::ScalarType; @@ -21,14 +21,10 @@ using torch::executor::KernelRuntimeContext; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is - * updated to have support for below data types, these can be removed and + * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ - enum datatype { - Ushort = 20, - Bits4u = 21, - Bits4 = 22 - }; +enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -38,102 +34,84 @@ namespace impl { namespace FusionG3 { namespace native { - namespace { /** * Asserts that the parameters are valid. */ -void check_quantize_per_tensor_args(const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - // Ensure self and out has the same shape - ET_CHECK_MSG( +void check_quantize_per_tensor_args( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + // Ensure self and out has the same shape + ET_CHECK_MSG( torch::executor::isFloatingType(input.scalar_type()), "input.scalar_type() %" PRId8 " is not floating type", static_cast(input.scalar_type())); - int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; - ScalarType out_dtype = out.scalar_type(); - ET_CHECK_MSG( + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + ScalarType out_dtype = out.scalar_type(); + ET_CHECK_MSG( out_dtype == dtype, "out.scalar_type() %" PRId8 " is not matching dtype argument %" PRId8, static_cast(out_dtype), static_cast(dtype)); - if (out_dtype == ScalarType::Byte) - { - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - } - else if (dtype == ScalarType::Char) - { - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - } - else if (dtype == ScalarType::Bits16) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else if (dtype == ScalarType::Short) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else if (dtype == (ScalarType)Ushort) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else if (dtype == (ScalarType)Bits4u) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - /* Minimum and maximum values fo unsigned 4-bit data type */ - quant_min_lower_bound = quant_min_lower_bound >> 4; - quant_max_upper_bound = quant_max_upper_bound >> 4; - } - else if (dtype == (ScalarType)Bits4) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - /* Minimum and maximum values fo signed 4-bit data type */ - quant_min_lower_bound = quant_min_lower_bound >> 4; - quant_max_upper_bound = quant_max_upper_bound >> 4; - } - else if (dtype == ScalarType::Int) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else - { - ET_CHECK_MSG( - false, "Unsupported dtype: %" PRId8, static_cast(out_dtype)); - } - + if (out_dtype == ScalarType::Byte) { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } else if (dtype == ScalarType::Char) { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } else if (dtype == ScalarType::Bits16) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else if (dtype == ScalarType::Short) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else if (dtype == (ScalarType)Ushort) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else if (dtype == (ScalarType)Bits4u) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo unsigned 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } else if (dtype == (ScalarType)Bits4) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo signed 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } else if (dtype == ScalarType::Int) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else { ET_CHECK_MSG( + false, "Unsupported dtype: %" PRId8, static_cast(out_dtype)); + } + + ET_CHECK_MSG( quant_min >= quant_min_lower_bound, "quant_min out of bound for dtype, expected quant_min_lower_bound: %" PRId32 " actual quant_min: %" PRId64, quant_min_lower_bound, quant_min); - ET_CHECK_MSG( + ET_CHECK_MSG( quant_max <= quant_max_upper_bound, "quant_max out of bound for dtype, expected quant_max_upper_bound: %" PRId32 " actual quant_max: %" PRId64, quant_max_upper_bound, quant_max); -}/* check_quantize_per_tensor_args */ +} /* check_quantize_per_tensor_args */ } // namespace @@ -143,8 +121,7 @@ T quantize_val( int64_t zero_point, K value, int64_t quant_min, - int64_t quant_max) -{ + int64_t quant_max) { int64_t qvalue; float inv_scale = 1.0f / static_cast(scale); qvalue = static_cast( @@ -156,458 +133,495 @@ T quantize_val( return static_cast(qvalue); } - /* Local function which calls the kernels based on the output datatype */ -void quantize_impl(Tensor& out, - const Tensor& input, - float *scale_data, - int *zero_point_data, - int *axis, - int quant_min, - int quant_max) -{ - const exec_aten::ArrayRef input_size = input.sizes(); +void quantize_impl( + Tensor& out, + const Tensor& input, + float* scale_data, + int* zero_point_data, + int* axis, + int quant_min, + int quant_max) { + const exec_aten::ArrayRef input_size = input.sizes(); - int kTensorDimensionLimit = 5; + int kTensorDimensionLimit = 5; - int inp_shape[kTensorDimensionLimit]; + int inp_shape[kTensorDimensionLimit]; - for(auto i = 0; i < input_size.size(); i++) - { - inp_shape[i] = input_size[i]; - } - - const float* input_data = input.const_data_ptr(); + for (auto i = 0; i < input_size.size(); i++) { + inp_shape[i] = input_size[i]; + } - bool is_asym_quant = 0; + const float* input_data = input.const_data_ptr(); - if(zero_point_data != NULL) //asymmetric quant + bool is_asym_quant = 0; + + if (zero_point_data != NULL) // asymmetric quant + { + if (axis != NULL) // channel { - if(axis != NULL) //channel - { - for(int i = 0; i < input.size(*axis) ; i++) - { - if(zero_point_data[i] != 0) - { - is_asym_quant |= 1; - } + for (int i = 0; i < input.size(*axis); i++) { + if (zero_point_data[i] != 0) { + is_asym_quant |= 1; } } - else + } else { + if (*zero_point_data != 0) // tensor { - if(*zero_point_data != 0) //tensor - { - is_asym_quant |= 1; - } + is_asym_quant |= 1; } } - - if(is_asym_quant) - { - if (out.scalar_type() == ScalarType::Byte) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym8u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); + } + + if (is_asym_quant) { + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Ushort) { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4u) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else { + if (axis == NULL) { + // Vector quantization +// calculate the quantized input +#define ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, \ + (int64_t) * zero_point_data, \ + value, \ + (int64_t)quant_min, \ + (int64_t)quant_max); \ + } \ + } break; +#define ASYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \ + ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (out.scalar_type() == ScalarType::Char) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym8( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType)Ushort) - { - uint16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym16u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == ScalarType::Short) - { - int16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym16( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType)Bits4u) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym4u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType)Bits4) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym4( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); + + } else { + // Channel based quantization + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // Vector quantization - // calculate the quantized input - #define ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - IN_CTYPE value = input_data_ptr[i]; \ - out_data_ptr[i] = quantize_val( \ - (double)*scale_data, (int64_t)*zero_point_data, value, \ - (int64_t)quant_min, (int64_t)quant_max); \ - } \ - } break; - #define ASYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \ - ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_TENSOR); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - - } - else - { - // Channel based quantization - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] - #define ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - double _scale = (double)scale_data[channel_ix]; \ - int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define ASYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \ - ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_CHANNEL); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - } - - #undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR - #undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL - #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR - #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual quantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are quantizing. +// in other words you are quantizing in_data[in_ix] +#define ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define ASYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \ + ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } + } + +#undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR +#undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL +#undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR +#undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL } - else - { - if (out.scalar_type() == ScalarType::Byte) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym8u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == ScalarType::Char) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym8( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType) Ushort) - { - uint16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym16u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); + } else { + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Ushort) { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4u) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else { + if (axis == NULL) { + // calculate the quantized input +#define SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, \ + (int64_t) * zero_point_data, \ + value, \ + (int64_t)quant_min, \ + (int64_t)quant_max); \ + } \ + } break; +#define SYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \ + SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (out.scalar_type() == ScalarType::Short) - { - int16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym16( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType) Bits4u) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym4u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType) Bits4) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym4( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); + + } else { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // calculate the quantized input - #define SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - IN_CTYPE value = input_data_ptr[i]; \ - out_data_ptr[i] = quantize_val( \ - (double)*scale_data, (int64_t)*zero_point_data, value, \ - (int64_t)quant_min, (int64_t)quant_max); \ - } \ - } break; - #define SYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \ - SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_TENSOR); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - - } - else - { - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] - #define SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - double _scale = (double)scale_data[channel_ix]; \ - int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define SYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \ - SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_CHANNEL); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - } - #undef SYM_CALCULATE_FLOAT_TYPE_TENSOR - #undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL - #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR - #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual quantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are quantizing. +// in other words you are quantizing in_data[in_ix] +#define SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define SYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \ + SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } + } +#undef SYM_CALCULATE_FLOAT_TYPE_TENSOR +#undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL +#undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR +#undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL } + } } // Quantize the input tensor -Tensor& quantize_per_tensor_out(KernelRuntimeContext& context, - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( +Tensor& quantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in quantize_per_tensor_out"); - check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - float scale_data = (float)scale; - int zero_point_data = (int)zero_point; - quantize_impl(out, - input, - &scale_data, - &zero_point_data, - NULL, - (int) quant_min, - (int) quant_max); + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + quantize_impl( + out, + input, + &scale_data, + &zero_point_data, + NULL, + (int)quant_min, + (int)quant_max); - return out; + return out; } - -Tensor& quantize_per_tensor_tensor_args_out(KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - // Temporary change to allow not fatal failure for now to unblock some - // expected failure tests that are dying instead of failure. Will revisit - // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal - // failures. - if (scale.scalar_type() != ScalarType::Double) - { - context.fail(torch::executor::Error::InvalidArgument); - return out; - } - ET_CHECK_MSG( +Tensor& quantize_per_tensor_tensor_args_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + // Temporary change to allow not fatal failure for now to unblock some + // expected failure tests that are dying instead of failure. Will revisit + // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal + // failures. + if (scale.scalar_type() != ScalarType::Double) { + context.fail(torch::executor::Error::InvalidArgument); + return out; + } + ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "Expected scale to be Double tensor received: %" PRId8, static_cast(scale.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.scalar_type() == ScalarType::Long, "Expected zero_point to be Long tensor received: %" PRId8, static_cast(zero_point.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.numel() == 1, "Exepcted scale to only have one element received: %zd", ssize_t(scale.numel())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.numel() == 1, "Exepcted zero_point to only have one element received: %zd", ssize_t(zero_point.numel())); - quantize_per_tensor_out(context, + quantize_per_tensor_out( + context, input, scale.const_data_ptr()[0], zero_point.const_data_ptr()[0], @@ -616,113 +630,111 @@ Tensor& quantize_per_tensor_tensor_args_out(KernelRuntimeContext& context, dtype, out); - return out; + return out; } -Tensor& quantize_per_tensor_tensor_args_out(const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - auto context = torch::executor::RuntimeContext(); - auto& res = quantize_per_tensor_tensor_args_out( +Tensor& quantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + auto context = torch::executor::RuntimeContext(); + auto& res = quantize_per_tensor_tensor_args_out( context, input, scale, zero_point, quant_min, quant_max, dtype, out); - ET_CHECK(context.failure_state() == Error::Ok); - return res; + ET_CHECK(context.failure_state() == Error::Ok); + return res; } -Tensor& quantize_per_channel_out(const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - - // normalize axis - ET_CHECK_MSG( - executorch::runtime::tensor_has_dim(input, axis), - "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", - ssize_t(axis), - ssize_t(input.dim())); +Tensor& quantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); - if (axis < 0) - { - axis += executorch::runtime::nonzero_dim(input); - } + // normalize axis + ET_CHECK_MSG( + executorch::runtime::tensor_has_dim(input, axis), + "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", + ssize_t(axis), + ssize_t(input.dim())); - ET_CHECK_MSG( - err == torch::executor::Error::Ok, - "Failed to resize out Tensor in quantize_per_channel_out"); + if (axis < 0) { + axis += executorch::runtime::nonzero_dim(input); + } - ET_CHECK_MSG( - scale.scalar_type() == ScalarType::Double, - "scale.scalar_type() %" PRId8 " is not double type", - static_cast(scale.scalar_type())); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); - ET_CHECK_MSG( - scale.numel() == input.size(axis), - "scale.numel() %zd != input.size(axis) %zd", - scale.numel(), - input.size(axis)); + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "scale.scalar_type() %" PRId8 " is not double type", + static_cast(scale.scalar_type())); - ET_CHECK_MSG( - zero_point.scalar_type() == ScalarType::Long, - "zero_point.scalar_type() %" PRId8 " is not integer type", - static_cast(zero_point.scalar_type())); + ET_CHECK_MSG( + scale.numel() == input.size(axis), + "scale.numel() %zd != input.size(axis) %zd", + scale.numel(), + input.size(axis)); - ET_CHECK_MSG( - zero_point.numel() == input.size(axis), - "zero_point.numel() %zd != input.size(axis) %zd", - zero_point.numel(), - input.size(axis)); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); - check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + zero_point.numel(), + input.size(axis)); + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - const double* scale_dt = scale.const_data_ptr(); - const int64_t* zero_point_dt = zero_point.const_data_ptr(); - - float scale_data[input.size(axis)]; - int zero_point_data[input.size(axis)]; + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt = zero_point.const_data_ptr(); - for(int i = 0; i < scale.numel(); i++) - { - scale_data[i] = (float)scale_dt[i]; - zero_point_data[i] = (int)zero_point_dt[i]; - } + float scale_data[input.size(axis)]; + int zero_point_data[input.size(axis)]; - int *axis_ptr = (int *)&axis; + for (int i = 0; i < scale.numel(); i++) { + scale_data[i] = (float)scale_dt[i]; + zero_point_data[i] = (int)zero_point_dt[i]; + } - quantize_impl(out, - input, - scale_data, - zero_point_data, - axis_ptr, - (int) quant_min, - (int) quant_max); + int* axis_ptr = (int*)&axis; - return out; + quantize_impl( + out, + input, + scale_data, + zero_point_data, + axis_ptr, + (int)quant_min, + (int)quant_max); + + return out; } -Tensor& quantize_per_channel_out(KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - (void)context; - return quantize_per_channel_out( +Tensor& quantize_per_channel_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + (void)context; + return quantize_per_channel_out( input, scale, zero_point, axis, quant_min, quant_max, dtype, out); } @@ -733,46 +745,45 @@ Tensor& quantize_per_token_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, - Tensor& out) -{ - size_t num_tokens = 1; - for (size_t i = 0; i < input.dim() - 1; i++) - { - num_tokens *= input.size(i); - } - // This unfortunate change is needed because we compile op_quantize for aten - // mode as well + Tensor& out) { + size_t num_tokens = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well #ifdef USE_ATEN_LIB - std::vector sizes(2); - sizes[0] = num_tokens; - sizes[1] = input.size(input.dim() - 1); - Tensor reshaped_input = at::from_blob( + std::vector sizes(2); + sizes[0] = num_tokens; + sizes[1] = input.size(input.dim() - 1); + Tensor reshaped_input = at::from_blob( input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); #else - std::array input_dim_order{0, 1}; - std::array input_sizes; - input_sizes[0] = num_tokens; - input_sizes[1] = input.size(input.dim() - 1); - std::array input_strides; - executorch::runtime::dim_order_to_stride_nocheck( + std::array input_dim_order{0, 1}; + std::array input_sizes; + input_sizes[0] = num_tokens; + input_sizes[1] = input.size(input.dim() - 1); + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); - void* input_data = input.mutable_data_ptr(); - torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( - input.scalar_type(), - 2, - input_sizes.data(), - input_data, - input_dim_order.data(), - input_strides.data(), - executorch::runtime::TensorShapeDynamism::STATIC); - Tensor reshaped_input(&reshaped_input_impl); - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = + executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in quantize_per_channel_out"); #endif - return quantize_per_channel_out( + return quantize_per_channel_out( reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out); } @@ -784,14 +795,13 @@ Tensor& quantize_per_token_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, - Tensor& out) -{ - (void)context; - return quantize_per_token_out( + Tensor& out) { + (void)context; + return quantize_per_token_out( input, scale, zero_point, quant_min, quant_max, dtype, out); } } // namespace native -} // namespace G3 +} // namespace FusionG3 } // namespace impl } // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp index 79ec6dc5d7..c3287643cc 100644 --- a/backends/cadence/fusion_g3/operators/op_softmax.cpp +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -6,12 +6,12 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include #include #include +#include using exec_aten::Scalar; using exec_aten::ScalarType; @@ -21,95 +21,92 @@ using torch::executor::KernelRuntimeContext; namespace cadence { namespace impl { -namespace G3 { +namespace G3 { namespace native { Tensor& softmax_out( - KernelRuntimeContext& ctx, - const Tensor& in, - int64_t dim, - bool half_to_float, - Tensor& out) -{ - (void)ctx; + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + bool half_to_float, + Tensor& out) { + (void)ctx; - ET_KERNEL_CHECK( - ctx, - torch::executor::check_softmax_args(in, dim, half_to_float, out), - InvalidArgument, - out); + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); - ET_KERNEL_CHECK( - ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, executorch::runtime::tensors_have_same_dim_order(in, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - // Adjust for negative dim - dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); - int inp_shapes[in.dim()]; - const exec_aten::ArrayRef in_size = in.sizes(); - for(int i = 0; i < in.dim(); i++) - { - inp_shapes[i] = in_size[i]; - } + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; - if(out.scalar_type() == ScalarType::Float) - { - const float * const inp_data = in.const_data_ptr(); - float * const out_data = out.mutable_data_ptr(); - int axis = dim; - xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, - in.dim(), &axis); - } - else - { - ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { - const CTYPE* const in_data = in.const_data_ptr(); - CTYPE* const out_data = out.mutable_data_ptr(); + int inp_shapes[in.dim()]; + const exec_aten::ArrayRef in_size = in.sizes(); + for (int i = 0; i < in.dim(); i++) { + inp_shapes[i] = in_size[i]; + } - torch::executor::apply_over_dim( - [in_data, out_data]( - const size_t size, const size_t stride, const size_t base) { - // calculate max in softmax dim. During softmax computation each - // value is subtracted by the maximum in value before calling exp - // to preserve numerical stability. - const CTYPE max_in = torch::executor::apply_unary_reduce_fn( - [](const CTYPE val_in, CTYPE val_accum) { - return std::max(val_in, val_accum); - }, - in_data + base, - size, - stride); + if (out.scalar_type() == ScalarType::Float) { + const float* const inp_data = in.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + int axis = dim; + xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, in.dim(), &axis); + } else { + ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); - const CTYPE temp_sum = torch::executor:: - apply_unary_map_reduce_fn( - [max_in](const CTYPE val_in) { - return std::exp(val_in - max_in); - }, - [](const CTYPE mapped_in, CTYPE val_accum) { - return val_accum + mapped_in; + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in softmax dim. During softmax computation each + // value is subtracted by the maximum in value before calling exp + // to preserve numerical stability. + const CTYPE max_in = torch::executor::apply_unary_reduce_fn( + [](const CTYPE val_in, CTYPE val_accum) { + return std::max(val_in, val_accum); }, in_data + base, size, stride); - torch::executor::apply_unary_map_fn( - [max_in, temp_sum](const CTYPE val_in) { - return std::exp(val_in - max_in) / temp_sum; + const CTYPE temp_sum = + torch::executor::apply_unary_map_reduce_fn( + [max_in](const CTYPE val_in) { + return std::exp(val_in - max_in); + }, + [](const CTYPE mapped_in, CTYPE val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const CTYPE val_in) { + return std::exp(val_in - max_in) / temp_sum; }, in_data + base, out_data + base, size, stride); - }, - in, - dim); - }); - } + }, + in, + dim); + }); + } - return out; + return out; } } // namespace native From 52fa043d2bdd50d3366436cd6e14b78c1e6697c6 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:42:58 -0800 Subject: [PATCH 13/18] Fix pyre Differential Revision: D66468376 Pull Request resolved: https://github.com/pytorch/executorch/pull/7058 --- .../llama/source_transformation/apply_spin_quant_r1_r2.py | 2 +- exir/emit/_emitter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py b/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py index 7ec35c7b6c..89f564935f 100644 --- a/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py +++ b/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py @@ -146,9 +146,9 @@ def fuse_ln_linear( torch.zeros(linear.out_features, dtype=torch.float32) ) linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul( + W_, # pyre-fixme[6]: For 2nd argument expected `Tensor` but got # `Union[Tensor, Module]`. - W_, layernorm.bias.to(dtype=torch.float32), ) linear.bias.data = linear.bias.data.to(linear_dtype) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index a1dcc23dce..381bab618c 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1634,8 +1634,8 @@ def plan(self) -> ExecutionPlan: # missing in scenarios like unit test that does not enable memory planning, assume an # empty list. non_const_buffer_sizes=typing.cast( - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB... List[int], + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB... self.module.meta["non_const_buffer_sizes"], ), container_meta_type=self.container_meta_type, From a1f668d5f9ac64d046c392f0f2c3493f76c54675 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 25 Nov 2024 14:05:42 -0800 Subject: [PATCH 14/18] allow customized head_dim (#7065) Pull Request resolved: https://github.com/pytorch/executorch/pull/6872 This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/). Similar change in HF: https://github.com/huggingface/transformers/pull/32502 ghstack-source-id: 255340016 Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/) Co-authored-by: Lunwen He --- examples/models/llama/llama_transformer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 20b8b1e30d..3f8b8dd654 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -85,6 +85,7 @@ class ModelArgs: n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer hidden_dim: Optional[int] = None + head_dim: Optional[int] = None # Optional customized head_dim multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 @@ -142,6 +143,9 @@ def __post_init__(self): hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) self.hidden_dim = find_multiple(hidden_dim, multiple_of) + if self.head_dim is None: + self.head_dim = self.dim // self.n_heads + class KVCache(nn.Module): def __init__( @@ -272,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // self.n_heads + self.head_dim = args.head_dim self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim @@ -304,7 +308,7 @@ def __init__(self, args: ModelArgs, layer_id: int): ) self.SDPA = SDPA( kv_cache=self.kv_cache, - dim=self.dim, + dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, max_seq_len=self.max_seq_len, @@ -425,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs): self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim - self.head_dim = args.dim // args.n_heads + self.head_dim = args.head_dim self.attention = Attention(args, layer_id) if args.moe: self.block_sparse_moe = MOEFeedForward(args) @@ -472,7 +476,7 @@ def __init__(self, params: ModelArgs): precompute_freqs_cis, use_scaled=params.use_scaled_rope ) freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, + params.head_dim, ( params.max_seq_len # Normal llama2. if params.ffn_dim_multiplier is None From 20c8e8c14a6e9fd7a5d9fb10ee9ec46443e96bf8 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 25 Nov 2024 15:11:15 -0800 Subject: [PATCH 15/18] Fix test with resources (#7071) Fix test failure due to resources not handled correctly by ios tests. Differential Revision: [D66392647](https://our.internmc.facebook.com/intern/diff/D66392647/) ghstack-source-id: 255370795 Pull Request resolved: https://github.com/pytorch/executorch/pull/7062 Co-authored-by: Mengwei Liu --- .../llama/tokenizer/test/test_tiktoken.cpp | 15 ++++++- .../llm/tokenizer/test/test_bpe_tokenizer.cpp | 8 ++++ .../llm/tokenizer/test/test_tiktoken.cpp | 40 ++++++++++--------- 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/examples/models/llama/tokenizer/test/test_tiktoken.cpp b/examples/models/llama/tokenizer/test/test_tiktoken.cpp index b9309f9921..442da62174 100644 --- a/examples/models/llama/tokenizer/test/test_tiktoken.cpp +++ b/examples/models/llama/tokenizer/test/test_tiktoken.cpp @@ -14,6 +14,10 @@ #include +#ifdef EXECUTORCH_FB_BUCK +#include +#endif + using namespace ::testing; using ::example::Version; @@ -21,13 +25,20 @@ using ::executorch::extension::llm::Tokenizer; using ::executorch::runtime::Error; using ::executorch::runtime::Result; +static std::string get_resource_path(const std::string& name) { +#ifdef EXECUTORCH_FB_BUCK + return facebook::xplat::testing::getPathForTestResource("resources/" + name); +#else + return std::getenv("RESOURCES_PATH") + std::string("/") + name; +#endif +} + class MultimodalTiktokenV5ExtensionTest : public Test { public: void SetUp() override { executorch::runtime::runtime_init(); tokenizer_ = get_tiktoken_for_llama(Version::Multimodal); - modelPath_ = std::getenv("RESOURCES_PATH") + - std::string("/test_tiktoken_tokenizer.model"); + modelPath_ = get_resource_path("test_tiktoken_tokenizer.model"); } std::unique_ptr tokenizer_; diff --git a/extension/llm/tokenizer/test/test_bpe_tokenizer.cpp b/extension/llm/tokenizer/test/test_bpe_tokenizer.cpp index c553fe59f9..d207578de1 100644 --- a/extension/llm/tokenizer/test/test_bpe_tokenizer.cpp +++ b/extension/llm/tokenizer/test/test_bpe_tokenizer.cpp @@ -6,6 +6,9 @@ * LICENSE file in the root directory of this source tree. */ +#ifdef EXECUTORCH_FB_BUCK +#include +#endif #include #include #include @@ -23,8 +26,13 @@ class TokenizerExtensionTest : public Test { void SetUp() override { executorch::runtime::runtime_init(); tokenizer_ = std::make_unique(); +#ifdef EXECUTORCH_FB_BUCK + modelPath_ = facebook::xplat::testing::getPathForTestResource( + "resources/test_bpe_tokenizer.bin"); +#else modelPath_ = std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin"); +#endif } std::unique_ptr tokenizer_; diff --git a/extension/llm/tokenizer/test/test_tiktoken.cpp b/extension/llm/tokenizer/test/test_tiktoken.cpp index ce2a781aa1..3132170683 100644 --- a/extension/llm/tokenizer/test/test_tiktoken.cpp +++ b/extension/llm/tokenizer/test/test_tiktoken.cpp @@ -6,11 +6,13 @@ * LICENSE file in the root directory of this source tree. */ +#ifdef EXECUTORCH_FB_BUCK +#include +#endif #include #include #include #include -#include #include using namespace ::testing; @@ -47,6 +49,15 @@ static inline std::unique_ptr> _get_special_tokens() { } return special_tokens; } + +static inline std::string _get_resource_path(const std::string& name) { +#ifdef EXECUTORCH_FB_BUCK + return facebook::xplat::testing::getPathForTestResource("resources/" + name); +#else + return std::getenv("RESOURCES_PATH") + std::string("/") + name; +#endif +} + } // namespace class TiktokenExtensionTest : public Test { @@ -55,8 +66,7 @@ class TiktokenExtensionTest : public Test { executorch::runtime::runtime_init(); tokenizer_ = std::make_unique( _get_special_tokens(), kBOSTokenIndex, kEOSTokenIndex); - modelPath_ = std::getenv("RESOURCES_PATH") + - std::string("/test_tiktoken_tokenizer.model"); + modelPath_ = _get_resource_path("test_tiktoken_tokenizer.model"); } std::unique_ptr tokenizer_; @@ -144,44 +154,36 @@ TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) { } TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) { - auto invalidModelPath = - std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model"); - - Error res = tokenizer_->load(invalidModelPath.c_str()); + auto invalidModelPath = "./nonexistent.model"; + Error res = tokenizer_->load(invalidModelPath); EXPECT_EQ(res, Error::InvalidArgument); } TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) { - auto invalidModelPath = std::getenv("RESOURCES_PATH") + - std::string("/test_tiktoken_invalid_rank.model"); - + auto invalidModelPath = + _get_resource_path("test_tiktoken_invalid_rank.model"); Error res = tokenizer_->load(invalidModelPath.c_str()); EXPECT_EQ(res, Error::InvalidArgument); } TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) { - auto invalidModelPath = std::getenv("RESOURCES_PATH") + - std::string("/test_tiktoken_invalid_base64.model"); - + auto invalidModelPath = + _get_resource_path("test_tiktoken_invalid_base64.model"); Error res = tokenizer_->load(invalidModelPath.c_str()); EXPECT_EQ(res, Error::InvalidArgument); } TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) { - auto invalidModelPath = std::getenv("RESOURCES_PATH") + - std::string("/test_tiktoken_no_space.model"); - + auto invalidModelPath = _get_resource_path("test_tiktoken_no_space.model"); Error res = tokenizer_->load(invalidModelPath.c_str()); EXPECT_EQ(res, Error::InvalidArgument); } TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) { - auto invalidModelPath = - std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin"); - + auto invalidModelPath = _get_resource_path("test_bpe_tokenizer.bin"); Error res = tokenizer_->load(invalidModelPath.c_str()); EXPECT_EQ(res, Error::InvalidArgument); From ec68eb3270c0c4bb38c1743b009f99a8da6221a2 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Mon, 25 Nov 2024 15:11:58 -0800 Subject: [PATCH 16/18] Select python 3.1[0-2] on ExecuTorch nightly (#7064) * Select python 3.1[0-2] on ExecuTorch nightly * Another tweak * Should work now --- .github/workflows/build-wheels-linux.yml | 1 + .github/workflows/build-wheels-m1.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index a4132f6554..75f2c13fa8 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -27,6 +27,7 @@ jobs: test-infra-ref: main with-cuda: disabled with-rocm: disabled + python-versions: '["3.10", "3.11", "3.12"]' build: needs: generate-matrix diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 1dad6ad5ea..a160f5ab9b 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -27,6 +27,7 @@ jobs: test-infra-ref: main with-cuda: disabled with-rocm: disabled + python-versions: '["3.10", "3.11", "3.12"]' build: needs: generate-matrix From a35cb73c38079d738f5bea57bb4fbd9bbf4fa5d1 Mon Sep 17 00:00:00 2001 From: derekxu Date: Mon, 25 Nov 2024 17:06:33 -0800 Subject: [PATCH 17/18] Add logging dependency to OSS QNN logging Differential Revision: D66468388 Pull Request resolved: https://github.com/pytorch/executorch/pull/7059 --- backends/qualcomm/runtime/targets.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/qualcomm/runtime/targets.bzl b/backends/qualcomm/runtime/targets.bzl index 73d333f52d..ac65b442aa 100644 --- a/backends/qualcomm/runtime/targets.bzl +++ b/backends/qualcomm/runtime/targets.bzl @@ -28,6 +28,7 @@ def define_common_targets(): "//executorch/runtime/backend:interface", ], exported_deps = [ + "fbsource//third-party/toolchains:log", "//executorch/backends/qualcomm:schema", "//executorch/backends/qualcomm:qc_binary_info_schema", "//executorch/runtime/core:core", From 2967302c8834455bae7980c27f2634322f3d25b2 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Mon, 25 Nov 2024 21:42:21 -0800 Subject: [PATCH 18/18] Change weight to channel-packing in Conv1d Differential Revision: D66417572 Pull Request resolved: https://github.com/pytorch/executorch/pull/7057 --- .../vulkan/runtime/graph/ops/glsl/conv1d.glsl | 32 ++++++++++--------- .../runtime/graph/ops/impl/Convolution.cpp | 2 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl index e4880d8a22..1597b05e8d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl @@ -101,23 +101,25 @@ void main() { // "k" tracks the kernel's index for our input-kernel computation. // It reads out-of-bound zeros, but trying to avoid them complicates // for-loop conditions, which results in worse performance. - for (int k = 0; k < kernel_size; k += 4) { - // Since the weight tensor is width-packed, which is along the length - // dimension, we can batch-read four elements at a time. - const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c); - const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); - ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map); - sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum); - - in_pos[in_axis_map.x] += dilation; - sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum); + // The weight tensor is channel-packed. It may not be trival choice for + // performance reason since need to have more data fetch. The reason is + // for some sequence model, we found that the weight tensor + // (out_channel, in_channel / group, kernel) often has a large + // out_channel >> kernel, leading to non-optimal use of memory as the + // weight tensor gets very deep. As a mitigation, we use channel-packing + // for the weight tensor, yielding a 75% reduction in weight-tensor + // memory. + + // It is possible to further reduce the memory footprint by swapping the + // dimensions, using x extent for out_channel, and y for kernel. + for (int k = 0; k < kernel_size; k += 1) { + const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4); + const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); + VEC4_T weight = VEC4_T(weight_texel[out_c % 4]); - in_pos[in_axis_map.x] += dilation; - sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum); - - in_pos[in_axis_map.x] += dilation; - sum = fma(weight.wwww, load_texel(t_in, in_pos), sum); + ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map); + sum = fma(weight, load_texel(t_in, in_pos), sum); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 880d48e25e..1cdd7315f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -407,7 +407,7 @@ void add_conv1d_node( const ValueRef out, const bool clamp_out) { ValueRef arg_weight = prepack_standard( - graph, weight, graph.storage_type_of(out), utils::kWidthPacked); + graph, weight, graph.storage_type_of(out), utils::kChannelsPacked); ValueRef arg_bias = prepack_biases( graph, bias,