diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/__init__.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/__init__.py index 3ab97fce7..c03e8d821 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/__init__.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/__init__.py @@ -14,5 +14,5 @@ # Refer from https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval -from evaluation.accuracy import cli_evaluate as evaluate -from evaluation.utils import LMEvalParser +from evaluation.accuracy import cli_evaluate as evaluate # noqa: F401 +from evaluation.utils import LMEvalParser # noqa: F401 diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/accuracy.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/accuracy.py index 5608307f6..b10eebfc2 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/accuracy.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/accuracy.py @@ -31,7 +31,7 @@ def _handle_non_serializable(o): - if isinstance(o, np.int64) or isinstance(o, np.int32): + if isinstance(o, (np.int32, np.int64)): return int(o) elif isinstance(o, set): return list(o) @@ -45,7 +45,7 @@ def cli_evaluate(args) -> None: eval_logger = lm_eval.utils.eval_logger eval_logger.setLevel(getattr(logging, f"{args.verbosity}")) - eval_logger.info(f"Verbosity set to {args.verbosity}") + eval_logger.info("Verbosity set to %s", args.verbosity) os.environ["TOKENIZERS_PARALLELISM"] = "false" if args.predict_only: @@ -54,48 +54,45 @@ def cli_evaluate(args) -> None: raise ValueError("Specify --output_path if providing --log_samples or --predict_only") if args.include_path is not None: - eval_logger.info(f"Including path: {args.include_path}") + eval_logger.info(f"Including path: {args.include_path}") # noqa: G004 task_manager = lm_eval.tasks.TaskManager(args.verbosity, include_path=args.include_path) if args.limit: - eval_logger.warning( - " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." - ) + eval_logger.warning(" --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") if args.tasks is None: eval_logger.error("Need to specify task to evaluate.") sys.exit() elif args.tasks == "list": - eval_logger.info("Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))) + eval_logger.info("Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))) # noqa: G001 sys.exit() + elif os.path.isdir(args.tasks): + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = lm_eval.utils.load_yaml_config(yaml_file) + task_names.append(config) else: - if os.path.isdir(args.tasks): - task_names = [] - yaml_path = os.path.join(args.tasks, "*.yaml") - for yaml_file in glob.glob(yaml_path): - config = lm_eval.utils.load_yaml_config(yaml_file) + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = lm_eval.utils.load_yaml_config(task) task_names.append(config) - else: - task_list = args.tasks.split(",") - task_names = task_manager.match_tasks(task_list) - for task in [task for task in task_list if task not in task_names]: - if os.path.isfile(task): - config = lm_eval.utils.load_yaml_config(task) - task_names.append(config) - task_missing = [ - task for task in task_list if task not in task_names and "*" not in task - ] # we don't want errors if a wildcard ("*") task name was used - - if task_missing: - missing = ", ".join(task_missing) - eval_logger.error( - f"Tasks were not found: {missing}\n" - f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", - ) - raise ValueError( - f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks," - + " or '--verbosity DEBUG' to troubleshoot task registration issues." - ) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" # noqa: G004 + f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks," # noqa: ISC003 + + " or '--verbosity DEBUG' to troubleshoot task registration issues." + ) if args.output_path: path = Path(args.output_path) @@ -104,7 +101,7 @@ def cli_evaluate(args) -> None: raise FileExistsError(f"File already exists at {path}") output_path_file = path.joinpath(DEFAULT_RESULTS_FILE) if output_path_file.is_file(): - eval_logger.warning(f"File {output_path_file} already exists. Results will be overwritten.") + eval_logger.warning(f"File {output_path_file} already exists. Results will be overwritten.") # noqa: G004 # if path json then get parent dir elif path.suffix in (".json", ".jsonl"): output_path_file = path @@ -118,7 +115,7 @@ def cli_evaluate(args) -> None: os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code) args.model_args = args.model_args + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}" - eval_logger.info(f"Selected Tasks: {task_names}") + eval_logger.info(f"Selected Tasks: {task_names}") # noqa: G004 eval_logger.info("Loading selected tasks...") request_caching_args = evaluator.request_caching_arg_to_dict(cache_requests=args.cache_requests) @@ -164,14 +161,14 @@ def cli_evaluate(args) -> None: wandb_logger.log_eval_result() if args.log_samples: wandb_logger.log_eval_samples(samples) - except Exception as e: - eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + except Exception as e: # noqa: BLE001 + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") # noqa: G004 if args.output_path: output_path_file.open("w", encoding="utf-8").write(dumped) if args.log_samples: - for task_name, config in results["configs"].items(): + for task_name, config in results["configs"].items(): # noqa: B007 output_name = "{}_{}".format(re.sub("/|=", "__", args.model_args), task_name) filename = path.joinpath(f"{output_name}.jsonl") samples_dumped = json.dumps( diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/evaluator.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/evaluator.py index 2b4a8b2d2..4246fdcff 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/evaluator.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/evaluator.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from __future__ import annotations import collections import itertools import logging import random import time -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING import lm_eval.api.metrics import lm_eval.api.registry @@ -40,30 +40,30 @@ @lm_eval.utils.positional_deprecated def simple_evaluate( model, - model_args: Optional[Union[str, dict, object]] = None, - tasks: Optional[List[Union[str, dict, object]]] = None, - num_fewshot: Optional[int] = None, - batch_size: Optional[int] = None, - max_batch_size: Optional[int] = None, - provider: Optional[str] = None, - use_cache: Optional[str] = None, + model_args: str | dict | object | None = None, + tasks: list[str | dict | object] | None = None, + num_fewshot: int | None = None, + batch_size: int | None = None, + max_batch_size: int | None = None, + provider: str | None = None, + use_cache: str | None = None, cache_requests: bool = False, rewrite_requests_cache: bool = False, delete_requests_cache: bool = False, - limit: Optional[Union[int, float]] = None, + limit: int | float | None = None, bootstrap_iters: int = 100000, check_integrity: bool = False, write_out: bool = False, log_samples: bool = True, - gen_kwargs: Optional[str] = None, - task_manager: Optional[lm_eval.tasks.TaskManager] = None, + gen_kwargs: str | None = None, + task_manager: lm_eval.tasks.TaskManager | None = None, verbosity: str = "INFO", predict_only: bool = False, random_seed: int = 0, numpy_random_seed: int = 1234, torch_random_seed: int = 1234, - user_model: Optional[object] = None, - tokenizer: Optional[object] = None, + user_model: object | None = None, + tokenizer: object | None = None, ): """Instantiate and evaluate a model on a list of tasks. @@ -178,7 +178,7 @@ def simple_evaluate( elif isinstance(user_model, optimum.onnxruntime.ORTModelForSeq2SeqLM): model_id = "optimum/t5-small" lm_eval.utils.eval_logger.info( - "We use '{}' to build `LM` instance, the actually run model is user_model you passed.".format(model_id) + f"We use '{model_id}' to build `LM` instance, the actually run model is user_model you passed." ) lm = lm_eval.api.registry.get_model(model).create_from_arg_string( "pretrained=" + model_id, @@ -193,7 +193,7 @@ def simple_evaluate( if tokenizer is not None: lm.tokenizer = tokenizer else: - assert False, "Please provide tokenizer in evaluation function" + raise AssertionError("Please provide tokenizer in evaluation function") elif isinstance(model_args, dict): lm = lm_eval.api.registry.get_model(model).create_from_arg_obj( model_args, @@ -231,7 +231,7 @@ def simple_evaluate( task_manager = lm_eval.tasks.TaskManager(verbosity) task_dict = lm_eval.tasks.get_task_dict(tasks, task_manager) - for task_name in task_dict.keys(): + for task_name in task_dict: task_obj = task_dict[task_name] if isinstance(task_obj, tuple): _, task_obj = task_obj @@ -255,7 +255,7 @@ def simple_evaluate( if num_fewshot is not None: if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: lm_eval.utils.eval_logger.info( - f"num_fewshot has been set to 0 for {task_name} in its config." + f"num_fewshot has been set to 0 for {task_name} in its config." # noqa: ISC003 + "Manual configuration will be ignored." ) else: @@ -263,10 +263,8 @@ def simple_evaluate( f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" ) task_obj.set_config(key="num_fewshot", value=num_fewshot) - else: - # if num_fewshot not provided, and the task does not define a default one, default to 0 - if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: - task_obj.set_config(key="num_fewshot", value=0) + elif (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: + task_obj.set_config(key="num_fewshot", value=0) if check_integrity: lm_eval.evaluator_utils.run_task_tests(task_list=tasks) @@ -307,7 +305,7 @@ def simple_evaluate( results["date"] = start_date try: lm_eval.logging_utils.add_env_info(results) # additional environment info to results - except: + except: # noqa: E722 lm_eval.utils.eval_logger.info("get env info failed.") return results else: @@ -316,12 +314,12 @@ def simple_evaluate( @lm_eval.utils.positional_deprecated def evaluate( - lm: "lm_eval.api.model.LM", + lm: lm_eval.api.model.LM, task_dict, - limit: Optional[int] = None, + limit: int | None = None, cache_requests: bool = False, rewrite_requests_cache: bool = False, - bootstrap_iters: Optional[int] = 100000, + bootstrap_iters: int | None = 100000, write_out: bool = False, log_samples: bool = True, verbosity: str = "INFO", @@ -362,9 +360,7 @@ def evaluate( # get lists of group hierarchy and each type of request task_hierarchy, eval_tasks = lm_eval.evaluator_utils.get_task_list(task_dict) if not log_samples: - if not all( - "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() for task_output in eval_tasks - ): + if not all("bypass" not in getattr(task_output.task, "_metric_fn_list", {}) for task_output in eval_tasks): raise ValueError("log_samples must be True for 'bypass' metric-only tasks") for task_output in eval_tasks: task: lm_eval.tasks.Task = task_output.task @@ -420,8 +416,8 @@ def evaluate( if lm.world_size > 1: lm.accelerator.wait_for_everyone() - RANK = lm.rank - WORLD_SIZE = lm.world_size + RANK = lm.rank # noqa: N806 + WORLD_SIZE = lm.world_size # noqa: N806 ### Postprocess outputs ### # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) for task_output in eval_tasks: @@ -439,7 +435,7 @@ def evaluate( for instances in instances_by_doc_id.values(): instances.sort(key=lambda x: x.idx) # iterate over different filters used - for filter_key in task.instances[0].filtered_resps.keys(): + for filter_key in task.instances[0].filtered_resps: doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE) for doc_id, doc in doc_iterator: requests = instances_by_doc_id[doc_id] @@ -506,7 +502,7 @@ def evaluate( { key for task in task_list - for key in results[task].keys() + for key in results[task] if "_stderr" not in key and key not in ["alias", "samples"] } ) @@ -537,8 +533,8 @@ def evaluate( groups_agg = collections.defaultdict(dict) all_tasks_list = list(task_hierarchy.keys()) while True: - add_tasks_list = list(k for k in results_agg.keys()) - left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list))) + add_tasks_list = list(results_agg.keys()) + left_tasks_list = sorted(set(all_tasks_list) - set(add_tasks_list)) if len(left_tasks_list) == 0: break diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/__init__.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/__init__.py index 8a19e05fd..e468cd437 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/__init__.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/__init__.py @@ -11,13 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +__all__ = ["huggingface"] from evaluation.models import huggingface -# TODO: implement __all__ - - try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/huggingface.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/huggingface.py index b682e4f47..da4b032db 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/huggingface.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/huggingface.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from __future__ import annotations import copy import os import tempfile -from typing import List, Literal, Optional, Tuple, Union +from typing import Literal import accelerate import huggingface_hub @@ -29,7 +29,7 @@ import optimum.version import packaging.version import torch -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 import tqdm import transformers @@ -40,55 +40,47 @@ class HFLM(lm_eval.api.model.TemplateLM): """An abstracted Huggingface model class. Enables usage with both models of `optimum.onnxruntime.ORTModelForCausalLM` and `optimum.onnxruntime.ORTModelForSeq2SeqLM` classes. - """ + """ # noqa: D205 AUTO_MODEL_CLASS = None _DEFAULT_MAX_LENGTH = 2048 def __init__( self, - pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2", - backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", + pretrained: str | transformers.PreTrainedModel | None = "gpt2", + backend: Literal["default", "causal", "seq2seq"] | None = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) - revision: Optional[str] = "main", - tokenizer: Optional[ - Union[ - str, - transformers.PreTrainedTokenizer, - transformers.PreTrainedTokenizerFast, - ] - ] = None, - truncation: Optional[bool] = False, + revision: str | None = "main", + tokenizer: str | transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast | None = None, + truncation: bool | None = False, logits_cache: bool = True, - max_length: Optional[int] = None, - provider: Optional[str] = "CPUExecutionProvider", - batch_size: Optional[Union[int, str]] = 1, - max_batch_size: Optional[int] = 64, - trust_remote_code: Optional[bool] = False, - use_fast_tokenizer: Optional[bool] = True, - add_bos_token: Optional[bool] = False, - **kwargs, + max_length: int | None = None, + provider: str | None = "CPUExecutionProvider", + batch_size: int | str | None = 1, + max_batch_size: int | None = 64, + trust_remote_code: bool | None = False, + use_fast_tokenizer: bool | None = True, + add_bos_token: bool | None = False, + **kwargs, # noqa: ARG002 ) -> None: super().__init__() available_providers = onnxruntime.get_available_providers() - assert provider in available_providers, "{} is not available.".format(provider) + assert provider in available_providers, f"{provider} is not available." self._provider = provider self._device = torch.device("cpu") # use cpu to generate torch tensor # optionally: take in an already-initialized ORTModel if not isinstance(pretrained, str): eval_logger.warning( - "`pretrained` model kwarg is not of type `str`. " + "Many other model arguments may be ignored. " + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. " ) self._model = pretrained self._config = self._model.config - self.model.providers + self.model.providers # noqa: B018 if tokenizer: - assert isinstance(tokenizer, transformers.PreTrainedTokenizer) or isinstance( - tokenizer, transformers.PreTrainedTokenizerFast - ) + assert isinstance(tokenizer, (transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast)) self.tokenizer = tokenizer else: # Get tokenizer @@ -134,31 +126,27 @@ def __init__( self.tokenizer.pad_token_id = self.tokenizer.unk_token_id elif self.tokenizer.eos_token: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + elif getattr(self.config, "model_type", None) == "qwen": + # Qwen's trust_remote_code tokenizer does not allow for adding special tokens + self.tokenizer.pad_token = "<|endoftext|>" + elif self.tokenizer.__class__.__name__ in ("RWKVWorldTokenizer", "Rwkv5Tokenizer"): + # The RWKV world tokenizer, does not allow for adding special tokens / + # setting the pad token (which is set as 0) + # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer + # --- + # Note that the world tokenizer class name, might change in the future + # for the final huggingface merge + # https://github.com/huggingface/transformers/pull/26963 + assert self.tokenizer.pad_token_id == 0 else: - if getattr(self.config, "model_type", None) == "qwen": - # Qwen's trust_remote_code tokenizer does not allow for adding special tokens - self.tokenizer.pad_token = "<|endoftext|>" - elif ( - self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer" - or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer" - ): - # The RWKV world tokenizer, does not allow for adding special tokens / - # setting the pad token (which is set as 0) - # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer - # --- - # Note that the world tokenizer class name, might change in the future - # for the final huggingface merge - # https://github.com/huggingface/transformers/pull/26963 - assert self.tokenizer.pad_token_id == 0 - else: - self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) # TODO: override this for Gemma self.add_bos_token = add_bos_token if getattr(self.config, "model_type", None) == "gemma": self.add_bos_token = True eval_logger.info( - f"Model type is '{self.config.model_type}', " + f"Model type is '{self.config.model_type}', " # noqa: ISC003, G003 + "a BOS token will be used as Gemma underperforms without it." ) @@ -178,7 +166,7 @@ def __init__( if not isinstance(pretrained, str): # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. eval_logger.warning( - "Passed an already-initialized model through `pretrained`," + "Passed an already-initialized model through `pretrained`," # noqa: ISC003, G003 + " assuming single-process call to evaluate() or custom distributed integration" ) self._rank = 0 @@ -234,9 +222,9 @@ def world_size(self): def _get_backend( self, - config: Union[transformers.PretrainedConfig, transformers.AutoConfig], - backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", - trust_remote_code: Optional[bool] = False, + config: transformers.PretrainedConfig | transformers.AutoConfig, + backend: Literal["default", "causal", "seq2seq"] | None = "default", + trust_remote_code: bool | None = False, ) -> None: """Helper method during initialization. @@ -250,37 +238,28 @@ def _get_backend( self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM elif backend == "seq2seq": self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM - eval_logger.info(f"Overrode HF model backend type, and using type '{backend}'") + eval_logger.info(f"Overrode HF model backend type, and using type '{backend}'") # noqa: G004 + elif config.model_type in transformers.models.auto.modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: + # first check if model type is listed under seq2seq models, since some + # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. + # these special cases should be treated as seq2seq models. + self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM + elif self.config.model_type in transformers.models.auto.modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM else: - # determine and use the default HF backend for this model, based on its config + metadata. - if ( - getattr(config, "model_type") - in transformers.models.auto.modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES - ): - # first check if model type is listed under seq2seq models, since some - # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. - # these special cases should be treated as seq2seq models. - self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM - elif ( - getattr(self.config, "model_type") - in transformers.models.auto.modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ): - self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM - else: - if not trust_remote_code: - eval_logger.warning( - "HF model type is neither marked as CausalLM or Seq2SeqLM. \ - This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." - ) - # if model type is neither in HF transformers causal or seq2seq model registries - # then we default to AutoModelForCausalLM - self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + if not trust_remote_code: + eval_logger.warning( + "HF model type is neither marked as CausalLM or Seq2SeqLM. \ + This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." + ) + # if model type is neither in HF transformers causal or seq2seq model registries + # then we default to AutoModelForCausalLM + self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM assert self.AUTO_MODEL_CLASS in [ transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM, ] - return None def _get_config( self, @@ -305,10 +284,10 @@ def _create_model( local_dir = tempfile.TemporaryDirectory().name huggingface_hub.snapshot_download(pretrained, local_dir=local_dir) pretrained = local_dir - except Exception as e: - raise e + except Exception: # noqa: TRY302 + raise - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: if ( not os.path.exists(os.path.join(pretrained, "decoder_model.onnx")) and not os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")) @@ -316,8 +295,9 @@ def _create_model( and not os.path.exists(os.path.join(pretrained, "model.onnx")) ): raise ValueError( - "Couldn't find any ONNX model name in " + "['decoder_model.onnx', 'decoder_with_past_model.onnx', " - "'decoder_model_merged.onnx', 'model.onnx'] in {}.".format(pretrained) + "Couldn't find any ONNX model name in " + "['decoder_model.onnx', 'decoder_with_past_model.onnx', " + f"'decoder_model_merged.onnx', 'model.onnx'] in {pretrained}." ) sess_options = onnxruntime.SessionOptions() @@ -328,93 +308,89 @@ def _create_model( session = optimum.onnxruntime.ORTModelForCausalLM.load_model( os.path.join(pretrained, "model.onnx"), provider=self.provider, session_options=sess_options ) - inputs_names = [input.name for input in session.get_inputs()] + inputs_names = [input.name for input in session.get_inputs()] # noqa: A001 key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)] use_cache = len(key_value_input_names) > 0 self._model = optimum.onnxruntime.ORTModelForCausalLM( session, self.config, - use_cache=True if use_cache else False, - use_io_binding=True if use_cache else False, + use_cache=bool(use_cache), + use_io_binding=bool(use_cache), ) - else: - if os.path.exists(os.path.join(pretrained, "decoder_model_merged.onnx")): - session = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "decoder_model_merged.onnx"), - provider=self.provider, - session_options=sess_options, - ) - self._model = optimum.onnxruntime.ORTModelForCausalLM(session, self.config, use_cache=True) - elif os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")): - session = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "decoder_with_past_model.onnx"), - provider=self.provider, - session_options=sess_options, - ) - self._model = optimum.onnxruntime.ORTModelForCausalLM(session, self.config, use_cache=True) - elif os.path.exists(os.path.join(pretrained, "decoder_model.onnx")): - session = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "decoder_model.onnx"), - provider=self.provider, - session_options=sess_options, - ) - self._model = optimum.onnxruntime.ORTModelForCausalLM( - session, self.config, use_cache=False, use_io_binding=False - ) - else: - if os.path.exists(os.path.join(pretrained, "model.onnx")): + elif os.path.exists(os.path.join(pretrained, "decoder_model_merged.onnx")): session = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "model.onnx"), provider=self.provider, session_options=sess_options + os.path.join(pretrained, "decoder_model_merged.onnx"), + provider=self.provider, + session_options=sess_options, + ) + self._model = optimum.onnxruntime.ORTModelForCausalLM(session, self.config, use_cache=True) + elif os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")): + session = optimum.onnxruntime.ORTModelForCausalLM.load_model( + os.path.join(pretrained, "decoder_with_past_model.onnx"), + provider=self.provider, + session_options=sess_options, + ) + self._model = optimum.onnxruntime.ORTModelForCausalLM(session, self.config, use_cache=True) + elif os.path.exists(os.path.join(pretrained, "decoder_model.onnx")): + session = optimum.onnxruntime.ORTModelForCausalLM.load_model( + os.path.join(pretrained, "decoder_model.onnx"), + provider=self.provider, + session_options=sess_options, ) - inputs_names = session.get_inputs() - key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)] - use_cache = len(key_value_input_names) > 0 - self._model = optimum.onnxruntime.ORTModelForCausalLM( - session[0], - self.config, - pretrained, - use_cache=True if use_cache else False, - use_io_binding=True if use_cache else False, + session, self.config, use_cache=False, use_io_binding=False ) - else: - if os.path.exists(os.path.join(pretrained, "decoder_model_merged.onnx")): - sessions = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "decoder_model_merged.onnx"), - provider=self.provider, - session_options=sess_options, - ) - self._model = optimum.onnxruntime.ORTModelForCausalLM( - sessions[0], self.config, pretrained, use_cache=True - ) - elif os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")): - sessions = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "decoder_model.onnx"), - os.path.join(pretrained, "decoder_with_past_model.onnx"), - provider=self.provider, - session_options=sess_options, - ) - self._model = optimum.onnxruntime.ORTModelForCausalLM( - sessions[0], self.config, pretrained, sessions[1], use_cache=True - ) - else: - sessions = optimum.onnxruntime.ORTModelForCausalLM.load_model( - os.path.join(pretrained, "decoder_model.onnx"), - provider=self.provider, - session_options=sess_options, - ) - self._model = optimum.onnxruntime.ORTModelForCausalLM( - sessions[0], self.config, pretrained, use_cache=False, use_io_binding=False - ) - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif os.path.exists(os.path.join(pretrained, "model.onnx")): + session = optimum.onnxruntime.ORTModelForCausalLM.load_model( + os.path.join(pretrained, "model.onnx"), provider=self.provider, session_options=sess_options + ) + inputs_names = session.get_inputs() + key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)] + use_cache = len(key_value_input_names) > 0 + + self._model = optimum.onnxruntime.ORTModelForCausalLM( + session[0], + self.config, + pretrained, + use_cache=bool(use_cache), + use_io_binding=bool(use_cache), + ) + elif os.path.exists(os.path.join(pretrained, "decoder_model_merged.onnx")): + sessions = optimum.onnxruntime.ORTModelForCausalLM.load_model( + os.path.join(pretrained, "decoder_model_merged.onnx"), + provider=self.provider, + session_options=sess_options, + ) + self._model = optimum.onnxruntime.ORTModelForCausalLM( + sessions[0], self.config, pretrained, use_cache=True + ) + elif os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")): + sessions = optimum.onnxruntime.ORTModelForCausalLM.load_model( + os.path.join(pretrained, "decoder_model.onnx"), + os.path.join(pretrained, "decoder_with_past_model.onnx"), + provider=self.provider, + session_options=sess_options, + ) + self._model = optimum.onnxruntime.ORTModelForCausalLM( + sessions[0], self.config, pretrained, sessions[1], use_cache=True + ) + else: + sessions = optimum.onnxruntime.ORTModelForCausalLM.load_model( + os.path.join(pretrained, "decoder_model.onnx"), + provider=self.provider, + session_options=sess_options, + ) + self._model = optimum.onnxruntime.ORTModelForCausalLM( + sessions[0], self.config, pretrained, use_cache=False, use_io_binding=False + ) + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: if not os.path.exists(os.path.join(pretrained, "encoder_model.onnx")) or ( not os.path.exists(os.path.join(pretrained, "decoder_model.onnx")) and not os.path.exists(os.path.join(pretrained, "decoder_model_merged.onnx")) ): raise ValueError( - "Please ensure encoder_model.onnx and " - "decoder_model(_merged).onnx are under {}.".format(pretrained) + f"Please ensure encoder_model.onnx and decoder_model(_merged).onnx are under {pretrained}." ) sess_options = onnxruntime.SessionOptions() @@ -463,21 +439,13 @@ def _create_model( use_io_binding=False, ) - return None - def _create_tokenizer( self, - pretrained: Union[str, transformers.PreTrainedModel], - tokenizer: Optional[ - Union[ - str, - transformers.PreTrainedTokenizer, - transformers.PreTrainedTokenizerFast, - ] - ], - revision: Optional[str] = "main", - trust_remote_code: Optional[bool] = False, - use_fast_tokenizer: Optional[bool] = True, + pretrained: str | transformers.PreTrainedModel, + tokenizer: str | transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast | None, + revision: str | None = "main", + trust_remote_code: bool | None = False, + use_fast_tokenizer: bool | None = True, ) -> None: """Helper method during initialization. @@ -493,9 +461,7 @@ def _create_tokenizer( use_fast=use_fast_tokenizer, ) else: - assert isinstance(tokenizer, transformers.PreTrainedTokenizer) or isinstance( - tokenizer, transformers.PreTrainedTokenizerFast - ) + assert isinstance(tokenizer, (transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast)) self.tokenizer = tokenizer else: # Get tokenizer based on 'pretrained' @@ -510,7 +476,6 @@ def _create_tokenizer( trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, ) - return None def _detect_batch_size(self, requests=None, pos: int = 0): if requests: @@ -524,7 +489,7 @@ def _detect_batch_size(self, requests=None, pos: int = 0): # if OOM, then halves batch_size and tries again @accelerate.find_executable_batch_size(starting_batch_size=self.max_batch_size) def forward_batch(batch_size): - if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + if transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: length = max(max_context_enc, max_cont_enc) batched_conts = torch.ones((batch_size, length), device=self._device).long() test_batch = torch.ones((batch_size, length), device=self._device).long() @@ -536,7 +501,7 @@ def forward_batch(batch_size): call_kwargs = {} test_batch = torch.ones((batch_size, max_length), device=self._device).long() for _ in range(5): - out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) + F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) return batch_size @@ -559,11 +524,11 @@ def forward_batch(batch_size): lm_eval.models.utils.clear_torch_cache() return batch_size - def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: if add_special_tokens is None: - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: add_special_tokens = False or self.add_bos_token - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: # TODO: investigate best practices for enc-dec models + special tokens add_special_tokens = True @@ -577,18 +542,18 @@ def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=Non def tok_batch_encode( self, - strings: List[str], + strings: list[str], padding_side: str = "left", - left_truncate_len: int = None, + left_truncate_len: int | None = None, truncation: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. old_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = padding_side - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: add_special_tokens = False or self.add_bos_token - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: add_special_tokens = True encoding = self.tokenizer( @@ -606,9 +571,9 @@ def tok_batch_encode( return encoding["input_ids"], encoding["attention_mask"] def tok_decode(self, tokens, skip_special_tokens=True): - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) def _model_call(self, inps, attn_mask=None, labels=None): @@ -633,7 +598,7 @@ def _model_call(self, inps, attn_mask=None, labels=None): """ if attn_mask is not None or labels is not None: assert attn_mask is not None and labels is not None - assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM + assert transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS decoder_start_token_id = self._config.decoder_start_token_id pad_token_id = self._config.pad_token_id shifted_input_ids = labels.new_zeros(labels.shape) @@ -647,7 +612,7 @@ def _model_call(self, inps, attn_mask=None, labels=None): labels=labels, ).logits else: - assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM + assert transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS if ( hasattr(self.model, "config") and hasattr(self.model.config, "auto_map") @@ -657,7 +622,7 @@ def _model_call(self, inps, attn_mask=None, labels=None): bos = torch.tensor([64790, 64792]).repeat(input_bs, 1) inps = torch.cat((bos, inps), 1) - inputs_names = [input.name for input in self.model.model.get_inputs()] + inputs_names = [input.name for input in self.model.model.get_inputs()] # noqa: A001 if "position_ids" in inputs_names: # model is exported with optimum >= 1.14.0 with new input 'position_ids' input_shape = inps.shape @@ -698,13 +663,15 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): **generation_kwargs, ) - def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + def _select_cont_toks( + self, logits: torch.Tensor, contlen: int | None = None, inplen: int | None = None + ) -> torch.Tensor: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" # discard right-padding. # also discard the input/context tokens. we'll only score continuations. logits = logits[inplen - contlen : inplen] - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" # only discard right-padding. # the logits input to this fn only contain decoder-side tokens. @@ -713,8 +680,8 @@ def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: i return logits def loglikelihood_rolling( - self, requests: List[lm_eval.api.instance.Instance], disable_tqdm: bool = False - ) -> List[float]: + self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False + ) -> list[float]: loglikelihoods = [] adaptive_batch_size = None @@ -740,7 +707,7 @@ def loglikelihood_rolling( # TODO: Right now, # we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case - rolling_token_windows = [(None,) + x for x in rolling_token_windows] + rolling_token_windows = [(None, *x) for x in rolling_token_windows] pad_amnt = 0 if self.world_size > 1: @@ -784,15 +751,15 @@ def _batch_scheduler(self, pos, n_reordered_requests): def _loglikelihood_tokens( self, - requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + requests: list[tuple[tuple[str, str], list[int], list[int]]], disable_tqdm: bool = False, - override_bs: int = None, - ) -> List[Tuple[float, bool]]: + override_bs: int | None = None, + ) -> list[tuple[float, bool]]: # TODO: # implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] - def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): + def _collate(req: tuple[tuple[str, str], list[int], list[int]]): """Defines the key for the sorted method.""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning @@ -804,7 +771,7 @@ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): toks = req[1] + req[2] return -len(toks), tuple(toks) - def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): + def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): """Defines the key to group and lookup one-token continuations.""" # Use with group_by="contexts" (optional)" # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. @@ -816,7 +783,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): requests, sort_fn=_collate, group_by=( - "contexts" if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM and self.logits_cache else None + "contexts" if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS and self.logits_cache else None ), group_fn=_lookup_one_token_cont, ) @@ -865,14 +832,14 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, device=self._device, ) (inplen,) = inp.shape - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: inp = torch.tensor( (context_enc)[-self.max_length :], dtype=torch.long, @@ -904,11 +871,11 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: batched_inps = lm_eval.models.utils.pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: # TODO: left-pad encoder inps and mask? batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] batched_conts = lm_eval.models.utils.pad_and_concat( @@ -936,11 +903,11 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS else None ) - logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) - logits = logits.unsqueeze(0) # [1, seq, vocab] + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) # noqa: PLW2901 + logits = logits.unsqueeze(0) # [1, seq, vocab] # noqa: PLW2901 # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) @@ -950,18 +917,22 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): # original args. Otherwise, expands the logits batch dimension and yields each # batch along with matching continuation tokens and prompt strings. # logits -> [1, seq, vocab] - for request_str, cont_toks, logits in re_ord.get_cache( + for request_str, cont_toks, logits in re_ord.get_cache( # noqa: B020, PLW2901 req_str=request_str, cxt_toks=ctx_tokens, cont_toks=cont_toks, logits=logits, ): - cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=self._device).unsqueeze(0) # [1, seq] + cont_toks = torch.tensor( # noqa: PLW2901 + cont_toks, dtype=torch.long, device=self._device + ).unsqueeze( + 0 + ) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() - logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] # noqa: PLW2901 # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) @@ -975,10 +946,10 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): return re_ord.get_original(res) - def generate_until(self, requests: List[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> List[str]: + def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: res = [] - def _collate(req: Tuple[str, dict]): + def _collate(req: tuple[str, dict]): """Defines the key for the sorted method.""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning @@ -1029,30 +1000,30 @@ def _collate(req: Tuple[str, dict]): until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 - if "until" in kwargs.keys(): + if "until" in kwargs: until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError(f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}") else: - raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") + raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") # noqa: TRY004 # add EOS token to stop sequences eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) if not until: until = [eos] else: until.append(eos) - if "max_gen_toks" in kwargs.keys(): + if "max_gen_toks" in kwargs: max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks - elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + elif transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length @@ -1079,8 +1050,8 @@ def _collate(req: Tuple[str, dict]): cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: - cont_toks = cont_toks[context_enc.shape[1] :] + if transformers.AutoModelForCausalLM == self.AUTO_MODEL_CLASS: + cont_toks = cont_toks[context_enc.shape[1] :] # noqa: PLW2901 s = self.tok_decode(cont_toks) diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/utils.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/utils.py index a9845eb41..adc5b507a 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/utils.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/utils.py @@ -38,9 +38,11 @@ def __init__( verbosity="INFO", wandb_args="", predict_only=False, - seed=[0, 1234, 1234], + seed=None, trust_remote_code=False, ): + if seed is None: + seed = [0, 1234, 1234] self.model = model self.tasks = tasks self.model_args = model_args diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py index 9cafe62d3..eaf92734b 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py @@ -37,7 +37,7 @@ from onnx_neural_compressor.quantization import matmul_nbits_quantizer, tuning logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.WARN + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.WARNING ) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -108,7 +108,7 @@ def replace_architectures(json_path): # replace 'LLaMATokenizer' to lowercase 'LlamaTokenizer' # to avoid bug 'Tokenizer class LLaMATokenizer does not exist or is not currently imported.' # refer to https://github.com/huggingface/transformers/issues/22222#issuecomment-1477171703 - with open(json_path, "r") as file: + with open(json_path) as file: data = json.load(file) data["architectures"] = ["LlamaForCausalLM"] @@ -136,10 +136,10 @@ def eval_func(model): eval_acc = 0 for task_name in args.tasks: if task_name == "wikitext": - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity,none"])) + print("Accuracy for {} is: {}".format(task_name, results["results"][task_name]["word_perplexity,none"])) eval_acc += results["results"][task_name]["word_perplexity,none"] else: - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc,none"])) + print("Accuracy for {} is: {}".format(task_name, results["results"][task_name]["acc,none"])) eval_acc += results["results"][task_name]["acc,none"] if len(args.tasks) != 0: @@ -162,8 +162,8 @@ def benchmark(model): model = optimum_ort.ORTModelForCausalLM( session, # pylint: disable=E1121 model_config, - use_cache=True if use_cache else False, - use_io_binding=True if use_cache else False, + use_cache=bool(use_cache), + use_io_binding=bool(use_cache), ) max_new_tokens = 32 @@ -177,7 +177,6 @@ def benchmark(model): num_warmup = 10 batch_size = 1 prompt = [prompt] * batch_size - total_list = [] for i in range(num_iter): tic = time.time() @@ -193,7 +192,7 @@ def benchmark(model): print("\n", "-" * 10, "Summary:", "-" * 10) print(args) throughput = (num_iter - num_warmup) / total_time - print("Throughput: {} samples/s".format(throughput)) + print(f"Throughput: {throughput} samples/s") class AWQDataloader(data_reader.CalibrationDataReader): @@ -212,7 +211,7 @@ def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder="train", ca collate_fn=self.collate_batch, ) model = onnx.load(model_path, load_external_data=False) - inputs_names = [input.name for input in model.graph.input] + inputs_names = [input.name for input in model.graph.input] # noqa: A001 key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)] use_cache = len(key_value_input_names) > 0 self.batch_size = batch_size @@ -274,17 +273,17 @@ def __init__(self, model_path, batch_size=1, seqlen=2048, sub_folder="train", ca traindata.set_format(type="torch", columns=["input_ids", "attention_mask"]) session = ort.InferenceSession(model_path) - inputs_names = [input.name for input in session.get_inputs()] + inputs_names = [input.name for input in session.get_inputs()] # noqa: A001 key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)] use_cache = len(key_value_input_names) > 0 for i in range(calibration_sampling_size): while True: - i = random.randint(0, len(traindata) - 1) + i = random.randint(0, len(traindata) - 1) # noqa: PLW2901 trainenc = traindata[i] if trainenc["input_ids"].shape[0] > seqlen: break - i = random.randint(0, trainenc["input_ids"].shape[0] - seqlen - 1) + i = random.randint(0, trainenc["input_ids"].shape[0] - seqlen - 1) # noqa: PLW2901 j = i + seqlen inp = trainenc["input_ids"][i:j].unsqueeze(0) mask = torch.ones(inp.shape) @@ -325,7 +324,7 @@ def rewind(self): elif args.mode == "accuracy": acc_result = eval_func(args.model_path) print("Batch size = %d" % args.batch_size) - print("Accuracy: %.5f" % acc_result) + print(f"Accuracy: {acc_result:.5f}") if args.tune: model_name = "model.onnx" # require optimum >= 1.14.0 diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py index 3af820943..707cd3cf9 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py @@ -40,6 +40,7 @@ def prepare_model(input_model, output_model, task): ], stdout=subprocess.PIPE, text=True, + check=False, ) assert os.path.exists(output_model), f"{output_model} doesn't exist!" diff --git a/onnx_neural_compressor/__init__.py b/onnx_neural_compressor/__init__.py index a8e492104..009c578a0 100644 --- a/onnx_neural_compressor/__init__.py +++ b/onnx_neural_compressor/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2021 Intel Corporation # diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index 2e381cfdb..2bb906417 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2023 MIT HAN Lab # This source code is licensed under the MIT license @@ -16,26 +15,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import os import pathlib +from typing import Callable import onnx import onnxruntime as ort from onnx_neural_compressor import data_reader, logger, onnx_model, utility -from typing import Callable, List, Union # isort: skip - -def layer_wise_quant( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], +def layer_wise_quant( # noqa: D417 + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, quant_func: Callable, weight_config: dict, data_reader: data_reader.CalibrationDataReader = None, - *args, - **kwargs + *args, # noqa: ARG001 + **kwargs, ) -> onnx_model.ONNXModel: """Quantize model layer by layer to save memory. @@ -82,12 +81,10 @@ def layer_wise_quant( ) raise ValueError("Fail to run layer-wise quantization.") logger.info( - "Will split model into {} parts to do layer-wise quantization".format( - len([node.name for node in split_nodes]) + 1 - ) + f"Will split model into {len([node.name for node in split_nodes]) + 1} parts to do layer-wise quantization" ) logger.debug( - "Will split model with these nodes for layer-wise quantization: {}".format([node.name for node in split_nodes]) + f"Will split model with these nodes for layer-wise quantization: {[node.name for node in split_nodes]}" ) split_idx = 1 @@ -106,7 +103,7 @@ def layer_wise_quant( current_data_reader = lwq_data_reader.pop(0) # if no remaining split nodes, it means this is the last split, and the two split models will be saved. - save_both_split_models = True if len(split_nodes) == 0 else False + save_both_split_models = len(split_nodes) == 0 # split model with given split node split_model_part_1, split_model_part_2 = split_model.split_model_with_node( @@ -116,7 +113,7 @@ def layer_wise_quant( # append split_model_part_2 to do next split model_to_split.append(split_model_part_2) - logger.info("Quantize split model {}".format(split_idx)) + logger.info(f"Quantize split model {split_idx}") if require_data_reader: # process data_reader for current split and next split current_data_reader = _filter_data_reader_for_current_split_model( @@ -133,7 +130,7 @@ def layer_wise_quant( weight_config=weight_config, data_reader=current_data_reader, return_modelproto=False, - **kwargs + **kwargs, ) else: # perform quantization @@ -144,12 +141,12 @@ def layer_wise_quant( # check split model is valid try: ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers) - except Exception as e: + except Exception: logger.error( - "Layer-wise quantized model {} can't be inferred correctly. " - "Please check the raise exception".format(split_idx) + f"Layer-wise quantized model {split_idx} can't be inferred correctly. " + "Please check the raise exception" ) - raise e + raise # merge split quantized model if quantized_model_merged is None: @@ -161,7 +158,7 @@ def layer_wise_quant( split_idx += 1 # if this is the last split, quantize the last split model if save_both_split_models: - logger.info("Quantize split model {}".format(split_idx)) + logger.info(f"Quantize split model {split_idx}") # quantize split model if require_data_reader: @@ -177,7 +174,7 @@ def layer_wise_quant( weight_config=weight_config, data_reader=current_data_reader, return_modelproto=False, - **kwargs + **kwargs, ) else: # perform quantization @@ -188,12 +185,12 @@ def layer_wise_quant( # check split model is valid try: ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers) - except Exception as e: + except Exception: logger.error( - "Layer-wise quantized model {} can't be inferred correctly. " - "Please check the raise exception".format(split_idx) + f"Layer-wise quantized model {split_idx} can't be inferred correctly. " + "Please check the raise exception" ) - raise e + raise # merge split quantized model if quantized_model_merged is None: @@ -235,7 +232,7 @@ def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_rea data_reader.CalibrationDataReader: filtered data reader. """ filter_inputs = [] - input_names = [input.name for input in model.graph.input] + input_names = [input.name for input in model.graph.input] # noqa: A001 while True: inputs = data_reader.get_next() if not inputs: @@ -247,10 +244,10 @@ def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_rea return DataReader(filter_inputs) -def _prepare_data_reader_for_next_split_model( +def _prepare_data_reader_for_next_split_model( # noqa: D417 model_path: str, data_reader: data_reader.CalibrationDataReader, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, ): """Prepare data reader for next split model. @@ -264,6 +261,8 @@ def _prepare_data_reader_for_next_split_model( Returns: data_reader.CalibrationDataReader: data reader for next split model. """ + if providers is None: + providers = ["CPUExecutionProvider"] data_reader = copy.deepcopy(data_reader) data_reader_for_next_split_model = [] @@ -274,6 +273,6 @@ def _prepare_data_reader_for_next_split_model( if not inputs: break out = session.run(None, inputs) - inputs.update({name: value for name, value in zip(output_names, out)}) + inputs.update(dict(zip(output_names, out))) data_reader_for_next_split_model.append(inputs) return DataReader(data_reader_for_next_split_model) diff --git a/onnx_neural_compressor/algorithms/smoother/calibrator.py b/onnx_neural_compressor/algorithms/smoother/calibrator.py index fe0a862cc..10691c2e2 100644 --- a/onnx_neural_compressor/algorithms/smoother/calibrator.py +++ b/onnx_neural_compressor/algorithms/smoother/calibrator.py @@ -12,30 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. """Calibration for smooth quant.""" +from __future__ import annotations import importlib.util import pathlib import sys import tempfile -from typing import List import numpy as np import onnx import onnxruntime -from onnx_neural_compressor import data_reader, logger, onnx_model, utility +from onnx_neural_compressor import data_reader, logger, onnx_model class Calibrator: """Dump information for smooth quant.""" - def __init__( + def __init__( # noqa: D417 self, model: onnx_model.ONNXModel, dataloader: data_reader.CalibrationDataReader, - iterations: List[int] = [], - providers: List[str] = ["CPUExecutionProvider"], - **kwargs, + iterations: list[int] | None = None, + providers: list[str] | None = None, + **kwargs, # noqa: ARG002 ): """Initialize a Calibrator to dump information. @@ -45,6 +45,10 @@ def __init__( iterations (List[int], optional): tensor of which iteration will be collected. Defaults to []. providers (List[str], optional): execution provider for onnxruntime. Defaults to ["CPUExecutionProvider"]. """ + if providers is None: + providers = ["CPUExecutionProvider"] + if iterations is None: + iterations = [] self.model_wrapper = model self.dataloader = dataloader self.augmented_model = None @@ -82,7 +86,7 @@ def _check_is_group_conv(self, node): return True return False - def _get_input_tensor_of_ops(self, op_types: List[str] = ["MatMul", "Gemm", "Conv", "FusedConv"]): + def _get_input_tensor_of_ops(self, op_types: list[str] | None = None): """Traverse the graph and get all the data tensors flowing into layers of {op_types}. Group conv is excluded. @@ -95,6 +99,8 @@ def _get_input_tensor_of_ops(self, op_types: List[str] = ["MatMul", "Gemm", "Con Returns: dict: A dict of dumped tensor to node info """ + if op_types is None: + op_types = ["MatMul", "Gemm", "Conv", "FusedConv"] tensors_to_node = {} initializers = {i.name: i for i in self.model_wrapper.initializer()} @@ -103,7 +109,7 @@ def _get_input_tensor_of_ops(self, op_types: List[str] = ["MatMul", "Gemm", "Con if node.op_type in ["Conv", "FusedConv"] and self._check_is_group_conv(node): continue # also need to check whether the layer has weight - if len(node.input) >= 2 and node.input[1] in initializers.keys(): + if len(node.input) >= 2 and node.input[1] in initializers: tensors_to_node.setdefault(node.input[0], []).append([node.name, node.input, node.output]) return tensors_to_node @@ -129,7 +135,7 @@ def _get_max_per_channel(self, datas, percentile): elif len(data.shape) == 2: permute_datas.append(np.abs(data)) else: - assert False, "not supported" + raise AssertionError("not supported") permute_datas = np.stack(permute_datas, axis=0) permute_datas = permute_datas.reshape(-1, permute_datas.shape[-1]) max_per_channels = np.percentile(permute_datas, percentile, axis=0) @@ -175,7 +181,7 @@ def get_intermediate_outputs(self): node = output_name_to_node[data_name] elif data_name in input_name_to_nodes: node = input_name_to_nodes[data_name][0] - assert node, "{} is neither an input nor an output of nodes in augmented model.".format(data_name) + assert node, f"{data_name} is neither an input nor an output of nodes in augmented model." name_to_node[data_name] = node.name def _collect_data(ort_inputs): diff --git a/onnx_neural_compressor/algorithms/smoother/core.py b/onnx_neural_compressor/algorithms/smoother/core.py index d21641482..cc9d0762a 100644 --- a/onnx_neural_compressor/algorithms/smoother/core.py +++ b/onnx_neural_compressor/algorithms/smoother/core.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Smoother for onnxrt.""" +from __future__ import annotations import copy import os @@ -24,9 +25,6 @@ from onnx_neural_compressor import data_reader, logger, onnx_model, utility from onnx_neural_compressor.algorithms.smoother import calibrator -from typing import List, Union # isort: skip - - _dtype_map = { np.dtype("float32"): 1, np.dtype("uint8"): 2, @@ -65,7 +63,9 @@ def _make_sub_graph(node, inits, input_data, output_data, opset, ir_version): opset (object): opset of the model ir_version (object): ir_version of the model """ - input = onnx.helper.make_tensor_value_info(node.input[0], _dtype_map[input_data.dtype], input_data.shape) + input = onnx.helper.make_tensor_value_info( # noqa: A001 + node.input[0], _dtype_map[input_data.dtype], input_data.shape + ) output = onnx.helper.make_tensor_value_info(node.output[0], _dtype_map[output_data.dtype], output_data.shape) graph = onnx.helper.make_graph([node], "sub_graph", [input], [output], inits) model = onnx.helper.make_model(graph, opset_imports=opset) @@ -73,7 +73,7 @@ def _make_sub_graph(node, inits, input_data, output_data, opset, ir_version): return model -def _quant_dequant_data(data, qType=3, scheme="sym"): +def _quant_dequant_data(data, qType=3, scheme="sym"): # noqa: N803 """Quantize and then dequantize data. Args: @@ -100,11 +100,13 @@ class Smoother: def __init__( self, - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, dataloader: data_reader.CalibrationDataReader, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, ): """Initialize the attributes of class.""" + if providers is None: + providers = ["CPUExecutionProvider"] self.model = ( model if isinstance(model, onnx_model.ONNXModel) else onnx_model.ONNXModel(model, load_external_data=True) ) @@ -125,17 +127,17 @@ def __init__( self.tensors_to_node = None self._build_absorb_function() - def transform( + def transform( # noqa: D417 self, - alpha: Union[float, str] = 0.5, + alpha: float | str = 0.5, folding: bool = True, percentile: float = 99.999, - op_types: List[str] = ["Gemm", "Conv", "MatMul", "FusedConv"], + op_types: list[str] | None = None, scales_per_op: bool = True, calib_iter: int = 100, - auto_alpha_args: dict = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"}, - *args, - **kwargs + auto_alpha_args: dict | None = None, + *args, # noqa: ARG002 + **kwargs, # noqa: ARG002 ): """The main entry of smooth quant. @@ -159,6 +161,10 @@ def transform( onnx.ModelProto: A FP32 model with the same architecture as the orig model but with different weight which will be benefit to quantization """ + if auto_alpha_args is None: + auto_alpha_args = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"} + if op_types is None: + op_types = ["Gemm", "Conv", "MatMul", "FusedConv"] self.scales_per_op = scales_per_op self.clean() if isinstance(alpha, float) and (alpha < 0 or alpha > 1): @@ -203,7 +209,7 @@ def _dump_op_info(self, percentile, op_types, iterations): sq_calibrator = calibrator.Calibrator( self.model, self.dataloader, - iterations=list(range(0, iterations)), + iterations=list(range(iterations)), backend=self.providers, ) @@ -226,7 +232,7 @@ def recover(self): key = node_info[0] if self.scales_per_op else tensor_name if key not in self.tensor_scales_info: continue - input = node_info[1][1] + input = node_info[1][1] # noqa: A001 weight = onnx.numpy_helper.to_array( self.model.get_initializer(input), base_dir=os.path.dirname(self.model.model_path) if self.model.model_path is not None else "", @@ -277,7 +283,7 @@ def norm(node, scale): # pragma: no cover return True def mul(node, scale): # pragma: no cover - if all([self.model.get_initializer(inp) is None for inp in node.input]): + if all(self.model.get_initializer(inp) is None for inp in node.input): return False for inp in node.input: if self.model.get_initializer(inp) is not None: @@ -431,7 +437,7 @@ def _auto_tune_alpha( alpha_min: float = 0.3, alpha_max: float = 0.7, alpha_step: float = 0.05, - attn_method: str = "min", + attn_method: str = "min", # noqa: ARG002 ): """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly. @@ -492,7 +498,7 @@ def _auto_tune_alpha( os.remove(os.path.join(os.path.dirname(self.model.model_path), "weights.pb")) return optimal_alphas - def _get_smooth_scales(self, alpha, target_list=[]): + def _get_smooth_scales(self, alpha, target_list=None): """Get the smooth scales for. The ops with the same input will share one mul layer. @@ -505,6 +511,8 @@ def _get_smooth_scales(self, alpha, target_list=[]): Returns: the smooth scales for weights, currently one input tensor only have one scale """ + if target_list is None: + target_list = [] logger.info("Start smooth scales collection.") scales = {} for tensor, nodes in self.tensors_to_node.items(): @@ -571,7 +579,7 @@ def _insert_smooth_mul_op(self, scales): Args: scales (dict): The smooth scales """ - for key in scales.keys(): + for key in scales: input_name = key if not self.scales_per_op else self.model.get_node(key).input[0] weight_name = ( self.tensors_to_node[key][0][1][1] if not self.scales_per_op else self.model.get_node(key).input[1] @@ -584,8 +592,8 @@ def _insert_smooth_mul_op(self, scales): elif len(self.shape_info[weight_name]) == 4: scale_factor = np.reshape(scale_factor, (1, -1, 1, 1)) else: - assert False, "not support" - name = key + "_" + "smooth_scale" + raise AssertionError("not support") + key + "_" + "smooth_scale" scale_tensor = onnx.helper.make_tensor( name=key + "_" + "smooth_scale", data_type=onnx.onnx_pb.TensorProto.FLOAT, @@ -623,7 +631,7 @@ def _adjust_weights(self, scales): key = node_info[0] if self.scales_per_op else tensor_name if key not in scales: continue - input = node_info[1][1] + input = node_info[1][1] # noqa: A001 node = self.model.get_node_by_weight(input) weight = onnx.numpy_helper.to_array( self.model.get_initializer(input), @@ -648,7 +656,7 @@ def _adjust_weights(self, scales): scale = np.reshape(scales[key], (1, -1, 1, 1)) new_weight = weight * scale else: - assert False, "not support" + raise AssertionError("not support") self.tensor_scales_info[key] = 1.0 / scale new_tensor = onnx.numpy_helper.from_array(new_weight, input) diff --git a/onnx_neural_compressor/algorithms/weight_only/awq.py b/onnx_neural_compressor/algorithms/weight_only/awq.py index 30d9e8442..ace4fe83d 100644 --- a/onnx_neural_compressor/algorithms/weight_only/awq.py +++ b/onnx_neural_compressor/algorithms/weight_only/awq.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import os @@ -28,8 +29,6 @@ from onnx_neural_compressor.algorithms.weight_only import rtn from onnx_neural_compressor.algorithms.weight_only import utility as woq_utility -from typing import List, Union # isort: skip - def _get_weight_scale(weight, group_size): """Get the scale of weight.""" @@ -41,7 +40,6 @@ def _get_weight_scale(weight, group_size): def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, group_size, scheme): """Apply scale for salient weight.""" - best_scales = {} new_init_tensors = [] new_added_mul_nodes = [] replace_input = [] @@ -49,7 +47,7 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" for parent, nodes in absorb_pairs.items(): - if any([node.input[0] not in output_dicts for node in nodes]): + if any(node.input[0] not in output_dicts for node in nodes): logger.warning( "Miss input tensors of nodes {} during AWQ, skip it!".format( ", ".join([node.name for node in nodes if node.input[0] not in output_dicts]) @@ -72,12 +70,11 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, # search scale best_error = float("inf") - best_ratio = -1 best_scale = None n_grid = 20 for ratio in range(n_grid): - ratio = ratio * 1 / n_grid + ratio = ratio * 1 / n_grid # noqa: PLW2901 loss = 0 for node in nodes: if weight_config.get((node.name, node.op_type), {}) == "fp32": @@ -119,7 +116,6 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, is_best = loss < best_error if is_best: best_error = loss - best_ratio = ratio best_scale = scales for node in nodes: @@ -147,7 +143,7 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, if init_share_num == 1: model.remove_initializer(weight_tensor) - parent = model.get_node(parent) + parent = model.get_node(parent) # noqa: PLW2901 if parent.name in updated_nodes: continue @@ -164,7 +160,7 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, elif ( parent.op_type in ["SimplifiedLayerNormalization", "MatMul", "Gemm", "Mul"] - and not all([model.get_initializer(inp) is None for inp in parent.input]) + and not all(model.get_initializer(inp) is None for inp in parent.input) and len(model.input_name_to_nodes()[nodes[0].input[0]]) == len(nodes) ): # pragma: no cover for inp in parent.input: @@ -204,7 +200,7 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, ) new_added_mul_nodes.append(mul_node) for node in nodes: - replace_input.append([node, node.input[0], mul_node.output[0]]) + replace_input.append([node, node.input[0], mul_node.output[0]]) # noqa: PERF401 updated_nodes.append(parent.name) output_dicts[mul_node.output[0]] = output_dicts[mul_node.input[0]] / np.reshape(best_scale, (1, -1)) @@ -220,8 +216,8 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, """Apply clip for weight by checking mse.""" base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" ratios = {} - for parent, nodes in absorb_pairs.items(): - if any([node.input[0] not in output_dicts for node in nodes]): + for nodes in absorb_pairs.values(): + if any(node.input[0] not in output_dicts for node in nodes): logger.warning( "Miss input tensors of nodes {} during AWQ, skip it!".format( ", ".join([node.name for node in nodes if node.input[0] not in output_dicts]) @@ -278,16 +274,16 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, def awq_quantize( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, data_reader: data_reader.CalibrationDataReader, - weight_config: dict = {}, + weight_config: dict | None = None, num_bits: int = 4, group_size: int = 32, scheme: str = "asym", enable_auto_scale: bool = True, enable_mse_search: bool = True, accuracy_level: int = 0, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, ) -> onnx.ModelProto: """Quant the model with Activation-aware Weight quantization(AWQ) method. @@ -321,6 +317,10 @@ def awq_quantize( Returns: onnx.ModelProto: quantized onnx model. """ + if providers is None: + providers = ["CPUExecutionProvider"] + if weight_config is None: + weight_config = {} if not isinstance(model, onnx_model.ONNXModel): model = onnx_model.ONNXModel(model) output_dicts = {} @@ -343,7 +343,7 @@ def awq_quantize( and model.get_initializer(node.input[1]) is not None and weight_config.get((node.name, node.op_type), {}).get("weight_dtype", "fp32") != "fp32" ): - output_names.append(node.input[0]) + output_names.append(node.input[0]) # noqa: PERF401 output_names = list(set(output_names)) model.add_tensors_to_outputs(output_names) if model.is_large_model: @@ -415,7 +415,7 @@ def awq_quantize( def apply_awq_on_model( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, quant_config: dict, calibration_data_reader: data_reader.CalibrationDataReader, ) -> onnx.ModelProto: diff --git a/onnx_neural_compressor/algorithms/weight_only/gptq.py b/onnx_neural_compressor/algorithms/weight_only/gptq.py index 5016a2780..221e06263 100644 --- a/onnx_neural_compressor/algorithms/weight_only/gptq.py +++ b/onnx_neural_compressor/algorithms/weight_only/gptq.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import os @@ -28,12 +29,10 @@ from onnx_neural_compressor.algorithms.layer_wise import core from onnx_neural_compressor.algorithms.weight_only import utility as woq_utility -from typing import List, Union # isort: skip - def _gptq( - W: np.array, - H: np.array, + W: np.array, # noqa: N803 + H: np.array, # noqa: N803 num_bits: int = 4, group_size: int = 32, scheme: str = "asym", @@ -60,7 +59,6 @@ def _gptq( Returns: Q: fake quantized weight """ - Qs = [] maxq = 2**num_bits - 1 grid = 100 maxshrink = 0.8 @@ -114,8 +112,6 @@ def find_params(weight): zero = np.reshape(zero, shape) return scale, zero - scales = [] - zps = [] shape = W.shape scale, zp = find_params(W) dead = np.diag(H) == 0 @@ -125,24 +121,24 @@ def find_params(weight): # rearrange considering the diag's value if actorder: perm = np.argsort(np.diag(H))[::-1] - W = W[perm, :] - H = H[perm, :][:, perm] - Losses = np.zeros_like(W) - Q = np.zeros_like(W) + W = W[perm, :] # noqa: N806 + H = H[perm, :][:, perm] # noqa: N806 + Losses = np.zeros_like(W) # noqa: N806 + Q = np.zeros_like(W) # noqa: N806 damp = percdamp * np.mean(np.diag(H)) diag = np.arange(shape[0]) H[diag, diag] += damp # add a average value of - H = np.linalg.cholesky(np.linalg.inv(H)).T - Hinv = H + H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806 + Hinv = H # noqa: N806 for i1 in range(0, shape[0], blocksize): i2 = min(i1 + blocksize, shape[0]) count = i2 - i1 - W1 = copy.deepcopy(W[i1:i2, :]) - Q1 = np.zeros_like(W1) - Err1 = np.zeros_like(W1) - Losses1 = np.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] + W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806 + Q1 = np.zeros_like(W1) # noqa: N806 + Err1 = np.zeros_like(W1) # noqa: N806 + Losses1 = np.zeros_like(W1) # noqa: N806 + Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806 for i in range(count): # within a block, channel wise w = W1[i, :] @@ -167,17 +163,17 @@ def find_params(weight): if actorder: invperm = np.argsort(perm) - Q = Q[invperm, :] + Q = Q[invperm, :] # noqa: N806 - Q = np.reshape(Q, W.shape) + Q = np.reshape(Q, W.shape) # noqa: N806 del W return Q def gptq_quantize( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, data_reader: data_reader.CalibrationDataReader, - weight_config: dict = {}, + weight_config: dict | None = None, num_bits: int = 4, group_size: int = 32, scheme: str = "asym", @@ -187,7 +183,7 @@ def gptq_quantize( mse: bool = False, perchannel: bool = True, accuracy_level: int = 0, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, return_modelproto: bool = True, ): """Quant the model with GPTQ method. @@ -226,6 +222,10 @@ def gptq_quantize( Returns: onnx.ModelProto: quantized onnx model """ + if providers is None: + providers = ["CPUExecutionProvider"] + if weight_config is None: + weight_config = {} if not isinstance(model, onnx_model.ONNXModel): model = onnx_model.ONNXModel(model) base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" @@ -244,7 +244,7 @@ def gptq_quantize( and model.get_initializer(node.input[1]) is not None and weight_config.get((node.name, node.op_type), {}).get("weight_dtype", "fp32") != "fp32" ): - output_names.append(node.input[0]) + output_names.append(node.input[0]) # noqa: PERF401 output_names = list(set(output_names)) model.add_tensors_to_outputs(output_names) if model.is_large_model: @@ -288,21 +288,21 @@ def gptq_quantize( if len(weights) == 0: continue - Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] + Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806 nsamples = 0 for data in inputs: inp = session.run([input_name], data)[0] tmp = inp.shape[0] inp = np.reshape(inp, (-1, inp.shape[-1])) - Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] + Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806 nsamples += tmp inp = np.sqrt(2 / nsamples) * inp - Hs = [i + np.matmul(inp.T, inp) for i in Hs] + Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806 for ( node, weight, - H, + H, # noqa: N806 ) in zip(node_list, weights, Hs): if (node.name, node.op_type) in weight_config: num_bits = weight_config[(node.name, node.op_type)].get("weight_bits", 4) @@ -328,13 +328,13 @@ def gptq_quantize( weight_tensor = model.get_initializer(node.input[1]) init_share_num = model.get_initializer_share_num(node.input[1]) - satisfy_MatMulNBits_condition = Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4 - satisfy_MatMulFpQ4_condition = ( + satisfy_matmul_nbits_condition = Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4 + satisfy_matmul_fpq4_condition = ( Version(ort.__version__) >= constants.ONNXRT116_VERSION and num_bits == 4 and group_size == 32 ) - if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( + if ("CUDAExecutionProvider" in providers and satisfy_matmul_nbits_condition) or ( "CUDAExecutionProvider" not in providers - and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) + and (satisfy_matmul_fpq4_condition or satisfy_matmul_nbits_condition) ): # pragma: no cover # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP @@ -360,7 +360,7 @@ def gptq_quantize( model.add_node(q_matmul_node) else: q_weight_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}", data_type=utility.dtype_mapping[str(dtype)], dims=q_weight.shape, vals=q_weight.astype(dtype).tobytes(), @@ -388,7 +388,7 @@ def gptq_quantize( def apply_gptq_on_model( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, quant_config: dict, calibration_data_reader: data_reader.CalibrationDataReader, ) -> onnx.ModelProto: @@ -419,7 +419,7 @@ def apply_gptq_on_model( quant_func=gptq_quantize, weight_config=quant_config, data_reader=calibration_data_reader, - **quant_kwargs + **quant_kwargs, ) else: quantized_model = gptq_quantize( diff --git a/onnx_neural_compressor/algorithms/weight_only/rtn.py b/onnx_neural_compressor/algorithms/weight_only/rtn.py index 619c055e1..c4cadc456 100644 --- a/onnx_neural_compressor/algorithms/weight_only/rtn.py +++ b/onnx_neural_compressor/algorithms/weight_only/rtn.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # # Copyright (c) 2023 MIT HAN Lab # This source code is licensed under the MIT license @@ -17,6 +16,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import pathlib @@ -30,18 +30,16 @@ from onnx_neural_compressor.algorithms.layer_wise import core from onnx_neural_compressor.algorithms.weight_only import utility as woq_utility -from typing import List, Union # isort: skip - def rtn_quantize( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], - weight_config: dict = {}, + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, + weight_config: dict | None = None, num_bits: int = 4, group_size: int = 32, scheme: str = "asym", - ratios: dict = {}, + ratios: dict | None = None, accuracy_level: int = 0, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, return_modelproto: bool = True, ): """Quantize the model with round to nearst method. @@ -74,6 +72,12 @@ def rtn_quantize( Returns: onnx.ModelProto: quantized onnx model. """ + if providers is None: + providers = ["CPUExecutionProvider"] + if ratios is None: + ratios = {} + if weight_config is None: + weight_config = {} if not isinstance(model, onnx_model.ONNXModel): model = onnx_model.ONNXModel(model) base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" @@ -114,10 +118,10 @@ def rtn_quantize( weight = woq_utility.pad_tensor(weight, group_size, k_blocks) - satisfy_MatMulNBits_condition = ( + satisfy_MatMulNBits_condition = ( # noqa: N806 version.Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4 ) - satisfy_MatMulFpQ4_condition = ( + satisfy_MatMulFpQ4_condition = ( # noqa: N806 version.Version(ort.__version__) >= constants.ONNXRT116_VERSION and num_bits == 4 and group_size == 32 ) if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( @@ -152,7 +156,7 @@ def rtn_quantize( q_weight = np.transpose(q_weight) q_weight = q_weight[: org_w_shape[0], :].astype(dtype) q_weight_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}", data_type=utility.dtype_mapping[str(dtype)], dims=weight.shape, vals=q_weight.tobytes(), @@ -178,7 +182,7 @@ def rtn_quantize( def apply_rtn_on_model( - model: Union[onnx.ModelProto, onnx_model.ONNXModel, pathlib.Path, str], quant_config: dict + model: onnx.ModelProto | onnx_model.ONNXModel | pathlib.Path | str, quant_config: dict ) -> onnx.ModelProto: """Apply RTN on onnx model. diff --git a/onnx_neural_compressor/algorithms/weight_only/utility.py b/onnx_neural_compressor/algorithms/weight_only/utility.py index ddb5f990d..04cfe0f97 100644 --- a/onnx_neural_compressor/algorithms/weight_only/utility.py +++ b/onnx_neural_compressor/algorithms/weight_only/utility.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # # Copyright (c) 2023 MIT HAN Lab # This source code is licensed under the MIT license @@ -82,7 +81,7 @@ def make_matmul_weight_only_node( """ blob_size = _get_blob_size(group_size, zero_point is not None) packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") - q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)) + q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}" input_names = [node.input[0], q_weight_name] new_inits = [] kwargs = {} @@ -187,7 +186,7 @@ def make_matmul_weight_only_node( return matmul_weight_only_node, new_inits -def prepare_inputs(model, data_reader, providers): +def prepare_inputs(model, data_reader, providers): # noqa: ARG001 """Prepare inputs for weight only quantization. Args: @@ -199,7 +198,6 @@ def prepare_inputs(model, data_reader, providers): inputs: prepared inputs. so: session options """ - so = ort.SessionOptions() if sys.version_info < (3, 11) and util.find_spec("onnxruntime_extensions"): # pragma: no cover so.register_custom_ops_library(onnxruntime_extensions.get_library_path()) diff --git a/onnx_neural_compressor/config.py b/onnx_neural_compressor/config.py index b6fad923a..b9033101a 100644 --- a/onnx_neural_compressor/config.py +++ b/onnx_neural_compressor/config.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # # Copyright (c) 2023 Intel Corporation # @@ -24,6 +23,8 @@ import pathlib import re from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any, Callable, List, NamedTuple, _GenericAlias import numpy as np import onnx @@ -31,10 +32,7 @@ from onnxruntime import quantization from typing_extensions import Self -from onnx_neural_compressor import constants, data_reader, logger, utility - -from collections import OrderedDict # isort: skip -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union, _GenericAlias # isort: skip +from onnx_neural_compressor import constants, data_reader, logger class ParamLevel(enum.Enum): @@ -105,28 +103,28 @@ def is_tunable(self, value: Any) -> bool: # Use `Pydantic` to validate the input_args. # TODO: refine the implementation in further. assert isinstance(self.tunable_type, _GenericAlias), f"Expected a type hint, got {self.tunable_type} instead." - DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type) + DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type) # noqa: N806 try: - new_args = DynamicInputArgsModel(input_args=value) - return True - except Exception as e: + DynamicInputArgsModel(input_args=value) + return True # noqa: TRY300 + except Exception as e: # noqa: BLE001 logger.debug(f"Failed to validate the input_args: {e}") return False # Config registry to store all registered configs. -class ConfigRegistry(object): - registered_configs = {} +class ConfigRegistry: + registered_configs = {} # noqa: RUF012 _config_registry = None def __new__(cls) -> Self: if cls._config_registry is None: - cls._config_registry = super(ConfigRegistry, cls).__new__(cls) + cls._config_registry = super().__new__(cls) return cls._config_registry @classmethod - def register_config_impl(cls, algo_name: str, priority: Union[float, int] = 0): + def register_config_impl(cls, algo_name: str, priority: float | int = 0): """Register config decorator. The register the configuration classes for different algorithms. @@ -149,17 +147,17 @@ def decorator(config_cls): return decorator @classmethod - def get_all_configs(cls) -> Dict[str, Dict[str, Dict[str, object]]]: + def get_all_configs(cls) -> dict[str, dict[str, dict[str, object]]]: """Get all registered configurations.""" return cls.registered_configs @classmethod - def get_sorted_configs(cls) -> Dict[str, OrderedDict[str, Dict[str, object]]]: + def get_sorted_configs(cls) -> dict[str, OrderedDict[str, dict[str, object]]]: """Get registered configurations sorted by priority.""" return OrderedDict(sorted(cls.registered_configs.items(), key=lambda x: x[1]["priority"], reverse=True)) @classmethod - def get_cls_configs(cls) -> Dict[str, Dict[str, object]]: + def get_cls_configs(cls) -> dict[str, dict[str, object]]: """Get registered configurations without priority.""" cls_configs = {} for algo_name, config_data in cls.registered_configs.items(): @@ -167,17 +165,17 @@ def get_cls_configs(cls) -> Dict[str, Dict[str, object]]: return cls_configs @classmethod - def get_all_config_cls(cls) -> List[Type[BaseConfig]]: + def get_all_config_cls(cls) -> list[type[BaseConfig]]: configs_cls = [] - for algo_name, config_pairs in cls.registered_configs.items(): - configs_cls.append(config_pairs["cls"]) + for config_pairs in cls.registered_configs.values(): + configs_cls.append(config_pairs["cls"]) # noqa: PERF401 return configs_cls config_registry = ConfigRegistry() -def register_config(algo_name: str, priority: Union[float, int] = 0): +def register_config(algo_name: str, priority: float | int = 0): """Register config decorator. The register the configuration classes for different algorithms. @@ -192,7 +190,6 @@ class ExampleAlgorithmConfig: priority: the priority of the configuration. A larger number indicates a higher priority, which will be tried first at the auto-tune stage. Defaults to 0. """ - return config_registry.register_config_impl(algo_name=algo_name, priority=priority) @@ -200,17 +197,17 @@ class BaseConfig(ABC): """The base config for all algorithm configs.""" name = constants.BASE_CONFIG - params_list: List[Union[str, TuningParam]] = [] + params_list: list[str | TuningParam] = [] # noqa: RUF012 def __init__( self, - white_list: Optional[Union[Union[str, Callable], List[Union[str, Callable]]]] = constants.DEFAULT_WHITE_LIST, + white_list: str | Callable | list[str | Callable] | None = constants.DEFAULT_WHITE_LIST, ) -> None: - self._global_config: Optional[BaseConfig] = None + self._global_config: BaseConfig | None = None # For PyTorch, operator_type is the collective name for module type and functional operation type, # for example, `torch.nn.Linear`, and `torch.nn.functional.linear`. # local config is the collections of operator_type configs and operator configs - self._local_config: Dict[str, Optional[BaseConfig]] = {} + self._local_config: dict[str, BaseConfig | None] = {} self._white_list = white_list def _post_init(self): @@ -235,7 +232,7 @@ def white_list(self): return self._white_list @white_list.setter - def white_list(self, op_name_or_type_list: Optional[List[Union[str, Callable]]]): + def white_list(self, op_name_or_type_list: list[str | Callable] | None): self._white_list = op_name_or_type_list @property @@ -274,7 +271,7 @@ def to_dict(self): return result def get_params_dict(self): - result = dict() + result = {} for param, value in self.__dict__.items(): if param not in ["_global_config", "_local_config", "_white_list"]: result[param] = value @@ -302,13 +299,13 @@ def from_dict(cls, config_dict): return config @classmethod - def to_diff_dict(cls, instance) -> Dict[str, Any]: + def to_diff_dict(cls, instance) -> dict[str, Any]: # noqa: ARG003 # TODO (Yi) to implement it return {} @classmethod def from_json_file(cls, filename): - with open(filename, "r", encoding="utf-8") as file: + with open(filename, encoding="utf-8") as file: config_dict = json.load(file) return cls.from_dict(**config_dict) @@ -318,7 +315,7 @@ def to_json_file(self, filename): json.dump(config_dict, file, indent=4) logger.info("Dump the config into %s.", filename) - def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]: + def to_json_string(self, use_diff: bool = False) -> str | dict: """Serializes this instance to a JSON string. Args: @@ -335,7 +332,7 @@ def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]: config_dict = self.to_dict() try: return json.dumps(config_dict, indent=2) + "\n" - except Exception as e: + except Exception as e: # noqa: BLE001 logger.error("Failed to serialize the config to JSON string: %s", e) return config_dict @@ -348,8 +345,8 @@ def register_supported_configs(cls): """Add all supported configs.""" raise NotImplementedError - @classmethod - def validate(self, user_config: BaseConfig): + @classmethod # noqa: B027 + def validate(cls, user_config: BaseConfig): # TODO validate the user config pass @@ -370,7 +367,7 @@ def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any: parameters = signature.parameters return parameters.get(param).default - def expand(self) -> List[BaseConfig]: + def expand(self) -> list[BaseConfig]: """Expand the config. case 1 @@ -404,7 +401,7 @@ def expand(self) -> List[BaseConfig]: } -> ? """ - config_list: List[BaseConfig] = [] + config_list: list[BaseConfig] = [] params_list = self.params_list config = self tuning_param_list = [] @@ -420,7 +417,7 @@ def expand(self) -> List[BaseConfig]: elif isinstance(param, TuningParam): tuning_param = param else: - raise ValueError(f"Unsupported param type: {param}") + raise ValueError(f"Unsupported param type: {param}") # noqa: TRY004 # Assign the options to the `tuning.TuningParam` instance param_val = getattr(config, tuning_param.name) if param_val is not None: @@ -445,8 +442,8 @@ def expand(self) -> List[BaseConfig]: return config_list def _get_op_name_op_type_config(self): - op_type_config_dict = dict() - op_name_config_dict = dict() + op_type_config_dict = {} + op_name_config_dict = {} for name, config in self.local_config.items(): if self._is_op_type(name): op_type_config_dict[name] = config @@ -455,8 +452,8 @@ def _get_op_name_op_type_config(self): return op_type_config_dict, op_name_config_dict def to_config_mapping( - self, config_list: Optional[List[BaseConfig]] = None, model_info: List[Tuple[str, str]] = None - ) -> OrderedDict[Tuple[str, str], OrderedDict[str, BaseConfig]]: + self, config_list: list[BaseConfig] | None = None, model_info: list[tuple[str, str]] | None = None + ) -> OrderedDict[tuple[str, str], OrderedDict[str, BaseConfig]]: config_mapping = OrderedDict() if config_list is None: config_list = [self] @@ -489,7 +486,7 @@ def get_config_set_for_tuning(cls): class ComposableConfig(BaseConfig): name = constants.COMPOSABLE_CONFIG - def __init__(self, configs: List[BaseConfig]) -> None: + def __init__(self, configs: list[BaseConfig]) -> None: self.config_list = configs def __add__(self, other: BaseConfig) -> BaseConfig: @@ -506,7 +503,7 @@ def to_dict(self): return result @classmethod - def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[str, BaseConfig]): + def from_dict(cls, config_dict: OrderedDict[str, dict], config_registry: dict[str, BaseConfig]): assert len(config_dict) >= 1, "The config dict must include at least one configuration." num_configs = len(config_dict) name, value = next(iter(config_dict.items())) @@ -516,14 +513,14 @@ def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[st config += config_registry[name].from_dict(value) return config - def to_json_string(self, use_diff: bool = False) -> str: + def to_json_string(self, use_diff: bool = False) -> str: # noqa: ARG002 return json.dumps(self.to_dict(), indent=2) + "\n" def __repr__(self) -> str: return f"{self.__class__.__name__} {self.to_json_string()}" def to_config_mapping( - self, config_list: List[BaseConfig] = None, model_info: Dict[str, Any] = None + self, config_list: list[BaseConfig] | None = None, model_info: dict[str, Any] | None = None # noqa: ARG002 ) -> OrderedDict[str, BaseConfig]: config_mapping = OrderedDict() for config in self.config_list: @@ -548,31 +545,31 @@ def get_config_set_for_tuning(cls) -> None: return None def get_model_info(self, model, *args, **kwargs): - model_info_dict = dict() + model_info_dict = {} for config in self.config_list: model_info_dict.update({config.name: config.get_model_info(model, *args, **kwargs)}) return model_info_dict -def get_all_config_set_from_config_registry() -> List[BaseConfig]: - all_registered_config_cls: List[Type[BaseConfig]] = config_registry.get_all_config_cls() +def get_all_config_set_from_config_registry() -> list[BaseConfig]: + all_registered_config_cls: list[type[BaseConfig]] = config_registry.get_all_config_cls() config_set = [] for config_cls in all_registered_config_cls: - config_set.append(config_cls.get_config_set_for_tuning()) + config_set.append(config_cls.get_config_set_for_tuning()) # noqa: PERF401 return config_set def register_supported_configs(): """Register supported configs.""" - all_registered_config_cls: List[Type[BaseConfig]] = config_registry.get_all_config_cls() + all_registered_config_cls: list[type[BaseConfig]] = config_registry.get_all_config_cls() for config_cls in all_registered_config_cls: config_cls.register_supported_configs() class _OperatorConfig(NamedTuple): config: BaseConfig - operators: List[Union[str, Callable]] - valid_func_list: List[Callable] = [] + operators: list[str | Callable] + valid_func_list: list[Callable] = [] # noqa: RUF012 ######################## RNT Config ############################### @@ -582,8 +579,8 @@ class _OperatorConfig(NamedTuple): class RTNConfig(BaseConfig): """Config class for round-to-nearest weight-only quantization.""" - supported_configs: List[_OperatorConfig] = [] - params_list: List[Union[str, TuningParam]] = [ + supported_configs: list[_OperatorConfig] = [] # noqa: RUF012 + params_list: list[str | TuningParam] = [ # noqa: RUF012 "weight_dtype", "weight_bits", "weight_group_size", @@ -592,7 +589,7 @@ class RTNConfig(BaseConfig): "accuracy_level", "ratios", ] - model_params_list: List[str] = [ + model_params_list: list[str] = [ # noqa: RUF012 "providers", "layer_wise_quant", ] @@ -606,11 +603,11 @@ def __init__( weight_sym: bool = True, act_dtype: str = "fp32", accuracy_level: int = 0, - ratios: dict = {}, - providers: List[str] = ["CPUExecutionProvider"], + ratios: dict | None = None, + providers: list[str] | None = None, layer_wise_quant: bool = False, quant_last_matmul: bool = True, - white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST, + white_list: list[str | Callable] = constants.DEFAULT_WHITE_LIST, ): """Init RTN weight-only quantization config. @@ -633,6 +630,10 @@ def __init__( white_list (list, optional): op in white_list will be applied current config. Defaults to constants.DEFAULT_WHITE_LIST. """ + if providers is None: + providers = ["CPUExecutionProvider"] + if ratios is None: + ratios = {} super().__init__(white_list=white_list) self.weight_bits = weight_bits self.weight_dtype = weight_dtype @@ -647,7 +648,7 @@ def __init__( self._post_init() def get_model_params_dict(self): - result = dict() + result = {} for param in self.model_params_list: result[param] = getattr(self, param) return result @@ -666,7 +667,7 @@ def register_supported_configs(cls) -> None: supported_configs.append(_OperatorConfig(config=linear_rtn_config, operators=operators)) cls.supported_configs = supported_configs - def to_config_mapping(self, config_list: List[BaseConfig] = None, model_info: list = None): + def to_config_mapping(self, config_list: list[BaseConfig] | None = None, model_info: list | None = None): config_mapping = OrderedDict() if config_list is None: config_list = [self] @@ -693,7 +694,7 @@ def to_config_mapping(self, config_list: List[BaseConfig] = None, model_info: li return config_mapping @staticmethod - def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str]) -> list: + def get_model_info(model: onnx.ModelProto | pathlib.Path | str) -> list: if not isinstance(model, onnx.ModelProto): model = onnx.load(model, load_external_data=False) white_list = ["MatMul"] @@ -706,7 +707,7 @@ def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str]) -> list: return filter_result @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: # pragma: no cover + def get_config_set_for_tuning(cls) -> None | RTNConfig | list[RTNConfig]: # pragma: no cover return RTNConfig(weight_bits=[4, 8], weight_sym=[True, False]) @@ -726,8 +727,8 @@ def get_default_rtn_config() -> RTNConfig: class GPTQConfig(BaseConfig): """Config class for gptq weight-only quantization.""" - supported_configs: List[_OperatorConfig] = [] - params_list: List[Union[str, TuningParam]] = [ + supported_configs: list[_OperatorConfig] = [] # noqa: RUF012 + params_list: list[str | TuningParam] = [ # noqa: RUF012 "weight_dtype", "weight_bits", "weight_group_size", @@ -735,7 +736,7 @@ class GPTQConfig(BaseConfig): "act_dtype", "accuracy_level", ] - model_params_list: List[Union[str, TuningParam]] = [ + model_params_list: list[str | TuningParam] = [ # noqa: RUF012 "percdamp", "blocksize", "actorder", @@ -759,10 +760,10 @@ def __init__( actorder: bool = False, mse: bool = False, perchannel: bool = True, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, layer_wise_quant: bool = False, quant_last_matmul: bool = True, - white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST, + white_list: list[str | Callable] = constants.DEFAULT_WHITE_LIST, ): """Init GPTQ weight-only quantization config. @@ -791,6 +792,8 @@ def __init__( white_list (list, optional): op in white_list will be applied current config. Defaults to constants.DEFAULT_WHITE_LIST. """ + if providers is None: + providers = ["CPUExecutionProvider"] super().__init__(white_list=white_list) self.weight_bits = weight_bits self.weight_dtype = weight_dtype @@ -809,7 +812,7 @@ def __init__( self._post_init() def get_model_params_dict(self): - result = dict() + result = {} for param in self.model_params_list: result[param] = getattr(self, param) return result @@ -831,7 +834,7 @@ def register_supported_configs(cls) -> None: supported_configs.append(_OperatorConfig(config=linear_gptq_config, operators=operators)) cls.supported_configs = supported_configs - def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: + def to_config_mapping(self, config_list: list | None = None, model_info: list | None = None) -> OrderedDict: config_mapping = OrderedDict() if config_list is None: config_list = [self] @@ -858,7 +861,7 @@ def to_config_mapping(self, config_list: list = None, model_info: list = None) - return config_mapping @staticmethod - def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str]) -> list: + def get_model_info(model: onnx.ModelProto | pathlib.Path | str) -> list: if not isinstance(model, onnx.ModelProto): model = onnx.load(model, load_external_data=False) white_list = ["MatMul"] @@ -871,7 +874,7 @@ def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str]) -> list: return filter_result @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]: # pragma: no cover + def get_config_set_for_tuning(cls) -> None | GPTQConfig | list[GPTQConfig]: # pragma: no cover return GPTQConfig( weight_bits=[4, 8], weight_sym=[True, False], @@ -897,8 +900,8 @@ def get_default_gptq_config() -> GPTQConfig: class AWQConfig(BaseConfig): """Config class for awq weight-only quantization.""" - supported_configs: List[_OperatorConfig] = [] - params_list: List[str] = [ + supported_configs: list[_OperatorConfig] = [] # noqa: RUF012 + params_list: list[str] = [ # noqa: RUF012 "weight_dtype", "weight_bits", "weight_group_size", @@ -906,7 +909,7 @@ class AWQConfig(BaseConfig): "act_dtype", "accuracy_level", ] - model_params_list: List[str] = [ + model_params_list: list[str] = [ # noqa: RUF012 "enable_auto_scale", "enable_mse_search", "providers", @@ -923,9 +926,9 @@ def __init__( accuracy_level: int = 0, enable_auto_scale: bool = True, enable_mse_search: bool = True, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, quant_last_matmul: bool = True, - white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST, + white_list: list[str | Callable] = constants.DEFAULT_WHITE_LIST, ): """Init AWQ weight-only quantization config. @@ -947,6 +950,8 @@ def __init__( white_list (list, optional): op in white_list will be applied current config. Defaults to constants.DEFAULT_WHITE_LIST. """ + if providers is None: + providers = ["CPUExecutionProvider"] super().__init__(white_list=white_list) self.weight_bits = weight_bits self.weight_dtype = weight_dtype @@ -961,13 +966,13 @@ def __init__( self._post_init() def get_model_params_dict(self): - result = dict() + result = {} for param in self.model_params_list: result[param] = getattr(self, param) return result @classmethod - def register_supported_configs(cls) -> List[_OperatorConfig]: + def register_supported_configs(cls) -> list[_OperatorConfig]: supported_configs = [] linear_awq_config = AWQConfig( weight_dtype=["int"], @@ -982,7 +987,7 @@ def register_supported_configs(cls) -> List[_OperatorConfig]: supported_configs.append(_OperatorConfig(config=linear_awq_config, operators=operators)) cls.supported_configs = supported_configs - def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: + def to_config_mapping(self, config_list: list | None = None, model_info: list | None = None) -> OrderedDict: config_mapping = OrderedDict() if config_list is None: config_list = [self] @@ -1009,7 +1014,7 @@ def to_config_mapping(self, config_list: list = None, model_info: list = None) - return config_mapping @staticmethod - def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str]) -> list: + def get_model_info(model: onnx.ModelProto | pathlib.Path | str) -> list: if not isinstance(model, onnx.ModelProto): model = onnx.load(model, load_external_data=False) white_list = ["MatMul"] @@ -1022,7 +1027,7 @@ def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str]) -> list: return filter_result @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "AWQConfig", List["AWQConfig"]]: # pragma: no cover + def get_config_set_for_tuning(cls) -> None | AWQConfig | list[AWQConfig]: # pragma: no cover return AWQConfig( weight_bits=[4, 8], weight_sym=[True, False], @@ -1047,8 +1052,8 @@ def get_default_awq_config() -> AWQConfig: class SmoothQuantConfig(BaseConfig, quantization.StaticQuantConfig): """Smooth quant quantization config.""" - supported_configs: List[_OperatorConfig] = [] - params_list: List[str] = [ + supported_configs: list[_OperatorConfig] = [] # noqa: RUF012 + params_list: list[str] = [ # noqa: RUF012 # smooth parameters "alpha", "folding", @@ -1062,12 +1067,12 @@ def __init__( self, alpha: float = 0.5, folding: bool = True, - op_types: List[str] = ["Gemm", "Conv", "MatMul", "FusedConv"], + op_types: list[str] | None = None, calib_iter: int = 100, scales_per_op: bool = True, - auto_alpha_args: dict = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"}, - providers: List[str] = ["CPUExecutionProvider"], - white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST, + auto_alpha_args: dict | None = None, + providers: list[str] | None = None, + white_list: list[str | Callable] = constants.DEFAULT_WHITE_LIST, **kwargs, ): """Init smooth quant config. @@ -1091,6 +1096,12 @@ def __init__( kwargs (dict): kwargs in below link are supported except calibration_data_reader: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/quantize.py#L78 """ + if providers is None: + providers = ["CPUExecutionProvider"] + if auto_alpha_args is None: + auto_alpha_args = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"} + if op_types is None: + op_types = ["Gemm", "Conv", "MatMul", "FusedConv"] BaseConfig.__init__(self) kwargs.update({"calibration_data_reader": None}) quantization.StaticQuantConfig.__init__(self, **kwargs) @@ -1113,7 +1124,7 @@ def __init__( self._post_init() @classmethod - def register_supported_configs(cls) -> List[_OperatorConfig]: + def register_supported_configs(cls) -> list[_OperatorConfig]: supported_configs = [] smooth_quant_config = SmoothQuantConfig() operators = ["Gemm", "Conv", "MatMul", "FusedConv"] @@ -1134,7 +1145,7 @@ def get_model_info(model) -> list: @classmethod def get_config_set_for_tuning( cls, - ) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]: # pragma: no cover + ) -> None | SmoothQuantConfig | list[SmoothQuantConfig]: # pragma: no cover return SmoothQuantConfig(alpha=np.arange(0.3, 0.7, 0.05)) def convert_to_ort_config(self): @@ -1163,11 +1174,11 @@ def get_woq_tuning_config() -> list: Returns: the list of WOQ quant config. """ - RTN_G32ASYM = RTNConfig(weight_sym=False) - GPTQ_G32ASYM = GPTQConfig(weight_sym=False) - GPTQ_G32ASYM_DISABLE_LAST_MATMUL = GPTQConfig(weight_sym=False, quant_last_matmul=False) - GPTQ_G128ASYM = GPTQConfig(weight_group_size=128, weight_sym=False) - AWQ_G32ASYM = AWQConfig(weight_sym=False) + RTN_G32ASYM = RTNConfig(weight_sym=False) # noqa: N806 + GPTQ_G32ASYM = GPTQConfig(weight_sym=False) # noqa: N806 + GPTQ_G32ASYM_DISABLE_LAST_MATMUL = GPTQConfig(weight_sym=False, quant_last_matmul=False) # noqa: N806 + GPTQ_G128ASYM = GPTQConfig(weight_group_size=128, weight_sym=False) # noqa: N806 + AWQ_G32ASYM = AWQConfig(weight_sym=False) # noqa: N806 return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_MATMUL, GPTQ_G128ASYM, AWQ_G32ASYM] @@ -1207,7 +1218,9 @@ def __init__(self, calibration_data_reader: data_reader.CalibrationDataReader, e If enabled, each op will have an individual scale, mainlyfor accuracy. If not enabled, ops with the same input will share a scale, mainly for performance. """ - super().__init__(calibration_data_reader=calibration_data_reader, extra_options=extra_options, *args, **kwargs) + super().__init__( + calibration_data_reader=calibration_data_reader, extra_options=extra_options, *args, **kwargs # noqa: B026 + ) def to_dict(self): return self.__dict__ diff --git a/onnx_neural_compressor/constants.py b/onnx_neural_compressor/constants.py index d2e0391c6..6ec76bd30 100644 --- a/onnx_neural_compressor/constants.py +++ b/onnx_neural_compressor/constants.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2023 Intel Corporation # diff --git a/onnx_neural_compressor/data_reader.py b/onnx_neural_compressor/data_reader.py index 24538ce55..a85a70d4e 100644 --- a/onnx_neural_compressor/data_reader.py +++ b/onnx_neural_compressor/data_reader.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # # Copyright (c) 2023 Intel Corporation # diff --git a/onnx_neural_compressor/logger.py b/onnx_neural_compressor/logger.py index 9637b0350..e10b45dbf 100644 --- a/onnx_neural_compressor/logger.py +++ b/onnx_neural_compressor/logger.py @@ -37,7 +37,7 @@ def _pretty_dict(value, indent=0): _logger.handlers.clear() _logger.setLevel(LOGLEVEL) formatter = logging.Formatter("%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d] %(message)s", "%Y-%m-%d %H:%M:%S") -streamHandler = logging.StreamHandler() +streamHandler = logging.StreamHandler() # noqa: N816 streamHandler.setFormatter(formatter) _logger.addHandler(streamHandler) _logger.propagate = False diff --git a/onnx_neural_compressor/onnx_model.py b/onnx_neural_compressor/onnx_model.py index 061f7cad8..11d58e531 100644 --- a/onnx_neural_compressor/onnx_model.py +++ b/onnx_neural_compressor/onnx_model.py @@ -29,7 +29,7 @@ class ONNXModel(onnx_model.ONNXModel): """Build ONNX model.""" - def __init__(self, model, **kwargs): + def __init__(self, model, **kwargs): # noqa: D417 """Initialize an ONNX model. Args: @@ -84,7 +84,7 @@ def check_is_large_model(self): self._is_large_model = True return else: # pragma: no cover - raise e + raise if init_size > constants.MAXIMUM_PROTOBUF: self._is_large_model = True return @@ -163,8 +163,8 @@ def save(self, root): onnx.save(self.model, root) if self._config is not None: - model_type = "" if not hasattr(self._config, "model_type") else getattr(self._config, "model_type") - setattr(self._config.__class__, "model_type", model_type) + model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type + self._config.__class__.model_type = model_type output_config_file = pathlib.Path(root).parent.joinpath("config.json").as_posix() self._config.to_json_file(output_config_file, use_diff=False) @@ -194,7 +194,7 @@ def get_node_by_weight(self, weight_name): if len(nodes) == 1: return nodes[0] elif len(nodes) == 0: - raise ValueError("{} is not used by any node in this model.".format(weight_name)) + raise ValueError(f"{weight_name} is not used by any node in this model.") else: raise NotImplementedError("Models with shared weights is not supported.") @@ -217,13 +217,13 @@ def get_siblings(self, node): for parent in self.get_parents(node): for child in self.get_children(parent): if child.name != node.name: - siblings.append(child) + siblings.append(child) # noqa: PERF401 return siblings def get_scale_zero(self, tensor): """Help function to get scale and zero_point.""" if not tensor.endswith("_quantized"): - logger.debug("Find {} in the quantized graph is not quantized.".format(tensor)) + logger.debug(f"Find {tensor} in the quantized graph is not quantized.") return None, None if len(self._input_name_to_nodes) == 0: @@ -234,7 +234,7 @@ def get_scale_zero(self, tensor): def _searcher(tensor_name): """Search scale and zero point tensor recursively.""" node = self._input_name_to_nodes[tensor_name][0] - parent = self._output_name_to_node[tensor_name] if tensor_name in self._output_name_to_node else None + parent = self._output_name_to_node.get(tensor_name, None) direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"] if parent is not None and parent.op_type in direct_int8: fp32_tensor_name = ( @@ -273,12 +273,16 @@ def _searcher(tensor_name): return None, None else: scale_tensor, zo_tensor = _searcher(tensor) - assert scale_tensor, "missing scale for tensor {}".format(tensor) - assert zo_tensor, "missing zero point for tensor {}".format(tensor) + assert scale_tensor, f"missing scale for tensor {tensor}" + assert zo_tensor, f"missing zero point for tensor {tensor}" return scale_tensor, zo_tensor - def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=[], black_optype=[]): + def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None): """Replace inputs of all nodes.""" + if black_optype is None: + black_optype = [] + if white_optype is None: + white_optype = [] if len(white_optype) > 0: for node in self.model.graph.node: if node.op_type in white_optype: @@ -288,8 +292,12 @@ def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optyp if node.op_type not in black_optype: ONNXModel.replace_node_input(node, old_input_name, new_input_name) - def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=[], black_optype=[]): + def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None): """Replace outputs of all nodes.""" + if black_optype is None: + black_optype = [] + if white_optype is None: + white_optype = [] if len(white_optype) > 0: for node in self.model.graph.node: if node.op_type in white_optype: @@ -330,7 +338,7 @@ def remove_unused_nodes(self): if output in self._input_name_to_nodes or output in self.output(): unused = False break - for input in node.input: + for input in node.input: # noqa: A001 if self.get_initializer(input) is not None: continue elif input in self._output_name_to_node or input in self.input(): @@ -381,12 +389,12 @@ def topological_sort(self, enable_subgraph=False): for inp in self.model.graph.input: q.extend(input_name_to_nodes[inp.name]) for n in self.model.graph.node: - if all([i not in output_name_to_node and i not in self.input() for i in n.input]): + if all(i not in output_name_to_node and i not in self.input() for i in n.input): q.append(n) while q: n = q.popleft() - if not all([output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node]): + if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node): if n not in wait: wait.append(n) continue @@ -399,13 +407,15 @@ def topological_sort(self, enable_subgraph=False): q = copy.deepcopy(wait) wait.clear() nodes = [i[1] for i in all_nodes.items()] - assert len(list(set([n.name for n in nodes]))) == len(list(set([n.name for n in self.model.graph.node]))) + assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node})) self.model.graph.ClearField("node") self.model.graph.node.extend(nodes) - def get_nodes_chain(self, start, stop, result_chain=[]): + def get_nodes_chain(self, start, stop, result_chain=None): """Get nodes chain with given start node and stop node.""" # process start node list + if result_chain is None: + result_chain = [] start_node = collections.deque() for node in start: if isinstance(node, str): @@ -413,7 +423,7 @@ def get_nodes_chain(self, start, stop, result_chain=[]): elif isinstance(node, onnx.NodeProto): start_node.append(node.name) else: - assert False, "'get_nodes_chain' function only support list[string]" "or list[NodeProto] params" + raise TypeError("'get_nodes_chain' function only support list[string]or list[NodeProto] params") # process stop node list stop_node = [] @@ -423,7 +433,7 @@ def get_nodes_chain(self, start, stop, result_chain=[]): elif isinstance(node, onnx.NodeProto): stop_node.append(node.name) else: - assert False, "'get_nodes_chain' function only support list[string]" "or list[NodeProto] params" + raise TypeError("'get_nodes_chain' function only support list[string]or list[NodeProto] params") while start_node: node_name = start_node.popleft() @@ -585,7 +595,7 @@ def find_qkv_in_attention(self, find_all=False): continue qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1] other_inputs = [] - for input in start_node.input: + for input in start_node.input: # noqa: A001 if input not in self._output_name_to_node: continue if input == qkv_nodes[0].output[0]: @@ -650,11 +660,11 @@ def remove_tensors_from_outputs(self, tensor_names): removed_outputs = [] for tensor in tensor_names: if tensor in self.output(): - removed_outputs.append(self.model.graph.output[self.output().index(tensor)]) + removed_outputs.append(self.model.graph.output[self.output().index(tensor)]) # noqa: PERF401 for output in removed_outputs: self.model.graph.output.remove(output) - def match_first_parent(self, node, parent_op_type, output_name_to_node_dict, exclude=[]): + def match_first_parent(self, node, parent_op_type, output_name_to_node_dict, exclude=None): # noqa: D417 """Find parent node based on constraints on op_type. Args: @@ -667,20 +677,22 @@ def match_first_parent(self, node, parent_op_type, output_name_to_node_dict, exc parent: The matched parent node. None if not found. index: The input index of matched parent node. None if not found. """ - for i, input in enumerate(node.input): + if exclude is None: + exclude = [] + for i, input in enumerate(node.input): # noqa: A001 if input in output_name_to_node_dict: parent = output_name_to_node_dict[input] if parent.op_type == parent_op_type and parent not in exclude: return parent, i return None, None - def match_parent( + def match_parent( # noqa: D417 self, node, parent_op_type, input_index=None, output_name_to_node_dict=None, - exclude=[], + exclude=None, return_indice=None, ): """Find parent node based on constraints on op_type and index. @@ -696,6 +708,8 @@ def match_parent( Returns: parent: The matched parent node. """ + if exclude is None: + exclude = [] assert node is not None assert input_index is None or input_index >= 0 @@ -719,7 +733,7 @@ def match_parent( return None - def match_parent_path( + def match_parent_path( # noqa: D417 self, node, parent_op_types, @@ -773,10 +787,7 @@ def is_smoothquant_model(self): Returns: bool: the model is smooth quantized or not. """ - for init in self.model.graph.initializer: - if "_smooth_scale" in init.name: - return True - return False + return any("_smooth_scale" in init.name for init in self.model.graph.initializer) def find_split_nodes(self): """Find split nodes for layer-wise quantization.""" @@ -813,7 +824,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo unvalid_nodes = [ i for i in self.model.graph.node - if all([out not in self._input_name_to_nodes and not self.is_graph_output(out) for out in i.output]) + if all(out not in self._input_name_to_nodes and not self.is_graph_output(out) for out in i.output) ] self.topological_sort() @@ -839,7 +850,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo assert len(split_node_output) == 1, ( "Only support split at node with 1 output tensor, while " - "current split node {} has {} output tensors".format(split_node_name, len(split_node_output)) + f"current split node {split_node_name} has {len(split_node_output)} output tensors" ) split_tensor_name = split_node_output[0] @@ -858,8 +869,8 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo insert_output_for_model_1 = [] insert_input_for_model_2 = [] - for output in split_model_part_1._output_name_to_node.keys(): - if output in split_model_part_2._input_name_to_nodes.keys(): + for output in split_model_part_1._output_name_to_node: + if output in split_model_part_2._input_name_to_nodes: output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) if output_tensor not in split_model_part_1.model.graph.output: @@ -872,7 +883,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo split_model_part_1.model.graph.output.append(output) # insert model 2 input - for input in insert_input_for_model_2: + for input in insert_input_for_model_2: # noqa: A001 split_model_part_2.model.graph.input.append(input) # remove unused init @@ -889,7 +900,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo split_model_part_1.model_path = split_model_part_1_path split_model_part_1._save_split_model(split_model_part_1_path) split_model_part_1.check_is_large_model() - logger.debug("save split model part 1 to {} for layer wise quantization".format(split_model_part_1_path)) + logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization") if save_both_split_models: split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) @@ -897,7 +908,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo split_model_part_2.model_path = split_model_part_2_path split_model_part_2._save_split_model(split_model_part_2_path) split_model_part_2.check_is_large_model() - logger.debug("save split model part 2 to {} for layer wise quantization".format(split_model_part_2_path)) + logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization") return split_model_part_1, split_model_part_2 else: return split_model_part_1, split_model_part_2 @@ -947,16 +958,16 @@ def _remove_unused_input_output(self): if len(self._input_name_to_nodes) == 0: self._input_name_to_nodes = self.input_name_to_nodes() for output in self.model.graph.output: - if output.name not in self._output_name_to_node.keys(): - remove_outputs.append(output) + if output.name not in self._output_name_to_node: + remove_outputs.append(output) # noqa: PERF401 - for input in self.model.graph.input: - if input.name not in self._input_name_to_nodes.keys(): - remove_inputs.append(input) + for input in self.model.graph.input: # noqa: A001 + if input.name not in self._input_name_to_nodes: + remove_inputs.append(input) # noqa: PERF401 for output in remove_outputs: self.model.graph.output.remove(output) - for input in remove_inputs: + for input in remove_inputs: # noqa: A001 self.model.graph.input.remove(input) def remove_unused_init(self): @@ -965,8 +976,8 @@ def remove_unused_init(self): if len(self._input_name_to_nodes) == 0: self._input_name_to_nodes = self.input_name_to_nodes() for init in self.model.graph.initializer: - if init.name not in self._input_name_to_nodes.keys(): - remov_inits.append(init) + if init.name not in self._input_name_to_nodes: + remov_inits.append(init) # noqa: PERF401 self.remove_initializers(remov_inits) def load_model_initializer_by_tensor(self, data_path=None): @@ -999,8 +1010,8 @@ def write_external_data_to_new_location(self, external_data_location="external.d def merge_split_models(self, to_merge_model): """Merge two split model into final model.""" to_merge_model.write_external_data_to_new_location() - self.add_nodes([node for node in to_merge_model.nodes()]) - self.add_initializers([init for init in to_merge_model.initializer()]) + self.add_nodes(list(to_merge_model.nodes())) + self.add_initializers(list(to_merge_model.initializer())) self.update() # add new output @@ -1012,16 +1023,16 @@ def merge_split_models(self, to_merge_model): remove_output = [] for output in self.model.graph.output: if output.name in to_merge_model.input(): - remove_output.append(output) + remove_output.append(output) # noqa: PERF401 for output in remove_output: self.model.graph.output.remove(output) # add new input - for input in to_merge_model.graph().input: + for input in to_merge_model.graph().input: # noqa: A001 if ( input.name not in self.input() and input.name not in self.output() - and input.name not in self._output_name_to_node.keys() + and input.name not in self._output_name_to_node ): self.model.graph.input.append(input) diff --git a/onnx_neural_compressor/quantization/__init__.py b/onnx_neural_compressor/quantization/__init__.py index 7ef91659a..5f70f9635 100644 --- a/onnx_neural_compressor/quantization/__init__.py +++ b/onnx_neural_compressor/quantization/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. -from onnxruntime.quantization.quant_utils import QuantFormat, QuantType +from onnxruntime.quantization.quant_utils import QuantFormat, QuantType # noqa: F401 -from onnx_neural_compressor.quantization.quantize import quantize +from onnx_neural_compressor.quantization.quantize import quantize # noqa: F401 diff --git a/onnx_neural_compressor/quantization/algorithm_entry.py b/onnx_neural_compressor/quantization/algorithm_entry.py index cd079932c..ed8a5d652 100644 --- a/onnx_neural_compressor/quantization/algorithm_entry.py +++ b/onnx_neural_compressor/quantization/algorithm_entry.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import pathlib import tempfile -from typing import Union import onnx from onnxruntime import quantization @@ -27,13 +27,13 @@ ###################### SmoothQuant Entry ################################## @utility.register_algo(name=constants.SMOOTH_QUANT) def smooth_quant_entry( - model: Union[pathlib.Path, str], + model: pathlib.Path | str, quant_config: config.SmoothQuantConfig, calibration_data_reader: data_reader.CalibrationDataReader, - model_output: Union[pathlib.Path, str] = None, - *args, - **kwargs -) -> Union[pathlib.Path, str, onnx.ModelProto]: + model_output: pathlib.Path | str | None = None, + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 +) -> pathlib.Path | str | onnx.ModelProto: """Apply smooth quant.""" assert calibration_data_reader is not None, "Please provide calibration_data_reader" assert isinstance( @@ -80,10 +80,9 @@ def smooth_quant_entry( ###################### RTN Algo Entry ################################## @utility.register_algo(name=constants.RTN) -def rtn_quantize_entry( - model: Union[pathlib.Path, str], quant_config: config.RTNConfig, *args, **kwargs -) -> onnx.ModelProto: +def rtn_quantize_entry(model: pathlib.Path | str, quant_config: config.RTNConfig, *args, **kwargs) -> onnx.ModelProto: """The main entry to apply rtn quantization.""" + del args, kwargs # unused # map config to each op model_info = quant_config.get_model_info(model=model) configs_mapping = quant_config.to_config_mapping(model_info=model_info) @@ -95,11 +94,11 @@ def rtn_quantize_entry( ###################### GPTQ Algo Entry ################################## @utility.register_algo(name=constants.GPTQ) def gptq_quantize_entry( - model: Union[pathlib.Path, str], + model: pathlib.Path | str, quant_config: config.GPTQConfig, calibration_data_reader: data_reader.CalibrationDataReader, - *args, - **kwargs + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 ) -> onnx.ModelProto: """The main entry to apply gptq quantization.""" assert calibration_data_reader is not None, "Please provide calibration_data_reader" @@ -121,11 +120,11 @@ def gptq_quantize_entry( ###################### AWQ Algo Entry ################################## @utility.register_algo(name=constants.AWQ) def awq_quantize_entry( - model: Union[pathlib.Path, str], + model: pathlib.Path | str, quant_config: config.AWQConfig, calibration_data_reader: data_reader.CalibrationDataReader, - *args, - **kwargs + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 ) -> onnx.ModelProto: """The main entry to apply awq quantization.""" assert calibration_data_reader is not None, "Please provide calibration_data_reader" diff --git a/onnx_neural_compressor/quantization/calibrate.py b/onnx_neural_compressor/quantization/calibrate.py index 37bf7d671..1b168513a 100644 --- a/onnx_neural_compressor/quantization/calibrate.py +++ b/onnx_neural_compressor/quantization/calibrate.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2023 Intel Corporation # diff --git a/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py b/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py index 62a671fba..0a97e1be4 100644 --- a/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py +++ b/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import List, Union # isort: skip +from __future__ import annotations import onnx from onnxruntime.quantization import matmul_4bits_quantizer @@ -28,14 +27,16 @@ class MatMul4BitsQuantizer(matmul_nbits_quantizer.MatMulNBitsQuantizer): def __init__( self, - model: Union[onnx.ModelProto, str], + model: onnx.ModelProto | str, block_size: int = 128, is_symmetric: bool = False, accuracy_level: int = 0, nodes_to_exclude=None, algo_config: matmul_4bits_quantizer.WeightOnlyQuantConfig = None, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, ): + if providers is None: + providers = ["CPUExecutionProvider"] super().__init__( model=model, block_size=block_size, diff --git a/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py b/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py index 0d00bbbc5..7af294467 100644 --- a/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py +++ b/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import List, Union # isort: skip +from __future__ import annotations import onnx from onnxruntime.quantization import matmul_4bits_quantizer -from onnx_neural_compressor import config, data_reader, logger, onnx_model, utility +from onnx_neural_compressor import config, data_reader, logger, onnx_model from onnx_neural_compressor.quantization import algorithm_entry as algos @@ -76,15 +75,17 @@ class MatMulNBitsQuantizer: def __init__( self, - model: Union[onnx.ModelProto, str], + model: onnx.ModelProto | str, block_size: int = 128, is_symmetric: bool = False, accuracy_level: int = 0, - nodes_to_exclude: List[str] = None, + nodes_to_exclude: list[str] | None = None, algo_config: matmul_4bits_quantizer.WeightOnlyQuantConfig = None, n_bits: int = 4, - providers: List[str] = ["CPUExecutionProvider"], + providers: list[str] | None = None, ): + if providers is None: + providers = ["CPUExecutionProvider"] if nodes_to_exclude is None: nodes_to_exclude = [] self.model_path = model if isinstance(model, str) else None @@ -102,7 +103,7 @@ def __init__( "RTN", "AWQ", "GPTQ", - ], "Only RTN, GPTQ and AWQ algorithms are supported, but get {} algorithm".format(self.algorithm) + ], f"Only RTN, GPTQ and AWQ algorithms are supported, but get {self.algorithm} algorithm" def _generate_nc_config(self): config_class = config.config_registry.get_cls_configs()[self.algorithm.lower()] diff --git a/onnx_neural_compressor/quantization/quantize.py b/onnx_neural_compressor/quantization/quantize.py index 7e388e3aa..090e65c6a 100644 --- a/onnx_neural_compressor/quantization/quantize.py +++ b/onnx_neural_compressor/quantization/quantize.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import pathlib -from typing import Union import onnx from onnxruntime.quantization.quantize import QuantConfig @@ -24,8 +24,8 @@ # ORT-like user-facing API def quantize( - model_input: Union[str, pathlib.Path, onnx.ModelProto], - model_output: Union[str, pathlib.Path], + model_input: str | pathlib.Path | onnx.ModelProto, + model_output: str | pathlib.Path, quant_config: QuantConfig, ): if isinstance(quant_config, config.StaticQuantConfig): diff --git a/onnx_neural_compressor/quantization/tuning.py b/onnx_neural_compressor/quantization/tuning.py index a6743ad7a..c1e5677e9 100644 --- a/onnx_neural_compressor/quantization/tuning.py +++ b/onnx_neural_compressor/quantization/tuning.py @@ -11,19 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import os import pathlib +import sys import tempfile import uuid +from typing import Any, Callable, Dict, Generator, Iterator, List, Sized import onnx from onnx_neural_compressor import config, data_reader, logger, utility -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union # isort: skip - class EvaluationFuncWrapper: @@ -37,7 +38,7 @@ def __init__(self, eval_fn: Callable, eval_args=None): self.eval_fn = eval_fn self.eval_args = eval_args - def evaluate(self, model) -> Union[float, int]: + def evaluate(self, model) -> float | int: result = self.eval_fn(model, *self.eval_args) if self.eval_args else self.eval_fn(model) return result @@ -67,10 +68,10 @@ def eval_perf(molde): EVAL_FN = "eval_fn" WEIGHT = "weight" FN_NAME = "name" - EVAL_FN_TEMPLATE: Dict[str, Any] = {EVAL_FN: None, WEIGHT: 1.0, FN_NAME: None} + EVAL_FN_TEMPLATE: dict[str, Any] = {EVAL_FN: None, WEIGHT: 1.0, FN_NAME: None} # noqa: RUF012 def __init__(self) -> None: - self.eval_fn_registry: List[Dict[str, Any]] = [] + self.eval_fn_registry: list[dict[str, Any]] = [] def evaluate(self, model) -> float: """Evaluate the model using registered evaluation functions. @@ -94,7 +95,7 @@ def _update_the_objective_score(self, eval_pair, eval_result, overall_result) -> def get_number_of_eval_functions(self) -> int: return len(self.eval_fn_registry) - def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None: + def _set_eval_fn_registry(self, user_eval_fns: list[dict]) -> None: self.eval_fn_registry = [ { self.EVAL_FN: user_eval_fn_pair[self.EVAL_FN], @@ -104,7 +105,7 @@ def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None: for user_eval_fn_pair in user_eval_fns ] - def set_eval_fn_registry(self, eval_fns: Optional[Union[Callable, Dict, List[Dict]]] = None) -> None: + def set_eval_fn_registry(self, eval_fns: Callable | dict | list[dict] | None = None) -> None: # About the eval_fns format, refer the class docstring for details. if eval_fns is None: return @@ -117,7 +118,7 @@ def set_eval_fn_registry(self, eval_fns: Optional[Union[Callable, Dict, List[Dic elif isinstance(eval_fns, Dict): eval_fns = [eval_fns] elif isinstance(eval_fns, List): - assert all([isinstance(eval_fn_pair, Dict) for eval_fn_pair in eval_fns]) + assert all(isinstance(eval_fn_pair, Dict) for eval_fn_pair in eval_fns) else: raise NotImplementedError(f"The eval_fns should be a dict or a list of dict, but got {type(eval_fns)}.") self._set_eval_fn_registry(eval_fns) @@ -134,7 +135,7 @@ def self_check(self) -> None: class ConfigSet: - def __init__(self, config_list: List[config.BaseConfig]) -> None: + def __init__(self, config_list: list[config.BaseConfig]) -> None: self.config_list = config_list def __getitem__(self, index) -> config.BaseConfig: @@ -145,20 +146,20 @@ def __len__(self) -> int: return len(self.config_list) @classmethod - def _from_single_config(cls, fwk_config: config.BaseConfig) -> List[config.BaseConfig]: + def _from_single_config(cls, fwk_config: config.BaseConfig) -> list[config.BaseConfig]: config_list = [] config_list = fwk_config.expand() return config_list @classmethod - def _from_list_of_configs(cls, fwk_configs: List[config.BaseConfig]) -> List[config.BaseConfig]: + def _from_list_of_configs(cls, fwk_configs: list[config.BaseConfig]) -> list[config.BaseConfig]: config_list = [] for fwk_config in fwk_configs: config_list += cls._from_single_config(fwk_config) return config_list @classmethod - def generate_config_list(cls, fwk_configs: Union[config.BaseConfig, List[config.BaseConfig]]): + def generate_config_list(cls, fwk_configs: config.BaseConfig | list[config.BaseConfig]): # There are several cases for the input `fwk_configs`: # 1. fwk_configs is a single config # 2. fwk_configs is a list of configs @@ -173,12 +174,13 @@ def generate_config_list(cls, fwk_configs: Union[config.BaseConfig, List[config. return config_list @classmethod - def from_fwk_configs(cls, fwk_configs: Union[config.BaseConfig, List[config.BaseConfig]]) -> "ConfigSet": + def from_fwk_configs(cls, fwk_configs: config.BaseConfig | list[config.BaseConfig]) -> ConfigSet: """Create a ConfigSet object from a single config or a list of configs. Args: fwk_configs: A single config or a list of configs. - Examples: + + Examples: 1) single config: config.RTNConfig(weight_group_size=32) 2) single expandable config: config.RTNConfig(weight_group_size=[32, 64]) 3) mixed 1) and 2): [config.RTNConfig(weight_group_size=32), config.RTNConfig(weight_group_size=[32, 64])] @@ -192,7 +194,7 @@ def from_fwk_configs(cls, fwk_configs: Union[config.BaseConfig, List[config.Base class Sampler: - def __init__(self, config_source: Optional[ConfigSet]) -> None: + def __init__(self, config_source: ConfigSet | None) -> None: pass def __iter__(self) -> Iterator[config.BaseConfig]: @@ -260,7 +262,7 @@ class TuningConfig: def __init__( self, - config_set: Union[config.BaseConfig, List[config.BaseConfig]] = None, + config_set: config.BaseConfig | list[config.BaseConfig] = None, sampler: Sampler = default_sampler, tolerable_loss=0.01, max_trials=100, @@ -287,7 +289,7 @@ def _generate_unique_id(): unique_id = str(uuid.uuid4()) return unique_id - def __init__(self, trial_index: int, trial_result: Union[int, float], quant_config: config.BaseConfig): + def __init__(self, trial_index: int, trial_result: int | float, quant_config: config.BaseConfig): # The unique id to refer to one trial self.trial_id = _TrialRecord._generate_unique_id() self.trial_index = trial_index @@ -300,12 +302,10 @@ class TuningMonitor: def __init__(self, tuning_config: TuningConfig) -> None: self.tuning_config = tuning_config self.trial_cnt = 0 - self.tuning_history: List[_TrialRecord] = [] + self.tuning_history: list[_TrialRecord] = [] self.baseline = None - def add_trial_result( - self, trial_index: int, trial_result: Union[int, float], quant_config: config.BaseConfig - ) -> None: + def add_trial_result(self, trial_index: int, trial_result: int | float, quant_config: config.BaseConfig) -> None: self.trial_cnt += 1 trial_record = _TrialRecord(trial_index, trial_result, quant_config) self.tuning_history.append(trial_record) @@ -320,7 +320,7 @@ def get_number_of_trials(self): def get_best_quant_config(self) -> config.BaseConfig: assert self.get_number_of_trials() > 0, "No trial record in tuning monitor." # Put the record with a higher score at the beginning - sorted_trials_records: List[_TrialRecord] = sorted( + sorted_trials_records: list[_TrialRecord] = sorted( self.tuning_history, key=lambda x: x.trial_result, reverse=True ) return sorted_trials_records[0].quant_config @@ -331,7 +331,6 @@ def need_stop(self) -> bool: Returns: stop_flag: True if need to stop, otherwise False. """ - # reach max trials reach_max_trials = self.trial_cnt >= self.tuning_config.max_trials # reach accuracy goal @@ -355,7 +354,7 @@ def tuning_start(cls) -> None: logger.info("Tuning started.") @classmethod - def trial_start(cls, trial_index: int = None) -> None: + def trial_start(cls, trial_index: int | None = None) -> None: logger.info("%d-trail started.", trial_index) @classmethod @@ -375,7 +374,7 @@ def evaluation_end(cls) -> None: logger.info("Evaluation end.") @classmethod - def trial_end(cls, trial_index: int = None) -> None: + def trial_end(cls, trial_index: int | None = None) -> None: logger.info("%d-trail end.", trial_index) @classmethod @@ -383,14 +382,14 @@ def tuning_end(cls) -> None: logger.info("Tuning completed.") -def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger, TuningMonitor]: +def init_tuning(tuning_config: TuningConfig) -> tuple[ConfigLoader, TuningLogger, TuningMonitor]: config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) tuning_logger = TuningLogger() tuning_monitor = TuningMonitor(tuning_config) return config_loader, tuning_logger, tuning_monitor -def get_all_config_set() -> Union[config.BaseConfig, List[config.BaseConfig]]: +def get_all_config_set() -> config.BaseConfig | list[config.BaseConfig]: return config.get_all_config_set_from_config_registry() @@ -401,7 +400,7 @@ def _need_apply(quant_config: config.BaseConfig, algo_name): # * only for internal usage now @utility.log_quant_execution def _quantize( - model_input: Union[pathlib.Path, str], + model_input: pathlib.Path | str, quant_config: config.BaseConfig, calibration_data_reader: data_reader.CalibrationDataReader = None, ) -> onnx.ModelProto: @@ -436,12 +435,12 @@ def _quantize( def autotune( - model_input: Union[pathlib.Path, str], + model_input: pathlib.Path | str, tune_config: TuningConfig, eval_fn: Callable, - eval_args: Optional[Tuple[Any]] = None, + eval_args: tuple[Any] | None = None, calibration_data_reader: data_reader.CalibrationDataReader = None, -) -> Union[None, onnx.ModelProto]: +) -> None | onnx.ModelProto: """The main entry of auto-tune. Args: @@ -463,11 +462,11 @@ def autotune( config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config) try: baseline: float = eval_func_wrapper.evaluate(model_input) - except Exception as e: + except Exception as e: # noqa: BLE001 print(e) if "'str' object has no attribute 'SerializeToString'" in str(e): logger.warning("Please refine your eval_fn to accept model path (str) as input.") - exit(0) + sys.exit(0) tuning_monitor.set_baseline(baseline) tuning_logger.tuning_start() for trial_index, quant_config in enumerate(config_loader): @@ -475,7 +474,7 @@ def autotune( calibration_data_reader.rewind() tuning_logger.trial_start(trial_index=trial_index) tuning_logger.quantization_start() - logger.debug("quant config: {}".format(quant_config)) + logger.debug(f"quant config: {quant_config}") q_model = _quantize(model_input, quant_config=quant_config, calibration_data_reader=calibration_data_reader) tuning_logger.quantization_end() tuning_logger.evaluation_start() diff --git a/onnx_neural_compressor/utility.py b/onnx_neural_compressor/utility.py index cc36b6e8a..f58663ff4 100644 --- a/onnx_neural_compressor/utility.py +++ b/onnx_neural_compressor/utility.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import importlib -import logging -import os import pathlib import subprocess import time +from typing import Callable import cpuinfo import numpy as np @@ -27,10 +27,8 @@ from onnx_neural_compressor import constants, logger -from typing import Callable, Dict, List, Tuple, Union # isort: skip - # Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) -algos_mapping: Dict[str, Callable] = {} +algos_mapping: dict[str, Callable] = {} ####################################################### @@ -38,33 +36,34 @@ ####################################################### -def check_value(name, src, supported_type, supported_value=[]): +def check_value(name, src, supported_type, supported_value=None): """Check if the given object is the given supported type and in the given supported value. Example:: from onnx_neural_compressor import utility + def datatype(self, datatype): if utility.check_value("datatype", datatype, list, ["fp32", "bf16", "uint8", "int8"]): self._datatype = datatype """ - if isinstance(src, list) and any([not isinstance(i, supported_type) for i in src]): - assert False, "Type of {} items should be {} but not {}".format( - name, str(supported_type), [type(i) for i in src] - ) + if supported_value is None: + supported_value = [] + if isinstance(src, list) and any(not isinstance(i, supported_type) for i in src): + raise AssertionError(f"Type of {name} items should be {supported_type!s} but not {[type(i) for i in src]}") elif not isinstance(src, list) and not isinstance(src, supported_type): - assert False, "Type of {} should be {} but not {}".format(name, str(supported_type), type(src)) + raise AssertionError(f"Type of {name} should be {supported_type!s} but not {type(src)}") if len(supported_value) > 0: if isinstance(src, str) and src not in supported_value: - assert False, "{} is not in supported {}: {}. Skip setting it.".format(src, name, str(supported_value)) + raise AssertionError(f"{src} is not in supported {name}: {supported_value!s}. Skip setting it.") elif ( isinstance(src, list) - and all([isinstance(i, str) for i in src]) - and any([i not in supported_value for i in src]) + and all(isinstance(i, str) for i in src) + and any(i not in supported_value for i in src) ): - assert False, "{} is not in supported {}: {}. Skip setting it.".format(src, name, str(supported_value)) + raise AssertionError(f"{src} is not in supported {name}: {supported_value!s}. Skip setting it.") return True @@ -94,6 +93,7 @@ class Options: from onnx_neural_compressor import set_random_seed from onnx_neural_compressor import set_workspace from onnx_neural_compressor import set_resume_from + set_random_seed(2022) set_workspace("workspace_path") set_resume_from("workspace_path") @@ -153,7 +153,7 @@ def tuning_start(cls) -> None: logger.info("Tuning started.") @classmethod - def trial_start(cls, trial_index: int = None) -> None: + def trial_start(cls, trial_index: int | None = None) -> None: logger.info("%d-trail started.", trial_index) @classmethod @@ -173,7 +173,7 @@ def evaluation_end(cls) -> None: logger.info("Evaluation end.") @classmethod - def trial_end(cls, trial_index: int = None) -> None: + def trial_end(cls, trial_index: int | None = None) -> None: logger.info("%d-trail end.", trial_index) @classmethod @@ -183,7 +183,6 @@ def tuning_end(cls) -> None: def singleton(cls): """Singleton decorator.""" - instances = {} def _singleton(*args, **kw): @@ -195,7 +194,7 @@ def _singleton(*args, **kw): return _singleton -class LazyImport(object): +class LazyImport: """Lazy import python module till use.""" def __init__(self, module_name): @@ -212,7 +211,7 @@ def __getattr__(self, name): try: self.module = importlib.import_module(self.module_name) mod = getattr(self.module, name) - except: + except: # noqa: E722 spec = importlib.util.find_spec(str(self.module_name + "." + name)) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -228,7 +227,7 @@ def __call__(self, *args, **kwargs): @singleton -class CpuInfo(object): +class CpuInfo: """CPU info collection.""" def __init__(self): @@ -241,13 +240,13 @@ def __init__(self): max_extension_support = cpuid.get_max_extension_support() if max_extension_support >= 7: ecx = cpuid._run_asm( - b"\x31\xC9", # xor ecx, ecx - b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret + b"\x31\xc9", # xor ecx, ecx + b"\xb8\x07\x00\x00\x00\x0f\xa2\x89\xc8\xc3", # mov eax, 7 # cpuid # mov ax, cx # ret ) self._vnni = bool(ecx & (1 << 11)) eax = cpuid._run_asm( - b"\xB9\x01\x00\x00\x00", # mov ecx, 1 - b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret + b"\xb9\x01\x00\x00\x00", # mov ecx, 1 + b"\xb8\x07\x00\x00\x00\x0f\xa2\xc3", # mov eax, 7 # cpuid # ret ) self._bf16 = bool(eax & (1 << 5)) # TODO: The implementation will be refined in the future. @@ -304,14 +303,12 @@ def dump_elapsed_time(customized_msg=""): """ def f(func): - def fi(*args, **kwargs): start = time.time() res = func(*args, **kwargs) end = time.time() logger.info( - "%s elapsed time: %s ms" - % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) + f"{customized_msg if customized_msg else func.__qualname__} elapsed time: {round((end - start) * 1000, 2)} ms" ) return res @@ -377,7 +374,7 @@ def find_by_name(name, item_list): """Helper function to find item by name in a list.""" items = [] for item in item_list: - assert hasattr(item, "name"), "{} should have a 'name' attribute defined".format(item) # pragma: no cover + assert hasattr(item, "name"), f"{item} should have a 'name' attribute defined" # pragma: no cover if item.name == name: items.append(item) if len(items) > 0: @@ -420,8 +417,8 @@ def decorator(algo_func): def get_model_info( - model: Union[onnx.ModelProto, pathlib.Path, str], white_op_type_list: List[Callable] -) -> List[Tuple[str, Callable]]: + model: onnx.ModelProto | pathlib.Path | str, white_op_type_list: list[Callable] +) -> list[tuple[str, Callable]]: if not isinstance(model, onnx.ModelProto): model = onnx.load(model) filter_result = [] @@ -436,15 +433,15 @@ def get_model_info( return filter_result -def is_B_transposed(node): +def is_B_transposed(node): # noqa: N802 """Whether inuput B is transposed.""" - transB = [attr for attr in node.attribute if attr.name == "transB"] + transB = [attr for attr in node.attribute if attr.name == "transB"] # noqa: N806 if len(transB): - return 0 < onnx.helper.get_attribute_value(transB[0]) + return onnx.helper.get_attribute_value(transB[0]) > 0 return False -def get_qrange_for_qType(qType, reduce_range=False): +def get_qrange_for_qType(qType, reduce_range=False): # noqa: N802, N803 """Helper function to get the quantization range for a type. Args: @@ -460,7 +457,7 @@ def get_qrange_for_qType(qType, reduce_range=False): raise ValueError("unsupported quantization data type") -def _quantize_data_with_scale_zero(data, qType, scheme, scale, zero_point): +def _quantize_data_with_scale_zero(data, qType, scheme, scale, zero_point): # noqa: N803 """Quantize data with scale and zero point. To pack weights, we compute a linear transformation @@ -482,11 +479,11 @@ def _quantize_data_with_scale_zero(data, qType, scheme, scale, zero_point): elif qType == onnx.onnx_pb.TensorProto.UINT8 and scheme == "asym": quantized_data = ((data.astype(np.float32) / scale).round() + zero_point).astype("B") else: - raise ValueError("Unexpected combination of data type {} and scheme {}.".format(qType, scheme)) + raise ValueError(f"Unexpected combination of data type {qType} and scheme {scheme}.") return quantized_data -def _calculate_scale_zp(rmin, rmax, quantize_range, qType, scheme): +def _calculate_scale_zp(rmin, rmax, quantize_range, qType, scheme): # noqa: N803 """Calculate scale and zero point.""" if isinstance(rmax, np.ndarray): if scheme == "sym": @@ -536,7 +533,7 @@ def _calculate_scale_zp(rmin, rmax, quantize_range, qType, scheme): return scale, zero_point -def quantize_data(data, quantize_range, qType, scheme): +def quantize_data(data, quantize_range, qType, scheme): # noqa: N803 """Quantize data. To pack weights, we compute a linear transformation @@ -556,8 +553,8 @@ def quantize_data(data, quantize_range, qType, scheme): qType (int): data type to quantize to. Supported types UINT8 and INT8 scheme (string): sym or asym quantization. """ - rmin = min(min(data), 0) - rmax = max(max(data), 0) + rmin = min(*data, 0) + rmax = max(*data, 0) scale, zero_point = _calculate_scale_zp(rmin, rmax, quantize_range, qType, scheme) quantized_data = _quantize_data_with_scale_zero(data, qType, scheme, scale, zero_point) diff --git a/onnx_neural_compressor/version.py b/onnx_neural_compressor/version.py index aa0978f16..1b8307cc2 100644 --- a/onnx_neural_compressor/version.py +++ b/onnx_neural_compressor/version.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2021 Intel Corporation # diff --git a/pyproject.toml b/pyproject.toml index 9d46c3db1..59783b887 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,53 +50,67 @@ indent-width = 4 # Assume Python 3.8 target-version = "py38" +unsafe-fixes = true [tool.ruff.lint] -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or -# McCabe complexity (`C901`) by default. -select = ["E4", "E7", "E9", "F"] +select = [ + "A", # flake8-builtins + "ARG", # flake8-unused-arguments + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C4", # flake8-comprehensions + "D", # pydocstyle + "E", # pycodestyle + "F", # Pyflakes + "FA", # flake8-future-annotations + "G", # flake8-logging-format + "I002", # isort: required imports + "ISC", # flake8-implicit-str-concat + "LOG", # flake8-logging + "N", # pep8-naming + "NPY", # modern numpy + "PERF", # Perflint + "PIE", # flake8-pie + "PL", # pylint + "PYI", # flake8-pyi + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "SLOT", # flake8-slot + "T10", # flake8-debugger + "TID", # Disallow relative imports + "TRY", # flake8-try-except-raise + "UP", # pyupgrade + "W", # pycodestyle + "YTT", # flake8-2020 +] +# NOTE: Refrain from growing the ignore list unless for exceptional cases. +# Always include a comment to explain why. ignore = [ - "E402", # Module level import not at top of file - "E501", # Line too long (121 > 120 characters) - "E721", # Do not compare types, use isinstance() - "E722", # Do not use bare except - "E731", # Do not assign a lambda expression, use a def - "E741", # Do not use variables named ‘l’, ‘O’, or ‘I’ - "F401", # {name} imported but unused - "F403", # from {name} import * used; unable to detect undefined names - "F405", # {name} may be undefined, or defined from star imports - "F841", # Local variable is assigned to but never used{name} + "D1", # D1 is for missing docstrings, which is not yet enforced. + "E501", # Line length controlled by formatter + "NPY002", # np.random.Generator may not be preferred in all cases + "PLR09", # Ignore the pylint "too-many-*" rules + "PLR2004", # Magic numbers + "PYI011", # Allow protobuf enums as defaults to function arguments + "PYI021", # Allow docstrings in pyi files + "PYI041", # int | float is sometimes more clear than float + "RUF015", # next(iter(...)) sometimes obscures the intent when we access the 0th element of a shape + "SIM102", # We don't perfer always combining if branches + "SIM108", # We don't always encourage ternary operators + "SIM114", # Don't always combine if branches for debugability + "SIM116", # Don't use dict lookup to replace if-else + "TRY003", # Messages can be constructed in the exception ] -# Allow fix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL"] -unfixable = [] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - ignore-init-module-imports = true -[tool.ruff.format] -# Like Black, use double quotes for strings. -quote-style = "double" - -# Like Black, indent with spaces, rather than tabs. -indent-style = "space" - -# Like Black, respect magic trailing commas. -skip-magic-trailing-comma = false - -# Like Black, automatically detect the appropriate line ending. -line-ending = "auto" +[tool.ruff.lint.pydocstyle] +convention = "google" +[tool.ruff.format] # Enable auto-formatting of code examples in docstrings. Markdown, # reStructuredText code/literal blocks and doctests are all supported. -# -# This is currently disabled by default, but it is planned for this -# to be opt-out in the future. -docstring-code-format = false +docstring-code-format = true # Set the line length limit used when formatting code snippets in # docstrings. diff --git a/setup.py b/setup.py index c80178535..74d412370 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ -import io import re import subprocess @@ -10,8 +9,7 @@ def is_commit_on_tag(): result = subprocess.run( ["git", "describe", "--exact-match", "--tags"], capture_output=True, text=True, check=True ) - tag_name = result.stdout.strip() - return tag_name + return result.stdout.strip() except subprocess.CalledProcessError: return False @@ -22,17 +20,17 @@ def get_build_version(): try: result = subprocess.run(["git", "describe", "--tags"], capture_output=True, text=True, check=True) _, distance, commit = result.stdout.strip().split("-") - return f"{__version__}.dev{distance}+{commit}" except subprocess.CalledProcessError: return __version__ + return f"{__version__}.dev{distance}+{commit}" try: filepath = "./onnx_neural_compressor/version.py" - with io.open(filepath) as version_file: + with open(filepath) as version_file: (__version__,) = re.findall('__version__ = "(.*)"', version_file.read()) except Exception as error: - assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) + raise AssertionError(f"Error: Could not open '{filepath}'") from error if __name__ == "__main__": @@ -42,7 +40,7 @@ def get_build_version(): version=get_build_version(), author_email="tai.huang@intel.com, mengni.wang@intel.com, yuwen.zhou@intel.com, suyue.chen@intel.com", description="Repository of Neural Compressor ORT", - long_description=io.open("README.md", "r", encoding="utf-8").read(), + long_description=open("README.md", encoding="utf-8").read(), # noqa: SIM115 long_description_content_type="text/markdown", keywords="quantization", license="Apache 2.0", diff --git a/test/quantization/layer_wise/test_layer_wise.py b/test/quantization/layer_wise/test_layer_wise.py index af0bca3e4..69166110a 100644 --- a/test/quantization/layer_wise/test_layer_wise.py +++ b/test/quantization/layer_wise/test_layer_wise.py @@ -5,9 +5,9 @@ import onnx import onnxruntime as ort -import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer import torch import transformers +from onnxruntime.tools import symbolic_shape_infer from optimum.exporters.onnx import main_export from onnx_neural_compressor import config, data_reader, logger @@ -17,7 +17,7 @@ def find_onnx_file(folder_path): # return first .onnx file path in folder_path - for root, dirs, files in os.walk(folder_path): + for root, _dirs, files in os.walk(folder_path): for file in files: if file.endswith(".onnx"): return os.path.join(root, file) diff --git a/test/quantization/test_autotune.py b/test/quantization/test_autotune.py index 0e86c64b9..bf66c57c9 100644 --- a/test/quantization/test_autotune.py +++ b/test/quantization/test_autotune.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2023 Intel Corporation # @@ -13,12 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import functools import glob import os import shutil import unittest +from typing import Callable from unittest import mock import numpy as np @@ -29,15 +30,13 @@ from onnx_neural_compressor import config, data_reader from onnx_neural_compressor.quantization import tuning -from typing import Callable, Dict, List, Optional, Union # isort: skip - -def fake_eval(model, eval_result_lst): +def fake_eval(model, eval_result_lst): # noqa: ARG001 acc = eval_result_lst.pop(0) return acc -def _create_evaluator_for_eval_fns(eval_fns: Optional[Union[Callable, Dict, List[Dict]]] = None) -> tuning.Evaluator: +def _create_evaluator_for_eval_fns(eval_fns: Callable | dict | list[dict] | None = None) -> tuning.Evaluator: evaluator = tuning.Evaluator() evaluator.set_eval_fn_registry(eval_fns) return evaluator @@ -91,14 +90,14 @@ def test_auto_tune_warning(self, mock_warning): acc_data = iter([1.0, 0.8, 0.99, 1.0, 0.99, 0.99]) def eval_acc_fn(model) -> float: - session = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) return next(acc_data) custom_tune_config = tuning.TuningConfig( config_set=[config.SmoothQuantConfig(alpha=0.5), config.SmoothQuantConfig(alpha=0.6)] ) with self.assertRaises(SystemExit): - best_model = tuning.autotune( + tuning.autotune( model_input=self.gptj, tune_config=custom_tune_config, eval_fn=eval_acc_fn, @@ -113,12 +112,12 @@ def eval_acc_fn(model) -> float: def test_sq_auto_tune(self): acc_data = iter([1.0, 0.8, 0.99, 1.0, 0.99, 0.99]) - def eval_acc_fn(model) -> float: + def eval_acc_fn(model) -> float: # noqa: ARG001 return next(acc_data) perf_data = iter([1.0, 0.9, 0.99]) - def eval_perf_fn(model) -> float: + def eval_perf_fn(model) -> float: # noqa: ARG001 return next(perf_data) eval_fns = [ @@ -159,12 +158,12 @@ def eval_fn_wrapper(model): def test_rtn_auto_tune(self): acc_data = iter([1.0, 0.8, 0.6, 1.0, 0.99, 0.9]) - def eval_acc_fn(model) -> float: + def eval_acc_fn(model) -> float: # noqa: ARG001 return next(acc_data) perf_data = iter([1.0, 0.99, 0.99]) - def eval_perf_fn(model) -> float: + def eval_perf_fn(model) -> float: # noqa: ARG001 return next(perf_data) eval_fns = [ @@ -204,19 +203,19 @@ def eval_fn_wrapper(model): op_names = [ i.name for i in best_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(4, 32)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{4}G{32}") ] self.assertTrue(len(op_names) > 0) def test_awq_auto_tune(self): acc_data = iter([1.0, 0.8, 0.6, 1.0, 0.99, 0.9]) - def eval_acc_fn(model) -> float: + def eval_acc_fn(model) -> float: # noqa: ARG001 return next(acc_data) perf_data = iter([1.0, 0.99, 0.99]) - def eval_perf_fn(model) -> float: + def eval_perf_fn(model) -> float: # noqa: ARG001 return next(perf_data) eval_fns = [ @@ -256,19 +255,19 @@ def eval_fn_wrapper(model): op_names = [ i.name for i in best_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(4, 32)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{4}G{32}") ] self.assertTrue(len(op_names) > 0) def test_gptq_auto_tune(self): acc_data = iter([1.0, 0.8, 0.6, 1.0, 0.99, 0.9]) - def eval_acc_fn(model) -> float: + def eval_acc_fn(model) -> float: # noqa: ARG001 return next(acc_data) perf_data = iter([1.0, 0.99, 0.99]) - def eval_perf_fn(model) -> float: + def eval_perf_fn(model) -> float: # noqa: ARG001 return next(perf_data) eval_fns = [ @@ -307,7 +306,7 @@ def eval_fn_wrapper(model): op_names = [ i.name for i in best_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(4, 32)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{4}G{32}") ] self.assertTrue(len(op_names) > 0) @@ -327,7 +326,7 @@ def test_woq_auto_tune(self): op_names = [ i.name for i in best_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(8, 32)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{8}G{32}") ] self.assertTrue(len(op_names) > 0) @@ -347,7 +346,7 @@ def test_woq_auto_tune(self): [ i.name for i in best_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(4, 32)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{4}G{32}") ] ) + 1, @@ -366,7 +365,7 @@ def test_woq_auto_tune(self): op_names = [ i.name for i in best_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(4, 128)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{4}G{128}") ] self.assertTrue(len(op_names) > 0) diff --git a/test/quantization/test_config.py b/test/quantization/test_config.py index 50ffc74d0..16620c233 100644 --- a/test/quantization/test_config.py +++ b/test/quantization/test_config.py @@ -1,4 +1,3 @@ -import copy import os import shutil import unittest @@ -13,7 +12,7 @@ def find_onnx_file(folder_path): # return first .onnx file path in folder_path - for root, dirs, files in os.walk(folder_path): + for root, _dirs, files in os.walk(folder_path): for file in files: if file.endswith(".onnx"): return os.path.join(root, file) @@ -21,25 +20,25 @@ def find_onnx_file(folder_path): def build_simple_onnx_model(): - A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 5, 5]) - C = onnx.helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1, 5, 2]) - D = onnx.helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [1, 5, 2]) - H = onnx.helper.make_tensor_value_info("H", onnx.TensorProto.FLOAT, [1, 5, 2]) + A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 5, 5]) # noqa: N806 + onnx.helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1, 5, 2]) + onnx.helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [1, 5, 2]) + H = onnx.helper.make_tensor_value_info("H", onnx.TensorProto.FLOAT, [1, 5, 2]) # noqa: N806 e_value = np.random.randint(2, size=(10)).astype(np.float32) - B_init = onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist()) - E_init = onnx.helper.make_tensor("E", onnx.TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist()) + B_init = onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist()) # noqa: N806 + E_init = onnx.helper.make_tensor("E", onnx.TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist()) # noqa: N806 matmul_node = onnx.helper.make_node("MatMul", ["A", "B"], ["C"], name="Matmul") add = onnx.helper.make_node("Add", ["C", "E"], ["D"], name="add") - f_value = np.random.randint(2, size=(10)).astype(np.float32) - F_init = onnx.helper.make_tensor("F", onnx.TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist()) + np.random.randint(2, size=(10)).astype(np.float32) + F_init = onnx.helper.make_tensor("F", onnx.TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist()) # noqa: N806 add2 = onnx.helper.make_node("Add", ["D", "F"], ["H"], name="add2") graph = onnx.helper.make_graph([matmul_node, add, add2], "test_graph_1", [A], [H], [B_init, E_init, F_init]) model = onnx.helper.make_model(graph) - model = onnx.helper.make_model(graph, **{"opset_imports": [onnx.helper.make_opsetid("", 13)]}) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 13)]) return model @@ -68,7 +67,7 @@ def setUp(self): def _check_node_is_quantized(self, model, node_name): for node in model.graph.node: - if (node.name == node_name or node.name == node_name + "_Q4") and node.op_type in [ + if (node.name in (node_name, node_name + "_Q4")) and node.op_type in [ "MatMulNBits", "MatMulFpQ4", ]: @@ -79,7 +78,7 @@ def _count_woq_matmul(self, q_model, bits=4, group_size=32): op_names = [ i.name for i in q_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{bits}G{group_size}") ] return len(op_names) diff --git a/test/quantization/test_smooth_quant.py b/test/quantization/test_smooth_quant.py index fed59e142..ed9b49e78 100644 --- a/test/quantization/test_smooth_quant.py +++ b/test/quantization/test_smooth_quant.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright (c) 2023 Intel Corporation # diff --git a/test/quantization/weight_only/test_awq.py b/test/quantization/weight_only/test_awq.py index 2d918cc61..191494c19 100644 --- a/test/quantization/weight_only/test_awq.py +++ b/test/quantization/weight_only/test_awq.py @@ -15,7 +15,7 @@ def find_onnx_file(folder_path): # return first .onnx file path in folder_path - for root, dirs, files in os.walk(folder_path): + for root, _dirs, files in os.walk(folder_path): for file in files: if file.endswith(".onnx"): return os.path.join(root, file) @@ -74,7 +74,7 @@ def _count_woq_matmul(self, q_model, bits=4, group_size=32): op_names = [ i.name for i in q_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{bits}G{group_size}") ] return len(op_names) @@ -84,7 +84,7 @@ def _check_model_is_quantized(self, model): def _check_node_is_quantized(self, model, node_name): for node in model.graph.node: - if (node.name == node_name or node.name == node_name + "_Q4") and node.op_type in [ + if (node.name in (node_name, node_name + "_Q4")) and node.op_type in [ "MatMulNBits", "MatMulFpQ4", ]: diff --git a/test/quantization/weight_only/test_gptq.py b/test/quantization/weight_only/test_gptq.py index 133e11fd1..6c95df623 100644 --- a/test/quantization/weight_only/test_gptq.py +++ b/test/quantization/weight_only/test_gptq.py @@ -15,7 +15,7 @@ def find_onnx_file(folder_path): # return first .onnx file path in folder_path - for root, dirs, files in os.walk(folder_path): + for root, _dirs, files in os.walk(folder_path): for file in files: if file.endswith(".onnx"): return os.path.join(root, file) @@ -74,7 +74,7 @@ def _count_woq_matmul(self, q_model, bits=4, group_size=32): op_names = [ i.name for i in q_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{bits}G{group_size}") ] return len(op_names) @@ -84,7 +84,7 @@ def _check_model_is_quantized(self, model): def _check_node_is_quantized(self, model, node_name): for node in model.graph.node: - if (node.name == node_name or node.name == node_name + "_Q4") and node.op_type in [ + if (node.name in (node_name, node_name + "_Q4")) and node.op_type in [ "MatMulNBits", "MatMulFpQ4", ]: diff --git a/test/quantization/weight_only/test_rtn.py b/test/quantization/weight_only/test_rtn.py index 86b3c49a3..51b59f026 100644 --- a/test/quantization/weight_only/test_rtn.py +++ b/test/quantization/weight_only/test_rtn.py @@ -13,7 +13,7 @@ def find_onnx_file(folder_path): # return first .onnx file path in folder_path - for root, dirs, files in os.walk(folder_path): + for root, _dirs, files in os.walk(folder_path): for file in files: if file.endswith(".onnx"): return os.path.join(root, file) @@ -44,7 +44,7 @@ def _check_model_is_quantized(self, model): def _check_node_is_quantized(self, model, node_name): for node in model.graph.node: - if (node.name == node_name or node.name == node_name + "_Q4") and node.op_type in [ + if (node.name in (node_name, node_name + "_Q4")) and node.op_type in [ "MatMulNBits", "MatMulFpQ4", ]: @@ -55,7 +55,7 @@ def _count_woq_matmul(self, q_model, bits=4, group_size=32): op_names = [ i.name for i in q_model.graph.node - if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + if i.op_type.startswith("MatMul") and i.input[1].endswith(f"_Q{bits}G{group_size}") ] return len(op_names) diff --git a/test/utils/test_general.py b/test/utils/test_general.py index d24392438..702fe43ee 100644 --- a/test/utils/test_general.py +++ b/test/utils/test_general.py @@ -1,13 +1,13 @@ """Tests for general components.""" +from __future__ import annotations + import unittest +from typing import Any, Callable, List from onnx_neural_compressor import config, constants, logger from onnx_neural_compressor.quantization import tuning -from typing import Any, Callable, List, Optional, Tuple, Union # isort: skip - - PRIORITY_FAKE_ALGO = 100 FAKE_CONFIG_NAME = "fake" PRIORITY_FAKE_ALGO_1 = 90 @@ -33,20 +33,20 @@ def __repr__(self) -> str: class FakeAlgoConfig(config.BaseConfig): """Config class for fake algo.""" - supported_configs: List = [] - params_list = [ + supported_configs: list = [] # noqa: RUF012 + params_list = [ # noqa: RUF012 "weight_dtype", "weight_bits", config.TuningParam("target_op_type_list", tunable_type=List[List[str]]), ] name = FAKE_CONFIG_NAME - def __init__( + def __init__( # noqa: D417 self, weight_dtype: str = "int", weight_bits: int = 4, - target_op_type_list: List[str] = ["Conv", "Gemm"], - white_list: Optional[List[Union[str, Callable]]] = constants.DEFAULT_WHITE_LIST, + target_op_type_list: list[str] | None = None, + white_list: list[str | Callable] | None = constants.DEFAULT_WHITE_LIST, ): """Init fake config. @@ -54,6 +54,8 @@ def __init__( weight_dtype (str): Data type for weights, default is "int". weight_bits (int): Number of bits used to represent weights, default is 4. """ + if target_op_type_list is None: + target_op_type_list = ["Conv", "Gemm"] super().__init__(white_list=white_list) self.weight_bits = weight_bits self.weight_dtype = weight_dtype @@ -65,18 +67,18 @@ def to_dict(self): @classmethod def from_dict(cls, config_dict): - return super(FakeAlgoConfig, cls).from_dict(config_dict=config_dict) + return super().from_dict(config_dict=config_dict) @classmethod - def register_supported_configs(cls) -> List: + def register_supported_configs(cls) -> list: pass @staticmethod - def get_model_info(model: Any) -> List[Tuple[str, Any]]: + def get_model_info(model: Any) -> list[tuple[str, Any]]: # noqa: ARG004 return FAKE_MODEL_INFO @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoConfig", List["FakeAlgoConfig"]]: + def get_config_set_for_tuning(cls) -> None | FakeAlgoConfig | list[FakeAlgoConfig]: return FakeAlgoConfig(weight_bits=DEFAULT_WEIGHT_BITS) @@ -93,20 +95,20 @@ def get_default_fake_config() -> FakeAlgoConfig: class FakeAlgoOneConfig(config.BaseConfig): """Config class for fake algo.""" - supported_configs: List = [] - params_list = [ + supported_configs: list = [] # noqa: RUF012 + params_list = [ # noqa: RUF012 "weight_dtype", "weight_bits", config.TuningParam("target_op_type_list", tunable_type=List[List[str]]), ] name = FAKE_CONFIG_NAME_1 - def __init__( + def __init__( # noqa: D417 self, weight_dtype: str = "int", weight_bits: int = 4, - target_op_type_list: List[str] = ["Conv", "Gemm"], - white_list: Optional[List[Union[str, Callable]]] = constants.DEFAULT_WHITE_LIST, + target_op_type_list: list[str] | None = None, + white_list: list[str | Callable] | None = constants.DEFAULT_WHITE_LIST, ): """Init fake config. @@ -114,6 +116,8 @@ def __init__( weight_dtype (str): Data type for weights, default is "int". weight_bits (int): Number of bits used to represent weights, default is 4. """ + if target_op_type_list is None: + target_op_type_list = ["Conv", "Gemm"] super().__init__(white_list=white_list) self.weight_bits = weight_bits self.weight_dtype = weight_dtype @@ -125,22 +129,22 @@ def to_dict(self): @classmethod def from_dict(cls, config_dict): - return super(FakeAlgoOneConfig, cls).from_dict(config_dict=config_dict) + return super().from_dict(config_dict=config_dict) @classmethod - def register_supported_configs(cls) -> List: + def register_supported_configs(cls) -> list: pass @staticmethod - def get_model_info(model: Any) -> List[Tuple[str, Any]]: + def get_model_info(model: Any) -> list[tuple[str, Any]]: # noqa: ARG004 return FAKE_MODEL_INFO @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoOneConfig", List["FakeAlgoOneConfig"]]: + def get_config_set_for_tuning(cls) -> None | FakeAlgoOneConfig | list[FakeAlgoOneConfig]: return FakeAlgoOneConfig(weight_bits=DEFAULT_WEIGHT_BITS) -def get_all_config_set() -> Union[config.BaseConfig, List[config.BaseConfig]]: +def get_all_config_set() -> config.BaseConfig | list[config.BaseConfig]: return config.get_all_config_set_from_config_registry() @@ -151,7 +155,7 @@ class TestEvaluator(unittest.TestCase): def test_single_eval_fn(self): - def fake_eval_fn(model): + def fake_eval_fn(model): # noqa: ARG001 return 1.0 evaluator = tuning.Evaluator() @@ -162,7 +166,7 @@ def fake_eval_fn(model): def test_single_eval_fn_dict(self): acc_data = iter([1.0, 0.8, 0.99, 1.0, 0.99, 0.99]) - def eval_acc_fn(model) -> float: + def eval_acc_fn(model) -> float: # noqa: ARG001 return next(acc_data) eval_fns = {"eval_fn": eval_acc_fn, "weight": 0.5, "name": "accuracy"} @@ -204,8 +208,8 @@ def test_config_expand_complex_tunable_type(self): def test_mixed_two_algos(self): model = FakeModel() - OP1_NAME = "OP1_NAME" - OP2_NAME = "OP2_NAME" + OP1_NAME = "OP1_NAME" # noqa: N806 + OP2_NAME = "OP2_NAME" # noqa: N806 fake_config = FakeAlgoConfig(weight_bits=4, white_list=[OP1_NAME]) fake1_config = FakeAlgoOneConfig(weight_bits=2, white_list=[OP2_NAME]) mixed_config = fake_config + fake1_config