From 85467631d133d1701d440b0db1d919208cc6854c Mon Sep 17 00:00:00 2001 From: shubhagr-quic Date: Tue, 7 Jan 2025 15:41:58 +0530 Subject: [PATCH] QNN Compilation support in High Level APIs of QEFFAutoModelForCausalLM (#187) * QNN Compilation support in QEFFAutoModelForCausalLM High Level APIs 1. Modified qnn_compiler.py to include qnn_binary_dir path to support hash suffix in qpc directory name. 2. Added tests/qnn_tests/test_causal_lm_models_qnn.py for unit testing. 3. Modified qnn_config.json to enable compiler_enable_depth_first if qnn_config file is passed. 4. Added _qnn_compile function in QEFFBaseModel to support QNN Compilation. Signed-off-by: Shubham Agrawal * Increased Non-CLI Non-QAIC Tests timeout Signed-off-by: Rishin Raj * Added sudo for executing QNN Docker commands Signed-off-by: Rishin Raj --------- Signed-off-by: Shubham Agrawal Signed-off-by: Rishin Raj Co-authored-by: Rishin Raj --- QEfficient/base/modeling_qeff.py | 98 ++++++++++ QEfficient/compile/compile_helper.py | 2 +- QEfficient/compile/qnn_compiler.py | 48 ++--- QEfficient/compile/qnn_config.json | 1 + .../transformers/models/modeling_auto.py | 68 ++++--- docs/source/hl_api.md | 9 +- scripts/Jenkinsfile | 26 ++- tests/qnn_tests/test_causal_lm_models_qnn.py | 172 ++++++++++++++++++ 8 files changed, 375 insertions(+), 49 deletions(-) create mode 100644 tests/qnn_tests/test_causal_lm_models_qnn.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 82fc42215..2760cf52f 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -21,8 +21,10 @@ from QEfficient.base.onnx_transforms import OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform +from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import constants +from QEfficient.utils._utils import load_json from QEfficient.utils.cache import QEFF_HOME, to_hashable logger = logging.getLogger(__name__) @@ -319,3 +321,99 @@ def _compile( self.qpc_path = qpc_path return qpc_path + + def _qnn_compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + specializations: Optional[List[Dict[str, int]]] = None, + prefill_seq_len: int = 32, + ctx_len: int = 128, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + mdp_ts_num_devices: int = 1, + num_cores: int = 16, + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + qnn_config: Optional[str] = None, + ) -> str: + """ + Interface for QNN compiler + + Args: + :onnx_path (str): Onnx file to compile + :compile_dir (str): Directory path to compile the qpc. A suffix is added to the directory path to avoid reusing same qpc for different parameters. + :specializations (list): List of specializations to compile for + :prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``. + :ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``. + :batch_size (int, optional): Batch size. ``Defaults to 1``. + :full_batch_size (int, optional): Continuous batching batch size. + :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. + :num_cores (int): Number of cores used to compile the model. + :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to True``. + :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. + :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` + """ + if onnx_path is None and self.onnx_path is None: + self.export() + + onnx_path = Path(onnx_path or self.onnx_path) + compile_dir = Path(compile_dir or onnx_path.parent) + qpc_path = compile_dir / "qpc" + if not onnx_path.is_file(): + raise FileNotFoundError(f"ONNX file not found at: {onnx_path}") + + compile_hash = hashlib.sha256(to_hashable("qnn")) + + if specializations is not None: + compile_hash.update(to_hashable(specializations)) + + if qnn_config is not None: + qnn_config_values = load_json(qnn_config) + compile_hash.update(to_hashable(qnn_config_values)) + + if mdp_ts_num_devices > 1: + compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices})) + + compile_hash.update(to_hashable({"num_cores": num_cores})) + compile_hash.update(to_hashable({"mxfp6_matmul": mxfp6_matmul})) + compile_hash.update(to_hashable({"mxint8_kv_cache": mxint8_kv_cache})) + + # Check if already compiled + compile_hash = compile_hash.hexdigest()[:16] + qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash) + if qpc_path.is_dir(): + if (qpc_path / "programqpc.bin").is_file(): + self.qpc_path = qpc_path + return qpc_path + # Probably compilation failure last time, delete directory to start over + shutil.rmtree(qpc_path) + + # Write specializations.json file + if specializations is not None: + specializations_json = compile_dir / "specializations.json" + with open(specializations_json, "w") as fp: + json.dump( + {"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]}, + fp, + indent=4, + ) + + qnn_compile( + onnx_path=onnx_path, + qpc_base_path=compile_dir, + num_cores=num_cores, + device_group=list(range(mdp_ts_num_devices)), + batch_size=batch_size, + prompt_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6=mxfp6_matmul, + mxint8=mxint8_kv_cache, + full_batch_size=full_batch_size, + qnn_config=qnn_config, + qnn_binary_dir=qpc_path, + ) + + self.qpc_path = qpc_path + return qpc_path diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index ba7c90a97..ae86b493a 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -183,7 +183,7 @@ def compile( if enable_qnn: qpc_path = qnn_compile( onnx_path=onnx_path, - qpc_path=qpc_path, + qpc_base_path=qpc_path, num_cores=num_cores, batch_size=batch_size, prompt_len=prompt_len, diff --git a/QEfficient/compile/qnn_compiler.py b/QEfficient/compile/qnn_compiler.py index 307deca19..ad5da9767 100644 --- a/QEfficient/compile/qnn_compiler.py +++ b/QEfficient/compile/qnn_compiler.py @@ -25,7 +25,7 @@ class QNN: def __init__( self, onnx_path: str, - qpc_path: str, + qpc_base_path: str, num_cores: int, custom_io_path: str, device_group: Optional[List[int]] = None, @@ -37,10 +37,11 @@ def __init__( compiler_mxfp6_matmul_weights: bool = True, qnn_target: str = QnnConstants.TARGET, qnn_config_path: Optional[str] = None, + qnn_binary_dir: Optional[str] = None, **kwargs, ) -> None: self.onnx_path = onnx_path - self.qpc_path = qpc_path + self.qpc_base_path = qpc_base_path self.num_cores = num_cores self.device_group = device_group self.compiler_enable_depth_first = compiler_enable_depth_first @@ -50,8 +51,9 @@ def __init__( self.ctx_len = ctx_len self.compiler_mxfp6_matmul_weights = compiler_mxfp6_matmul_weights self.qnn_config_path = qnn_config_path + self.qnn_binary_dir = qnn_binary_dir self.custom_io_path = custom_io_path - self.dlc_model_path = os.path.join(qpc_path, f"{QnnConstants.MODEL_NAME}.dlc") + self.dlc_model_path = os.path.join(qpc_base_path, f"{QnnConstants.MODEL_NAME}.dlc") self.qnn_target = qnn_target self.qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) if not self.qnn_sdk_path: @@ -118,7 +120,7 @@ def create_qnn_tensor_slicing_json(self) -> str: } ], } - tensor_slicing_json_path = os.path.join(self.qpc_path, "tensor_slicing.json") + tensor_slicing_json_path = os.path.join(self.qpc_base_path, "tensor_slicing.json") create_json(tensor_slicing_json_path, tensor_slicing) return tensor_slicing_json_path @@ -157,7 +159,7 @@ def create_qnn_compile_backend_json(self) -> str: for key, value in self.qnn_config[QnnConstants.QNN_COMPILATION_BACKEND_STR].items(): qnn_compile_backend[key] = value - qnn_compile_backend_json_path = os.path.join(self.qpc_path, "qnn_compile_backend.json") + qnn_compile_backend_json_path = os.path.join(self.qpc_base_path, "qnn_compile_backend.json") create_json(qnn_compile_backend_json_path, qnn_compile_backend) return qnn_compile_backend_json_path @@ -177,13 +179,13 @@ def create_qnn_compiler_config_json(self) -> str: ), } } - qnn_compiler_config_json_path = os.path.join(self.qpc_path, "qnn_compiler_config.json") + qnn_compiler_config_json_path = os.path.join(self.qpc_base_path, "qnn_compiler_config.json") create_json(qnn_compiler_config_json_path, qnn_compiler_config) return qnn_compiler_config_json_path def compile(self) -> str: """ - Compiles the given ``ONNX`` model during object creation using QNN compiler and saves the compiled ``qpc`` package at ``qpc_path``. + Compiles the given ``ONNX`` model during object creation using QNN compiler and saves the compiled ``qpc`` package at ``qnn_binary_dir``. - Creates convertor command and convert onnx model to model.dlc using qairt-convertor - command line arguments and qnn_config.json (if provided) are used to create qnn_compiler_config.json for context-binary-generator - model.dlc from convertor stage is passed into context-binary-generator command to create programqpc.bin. @@ -197,20 +199,21 @@ def compile(self) -> str: and self.qnn_config[QnnConstants.SKIP_QNN_CONVERTOR_STEP_STR] ): converter_cmd = self.converter() - execute_command("convertor", converter_cmd, self.qpc_path) + execute_command("convertor", converter_cmd, self.qpc_base_path) if not os.path.isfile(self.dlc_model_path): raise FileNotFoundError( - f"file {self.dlc_model_path} needs to exist in the qpc_path{self.qpc_path}. Please rerun infer/compile Api" + f"file {self.dlc_model_path} needs to exist in the qpc_base_path{self.qpc_base_path}. Please rerun infer/compile Api" ) - self.qnn_binary_dir = os.path.join(self.qpc_path, "qpcs") + if self.qnn_binary_dir is None: + self.qnn_binary_dir = os.path.join(self.qpc_base_path, "qpcs") if os.path.isdir(self.qnn_binary_dir): shutil.rmtree(self.qnn_binary_dir) os.makedirs(self.qnn_binary_dir) ctx_bin_cmd = self.generate_context_binary() - execute_command("context_binary", ctx_bin_cmd, self.qpc_path) + execute_command("context_binary", ctx_bin_cmd, self.qpc_base_path) print("\n===================== Compilation Done! =====================\n") return self.qnn_binary_dir @@ -221,7 +224,7 @@ def converter(self) -> str: IMMUTABLE parameters which can not be overridden by the user using qnn_config.json: :input_network (str): Generated ``ONNX`` Model Path. - :output_path (str): Path to generated DLC file, which is provided qpc_path/model.dlc + :output_path (str): Path to generated DLC file, which is provided qpc_base_path/model.dlc :io_config (str): Path to custom_io_config.yaml file created using GenerateQNNnetworkSpecializationconfig.py :float_bias_bitwidth (int): Bitwidth to use for float bias tensor :float_bitwidth (int): Converts the graph to the specified float bitwidth, either 32 or 16(Default). @@ -255,8 +258,8 @@ def generate_context_binary(self) -> str: IMMUTABLE parameters which can not be modified by the user using qnn_config.json: :binary_file (str): QNN Binary Graph name to be generated (qnngraph.serialized). - :backend_binary (str): Path to generated QPC binary file, which is provided qpc_path/qpcs/programqpc.bin - :output_dir (str): Path to store generated Binaries (qpc_path/qpcs/). + :backend_binary (str): Generated QPC binary file name, which is provided programqpc.bin + :output_dir (str): Path to store generated Binaries (qnn_binary_dir). :model (str): Path to the file containing a QNN network. :dlc_path (str): Path to DLC file generated by QNN-Convertor. :config_file(str): Path to created qnn_compiler_config.json containing qnn_compile_backend.json & shared_library_path. @@ -305,7 +308,7 @@ def generate_profiling(self): def compile( onnx_path: str, - qpc_path: str, + qpc_base_path: str, num_cores: int, device_group: Optional[List[int]] = None, aic_enable_depth_first: bool = False, @@ -318,16 +321,17 @@ def compile( allow_mxint8_mdp_io: Optional[bool] = False, full_batch_size=None, qnn_config: Optional[str] = None, + qnn_binary_dir: Optional[str] = None, **kwargs, ) -> str: """ - Compiles the given ``ONNX`` model using QNN compiler and saves the compiled ``qpc`` package at ``qpc_path``. + Compiles the given ``ONNX`` model using QNN compiler and saves the compiled ``qpc`` package at ``qnn_binary_dir``. Generates model.dlc during convertor stage, qnn_compile_backend.json for backend parameters of context-binary-generator. Generates tensor-slicing configuration if multiple devices are passed in ``device_group``. ``Mandatory`` Args: :onnx_path (str): Generated ``ONNX`` Model Path. - :qpc_path (str): Path for saving compiled qpc binaries. + :qpc_base_path (str): base directory for QNN compilation config & binary file. :num_cores (int): Number of cores to compile the model on. ``Optional`` Args: :device_group (List[int]): Used for finding the number of devices to compile for. @@ -341,6 +345,7 @@ def compile( :allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic ``Defaults to False.`` :mxint8 (bool): Compress Present/Past KV to ``MXINT8`` using ``CustomIO`` config. ``Defaults to False.`` :qnn_config (str): Path to ``qnn_config.json`` file (formatted as a string). ``Defaults to None.`` + :qnn_binary_dir (str): Path for saving qnn binaries. Returns: :str: Path to compiled ``qpc`` package. @@ -357,11 +362,11 @@ def compile( if mxint8: logger.warning("QNN doesn't support mxint8. Bypassing the value passed for mxint8") - os.makedirs(qpc_path, exist_ok=True) + os.makedirs(qpc_base_path, exist_ok=True) # Created custom_io_config.yaml file for QNN-Convertor stage. # TODO To make custom_io_config.yaml configurable as not all models need it. - custom_io_file_path = os.path.join(qpc_path, "custom_io_config.yaml") + custom_io_file_path = os.path.join(qpc_base_path, "custom_io_config.yaml") fetch_nodes_info( onnx_graph_path=onnx_path, batch_size=batch_size, @@ -373,12 +378,12 @@ def compile( if not os.path.isfile(custom_io_file_path): raise FileNotFoundError( - f"file {custom_io_file_path} needs to exist in the qpc_path for Compilation. Please rerun infer/compile Api" + f"file {custom_io_file_path} needs to exist in the qpc_base_path for Compilation. Please rerun infer/compile Api" ) qnn_obj = QNN( onnx_path=onnx_path, - qpc_path=qpc_path, + qpc_base_path=qpc_base_path, num_cores=num_cores, device_group=device_group, qnn_config_path=qnn_config, @@ -389,6 +394,7 @@ def compile( prompt_len=prompt_len, ctx_len=ctx_len, compiler_mxfp6_matmul_weights=mxfp6, + qnn_binary_dir=qnn_binary_dir, ) compiled_binary_path = qnn_obj.compile() diff --git a/QEfficient/compile/qnn_config.json b/QEfficient/compile/qnn_config.json index 18f12dd9a..369b55981 100644 --- a/QEfficient/compile/qnn_config.json +++ b/QEfficient/compile/qnn_config.json @@ -3,6 +3,7 @@ "context_binary_generator_args_extension": "--log_level debug", "qnn_compilation_backend": { + "compiler_enable_depth_first": true, "compiler_printDDRStats": false, "compiler_printPerfMetrics": false, "compiler_stat_level": 10 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 83c573f6d..f565cbca9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -245,6 +245,8 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + enable_qnn: bool = False, + qnn_config: Optional[str] = None, **compiler_options, ) -> str: """ @@ -266,6 +268,8 @@ def compile( :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. + :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` + :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` Returns: :str: Path of the compiled ``qpc`` package. @@ -311,28 +315,48 @@ def compile( decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None specializations.append(decode_specialization) - # Custom IO - custom_io = {} - kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" - for suffix in ["", "_RetainedState"]: - for i in range(self.num_layers): - for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - - return self._compile( - onnx_path, - compile_dir, - compile_only=True, - retained_state=True, - specializations=specializations, - convert_to_fp16=True, - mxfp6_matmul=mxfp6_matmul, - custom_io=custom_io, - mdp_ts_num_devices=num_devices, - num_speculative_tokens=num_speculative_tokens, - aic_num_cores=num_cores, - **compiler_options, - ) + if enable_qnn: + if compiler_options: + logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only") + + qpc_path = self._qnn_compile( + onnx_path, + compile_dir, + specializations=specializations, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + mdp_ts_num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + qnn_config=qnn_config, + ) + else: + # Custom IO + custom_io = {} + kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + + qpc_path = self._compile( + onnx_path, + compile_dir, + compile_only=True, + retained_state=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + custom_io=custom_io, + mdp_ts_num_devices=num_devices, + num_speculative_tokens=num_speculative_tokens, + aic_num_cores=num_cores, + **compiler_options, + ) + return qpc_path # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate def generate( diff --git a/docs/source/hl_api.md b/docs/source/hl_api.md index 558965e76..5662b23a7 100644 --- a/docs/source/hl_api.md +++ b/docs/source/hl_api.md @@ -47,6 +47,13 @@ import QEfficient base_path, onnx_model_path = QEfficient.export(model_name="gpt2") qpc_path = QEfficient.compile(onnx_path=onnx_model_path, qpc_path=os.path.join(base_path, "qpc"), num_cores=14, device_group=[0]) + + # Similarly for QPC Compiled via QNN SDK + # 1. export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder + # 2. add --enable_qnn in the command + # 3. An optional config file can be provided via qnn_config if user wish to override the default parameters. + qpc_path_qnn = QEfficient.compile(onnx_path=onnx_model_path, qpc_path=os.path.join(base_path, "qpc"), num_cores=14, device_group=[0], + enable_qnn=True, qnn_config = "QEfficient/compile/qnn_config.json") .. deprecated:: This function will be deprecated in version 1.19, please use QEFFAutoModelForCausalLM.compile instead ``` @@ -54,6 +61,6 @@ ```{eval-rst} .. automodule:: QEfficient.generation.text_generation_inference :members: - :show-inheritance: + :show-inheritance: :exclude-members: latency_stats_bertstyle,cloud_ai_100_exec_kv_helper ``` diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index f05540c8a..c9f17a73e 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -32,7 +32,7 @@ pipeline { parallel { stage('Run Non-CLI Non-QAIC Tests') { steps { - timeout(time: 10, unit: 'MINUTES') { + timeout(time: 25, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -56,7 +56,7 @@ pipeline { mkdir -p $PWD/Non_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic && - pytest tests -m '(not cli) and (on_qaic)' -n 4 --junitxml=tests/tests_log2.xml && + pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 4 --junitxml=tests/tests_log2.xml && deactivate" ''' } @@ -84,7 +84,7 @@ pipeline { steps { timeout(time: 30, unit: 'MINUTES') { sh ''' - docker exec ${BUILD_TAG} bash -c " + sudo docker exec ${BUILD_TAG} bash -c " source /qnn_sdk/bin/envsetup.sh && source /qnn_sdk/bin/envcheck -c && cd /efficient-transformers && @@ -93,7 +93,25 @@ pipeline { export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Qnn_cli && pytest tests -m '(cli and qnn)' --junitxml=tests/tests_log4.xml && - junitparser merge tests/tests_log1.xml tests/tests_log2.xml tests/tests_log3.xml tests/tests_log4.xml tests/tests_log.xml && + deactivate" + ''' + } + } + } + stage('QNN Non-CLI Tests') { + steps { + timeout(time: 60, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + source /qnn_sdk/bin/envsetup.sh && + source /qnn_sdk/bin/envcheck -c && + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Qnn_non_cli && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Qnn_non_cli && + pytest tests -m '(not cli) and (qnn) and (on_qaic)' --junitxml=tests/tests_log5.xml && + junitparser merge tests/tests_log1.xml tests/tests_log2.xml tests/tests_log3.xml tests/tests_log4.xml tests/tests_log5.xml tests/tests_log.xml && deactivate" ''' } diff --git a/tests/qnn_tests/test_causal_lm_models_qnn.py b/tests/qnn_tests/test_causal_lm_models_qnn.py new file mode 100644 index 000000000..50ad3551d --- /dev/null +++ b/tests/qnn_tests/test_causal_lm_models_qnn.py @@ -0,0 +1,172 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import numpy as np +import pytest +from transformers import AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.utils import hf_download +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.run_utils import ApiRunner + +test_models = [ + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "gpt2", +] + + +def load_causal_lm_model(model_config): + """ + Function to load model from huggingface and transform to KV model + -------- + + :model_config: Dict + + :return model_hf, params + """ + model_path = hf_download( + repo_id=model_config["model_name"], + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + num_hidden_layers=model_config["n_layer"], + attn_implementation="eager", + low_cpu_mem_usage=False, + ) # Run models for single layers only + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, +): + """ + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_causal_lm_model(model_config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + qeff_model = QEFFAutoModelForCausalLM(model_hf) + + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + assert ( + pytorch_hf_tokens == pytorch_kv_tokens + ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + + onnx_model_path = qeff_model.export() + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path) + + assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + _ = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + enable_qnn=True, + ) + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size + gen_len = ort_tokens.shape[-1] + assert ( + ort_tokens == cloud_ai_100_tokens[:, :gen_len] + ).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output." + + # testing for CB models + model_hf, _ = load_causal_lm_model(model_config) + full_batch_size = 4 + fbs_prompts = Constants.INPUT_STR * 4 + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + fbs_prompts, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + full_batch_size, + ) + + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) + pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) + + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True) + onnx_model_path = qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + _ = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + full_batch_size=full_batch_size, + enable_qnn=True, + ) + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + + assert all( + [ + all(pt_token[:24] == cloud_token[:24]) + for pt_token, cloud_token in zip(pytorch_hf_tokens, exec_info_fbs.generated_ids) + ] + ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.parametrize("model_name", test_models) +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "microsoft/Phi-3-mini-4k-instruct": + n_layer = 2 # test only 2 layer models + else: + n_layer = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)