Skip to content

Commit

Permalink
Code for SDK configs Inclusion
Browse files Browse the repository at this point in the history
Signed-off-by: Abukhoyer Shaik <[email protected]>
  • Loading branch information
abukhoy committed Dec 20, 2024
1 parent dc2c509 commit b361ba9
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 2 deletions.
35 changes: 33 additions & 2 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import hashlib
import logging
import os
import warnings
from pathlib import Path
from typing import List, Optional, Union
Expand All @@ -23,7 +24,7 @@
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform
from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform
from QEfficient.utils import constants, get_padding_shape_from_config
from QEfficient.utils import constants, create_and_dump_configs, get_padding_shape_from_config
from QEfficient.utils.cache import to_hashable

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -319,7 +320,7 @@ def compile(
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype

return self._compile(
self._compile(
onnx_path,
compile_dir,
compile_only=True,
Expand All @@ -333,6 +334,36 @@ def compile(
aic_num_cores=num_cores,
**compiler_options,
)
# Construct the qconfig json file path
qconfig_file_path = os.path.join(os.path.dirname(self.qpc_path), "qconfig.json")
huggingface_config = self.model.config.__dict__

pytorch_transforms = [cls.__name__ for cls in self._pytorch_transforms]
onnx_transforms = [cls.__name__ for cls in self._onnx_transforms]

onnx_path = str(self.onnx_path)
specializations_file_path = str(os.path.join(os.path.dirname(self.qpc_path), "specializations.json"))
compile_dir = str(os.path.dirname(self.qpc_path))

create_and_dump_configs(
qconfig_file_path,
specializations_file_path,
huggingface_config,
pytorch_transforms,
onnx_transforms,
onnx_path,
compile_dir,
prefill_seq_len,
ctx_len,
batch_size,
full_batch_size,
num_devices,
num_cores,
mxfp6_matmul,
mxint8_kv_cache,
num_speculative_tokens,
)
return self.qpc_path

# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
def generate(
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from QEfficient.utils._utils import ( # noqa: F401
check_and_assign_cache_dir,
create_and_dump_configs,
get_num_layers_from_config,
get_onnx_dir_name,
get_padding_shape_from_config,
Expand Down
68 changes: 68 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import subprocess
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
Expand Down Expand Up @@ -394,3 +395,70 @@ def create_json(file_path: str, json_data: object):
json.dump(json_data, file, indent=4)
except Exception as e:
print(f"Failed to create JSON File {file_path}: {e}")


def create_and_dump_configs(
config_file_path,
specializations_file_path,
huggingface_config,
pytorch_transforms,
onnx_transforms,
onnx_path,
compile_dir,
prefill_seq_len,
ctx_len,
batch_size,
full_batch_size,
num_devices,
num_cores,
mxfp6_matmul,
mxint8_kv_cache,
num_speculative_tokens,
):
try:
# Parse the XML file
tree = ET.parse(Constants.SDK_APPS_XML)
root = tree.getroot()
# Try to find the base_version element and get its text
version = root.find(".//base_version").text
except (FileNotFoundError, ET.ParseError, AttributeError):
version = None

# Ensure all objects in the configs dictionary are JSON serializable
def make_serializable(obj):
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, (list, tuple)):
return [make_serializable(item) for item in obj]
elif isinstance(obj, dict):
return {key: make_serializable(value) for key, value in obj.items()}
else:
return str(obj) # Convert non-serializable objects to strings

configs = {
"huggingface_config": make_serializable(huggingface_config),
"qpc_config": {
"QEff_config": {
"pytorch_transforms": make_serializable(pytorch_transforms),
"onnx_transforms": make_serializable(onnx_transforms),
"onnx_path": onnx_path,
},
"compilation_config": {
"apps_sdk_version": version,
"compile_dir": compile_dir,
"specializtions_file_path": specializations_file_path,
"prefill_seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"batch_size": batch_size,
"full_batch_size": full_batch_size,
"num_devices": num_devices,
"num_cores": num_cores,
"mxfp6_matmul": mxfp6_matmul,
"mxint8_kv_cache": mxint8_kv_cache,
"num_speculative_tokens": num_speculative_tokens,
},
},
}
# Dump the configs dictionary to a JSON file
with open(config_file_path, "w") as file:
json.dump(configs, file, indent=4)
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Constants:
MAX_QPC_LIMIT = 30
MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
NUM_SPECULATIVE_TOKENS = 2
SDK_APPS_XML = tree = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version.


@dataclass
Expand Down

0 comments on commit b361ba9

Please sign in to comment.