diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 83c573f6..8aa66b91 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -7,6 +7,7 @@ import hashlib import logging +import os import warnings from pathlib import Path from typing import List, Optional, Union @@ -23,7 +24,7 @@ from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform -from QEfficient.utils import constants, get_padding_shape_from_config +from QEfficient.utils import constants, create_and_dump_configs, get_padding_shape_from_config from QEfficient.utils.cache import to_hashable logger = logging.getLogger(__file__) @@ -319,7 +320,7 @@ def compile( for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - return self._compile( + self._compile( onnx_path, compile_dir, compile_only=True, @@ -333,6 +334,36 @@ def compile( aic_num_cores=num_cores, **compiler_options, ) + # Construct the qconfig json file path + qconfig_file_path = os.path.join(os.path.dirname(self.qpc_path), "qconfig.json") + huggingface_config = self.model.config.__dict__ + + pytorch_transforms = [cls.__name__ for cls in self._pytorch_transforms] + onnx_transforms = [cls.__name__ for cls in self._onnx_transforms] + + onnx_path = str(self.onnx_path) + specializations_file_path = str(os.path.join(os.path.dirname(self.qpc_path), "specializations.json")) + compile_dir = str(os.path.dirname(self.qpc_path)) + + create_and_dump_configs( + qconfig_file_path, + specializations_file_path, + huggingface_config, + pytorch_transforms, + onnx_transforms, + onnx_path, + compile_dir, + prefill_seq_len, + ctx_len, + batch_size, + full_batch_size, + num_devices, + num_cores, + mxfp6_matmul, + mxint8_kv_cache, + num_speculative_tokens, + ) + return self.qpc_path # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate def generate( diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index 2506b923..a7a32cfb 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -11,6 +11,7 @@ ) from QEfficient.utils._utils import ( # noqa: F401 check_and_assign_cache_dir, + create_and_dump_configs, get_num_layers_from_config, get_onnx_dir_name, get_padding_shape_from_config, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 2729267d..9d5ca88b 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -8,6 +8,7 @@ import json import os import subprocess +import xml.etree.ElementTree as ET from typing import Any, Dict, List, Optional, Tuple, Union import requests @@ -394,3 +395,70 @@ def create_json(file_path: str, json_data: object): json.dump(json_data, file, indent=4) except Exception as e: print(f"Failed to create JSON File {file_path}: {e}") + + +def create_and_dump_configs( + config_file_path, + specializations_file_path, + huggingface_config, + pytorch_transforms, + onnx_transforms, + onnx_path, + compile_dir, + prefill_seq_len, + ctx_len, + batch_size, + full_batch_size, + num_devices, + num_cores, + mxfp6_matmul, + mxint8_kv_cache, + num_speculative_tokens, +): + try: + # Parse the XML file + tree = ET.parse(Constants.SDK_APPS_XML) + root = tree.getroot() + # Try to find the base_version element and get its text + version = root.find(".//base_version").text + except (FileNotFoundError, ET.ParseError, AttributeError): + version = None + + # Ensure all objects in the configs dictionary are JSON serializable + def make_serializable(obj): + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, (list, tuple)): + return [make_serializable(item) for item in obj] + elif isinstance(obj, dict): + return {key: make_serializable(value) for key, value in obj.items()} + else: + return str(obj) # Convert non-serializable objects to strings + + configs = { + "huggingface_config": make_serializable(huggingface_config), + "qpc_config": { + "QEff_config": { + "pytorch_transforms": make_serializable(pytorch_transforms), + "onnx_transforms": make_serializable(onnx_transforms), + "onnx_path": onnx_path, + }, + "compilation_config": { + "apps_sdk_version": version, + "compile_dir": compile_dir, + "specializtions_file_path": specializations_file_path, + "prefill_seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "batch_size": batch_size, + "full_batch_size": full_batch_size, + "num_devices": num_devices, + "num_cores": num_cores, + "mxfp6_matmul": mxfp6_matmul, + "mxint8_kv_cache": mxint8_kv_cache, + "num_speculative_tokens": num_speculative_tokens, + }, + }, + } + # Dump the configs dictionary to a JSON file + with open(config_file_path, "w") as file: + json.dump(configs, file, indent=4) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index bfbac905..b484928a 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -63,6 +63,7 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 + SDK_APPS_XML = tree = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version. @dataclass