From f95c0773a129a4605b2161f5f9fddb8116c948d0 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 28 Feb 2024 10:40:40 +0800 Subject: [PATCH 1/4] Add share memory Flag in docker (#19672) ### Description ### Motivation and Context Ref: https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem Co-authored-by: Your Name --- tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 822bc559d992d..165bd804a8ad5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -241,7 +241,7 @@ stages: script: | set -e -x mkdir -p $HOME/.onnx - docker run --gpus all --rm \ + docker run --gpus all --shm-size=1g --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --rm \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory)/Release:/build/Release \ --volume /data/models:/build/models:ro \ From 026e3178ae71cfcc5cfa2decde9a7d64b935d255 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 28 Feb 2024 15:57:05 +0800 Subject: [PATCH 2/4] Improve memory matrix for ORTModule (#19620) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Memory matrix for ORTModule Collect parameter/gradient/buffers sizes also. Exposed as a function, can be used externally for debugging purpose. ``` 2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▏ | 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▏ | 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 [2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1 0%|▎ | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▍ | 5/3200 [01:28<6:53:04, 7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▍ | 6/3200 [01:28<4:36:19, 5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 7/3200 [01:28<3:09:33, 3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 8/3200 [01:28<2:13:48, 2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▋ | 9/3200 [01:29<1:36:03, 1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 ``` --- .../training/ortmodule/_runtime_inspector.py | 37 +++------ .../python/training/utils/__init__.py | 2 + .../training/utils/torch_profile_utils.py | 76 +++++++++++++++++++ 3 files changed, 88 insertions(+), 27 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27cd6f..772b9bd9e31ae 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -14,7 +14,7 @@ from sympy import Symbol, simplify from sympy.parsing.sympy_parser import parse_expr -from onnxruntime.training.utils import PTable +from onnxruntime.training.utils import PTable, log_memory_usage from ._execution_agent import TrainingAgent from .options import _MemoryOptimizationLevel, _RuntimeOptions @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._is_first_inspect = True + self._m = m + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -621,29 +623,13 @@ def inspect_memory(self, cur_phase: Phase): need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) if need_print: - cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) - max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) - cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) - max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) - torch_mem_stat = torch.cuda.memory_stats() - cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) - max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) - - mem_stats = [ - ["phase", _convert_phase_to_string(cur_phase)], - ["allocated", cur_mem_allocated], # current memory allocated for tensors - ["max allocated", max_mem_allocated], # peak memory allocated for tensors - ["cached", cur_mem_cached], # current memory cached for the caching allocator - ["max cached", max_mem_cached], # peak memory cached for caching allocator. - ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory - ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ] - - summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" - for stat in mem_stats: - summ += f" | {stat[0]}: {stat[1]}" - - self._logger.info(summ) + log_memory_usage( + _convert_phase_to_string(cur_phase), + rank_0_only=True, + step_info=f"step {self._current_step}", + logger=self._logger, + module=self._m, + ) if cur_phase == self._last_phase: self._increase_step() @@ -655,9 +641,6 @@ def inspect_memory(self, cur_phase: Phase): def _increase_step(self): self._current_step += 1 - def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: - return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index b4a518d573998..ecfb7d7907f3c 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -12,6 +12,7 @@ unflatten_data_using_schema, ) from onnxruntime.training.utils.torch_profile_utils import ( + log_memory_usage, nvtx_function_decorator, torch_nvtx_range_pop, torch_nvtx_range_push, @@ -31,6 +32,7 @@ "torch_nvtx_range_push", "torch_nvtx_range_pop", "nvtx_function_decorator", + "log_memory_usage", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py index 382d7dac142fe..9e8a41e0dc7c8 100644 --- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import torch @@ -26,3 +28,77 @@ def wrapped_fn(*args, **kwargs): return ret_val return wrapped_fn + + +def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None): + """Log memory usage for the current phase. + Args: + cur_phase (str): The current phase. + rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True. + step_info (str, optional): The step information. Defaults to "". + logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout. + module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None. + """ + rank = 0 + if rank_0_only is True: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + if rank != 0: + return + + _normalizer_factor = float(1024 * 1024) + _normalizer_unit = "MiB" + + def _normalize(mem_size_in_bytes: float | int) -> str: + return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}" + + cur_mem_allocated = _normalize(torch.cuda.memory_allocated()) + max_mem_allocated = _normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = _normalize(torch.cuda.memory_reserved()) + max_mem_cached = _normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + mem_stats = [ + ["phase", cur_phase], + ["allocated", cur_mem_allocated], # current memory allocated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for the caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + # Calculate the total size of parameters and gradients in the model + if module: + param_total_size = 0 + grad_total_size = 0 + for p in module.parameters(): + if p.is_cuda: + param_total_size += p.numel() * p.element_size() + if p.grad is not None and p.grad.is_cuda: + grad_total_size += p.grad.numel() * p.grad.element_size() + + # Calculate the total size of buffers in the model + buffer_total_size = 0 + for b in module.buffers(): + if b.is_cuda: + buffer_total_size += b.numel() * b.element_size() + + mem_stats.extend( + [ + ["param", _normalize(param_total_size)], + ["grad", _normalize(grad_total_size)], + ["buffer", _normalize(buffer_total_size)], + ] + ) + + summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + + if logger is None: + print(summ) + else: + logger.info(summ) From 7a147fc6f76a30b8d5875352afced515431ec7e5 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 02:20:53 -0800 Subject: [PATCH 3/4] Remove a bash task from webgpu CI pipeline (#19682) ### Description It is a "Bash" task that requires running bash on Windows. Most Windows operating systems do not have Bash installed. Given this task is only debugging purposes, we can remove it for now. ### Motivation and Context I am making this change because I am regenerating the VM image in a different manner, and the new image does not contain bash. Once this PR is in, I can switch the images. --- .../github/azure-pipelines/templates/win-web-ci.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 8ba3517530edd..043da233cc674 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -155,12 +155,7 @@ jobs: path: $(Build.SourcesDirectory)/js/test/ cacheHitVar: CACHE_RESTORED displayName: 'Cache ONNX node test data' - - task: Bash@3 - inputs: - targetType: 'inline' - script: find "$(Build.SourcesDirectory)/js/test/" -type f - condition: and(not(canceled()), eq(variables.CACHE_RESTORED, 'true')) - displayName: 'List ONNX node test data' + - task: PowerShell@2 inputs: filePath: '$(Build.SourcesDirectory)\tools\ci_build\github\js\pack-npm-packages.ps1' From 913bdc7306e11b65644f733861684a3a460e8db0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 28 Feb 2024 08:30:12 -0800 Subject: [PATCH 4/4] [QNN Quant] Handle external data for QNN preprocessing/quant (#19670) ### Description - Adds parameters to `qnn_preprocess_model()` to allow saving the new model with external data. - Updates `get_qnn_qdq_config()` to: - Load model without external data (it is not needed) - Return a quantization configuration with `use_external_data_format` set to `True` if the model has external data or if the model is >= 2GB. ### Motivation and Context Update QNN quantization to better handle large models that use external data. --- .../execution_providers/qnn/preprocess.py | 51 +++++- .../execution_providers/qnn/quant_config.py | 15 +- .../quantization/test_qnn_preprocess_model.py | 170 ++++++++++++++++++ .../test_tensor_quant_overrides_option.py | 30 ++++ 4 files changed, 261 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_qnn_preprocess_model.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index b1c114fe1f9fd..b0dab81830c8b 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import logging from pathlib import Path @@ -13,7 +15,44 @@ from .fusion_lpnorm import FusionLpNormalization -def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool: +def qnn_preprocess_model( + model_input: Path, + model_output: Path, + fuse_layernorm: bool = False, + save_as_external_data: bool = False, + all_tensors_to_one_file: bool = False, + external_data_location: str | None = None, + external_data_size_threshold: int = 1024, + external_data_convert_attribute: bool = False, +) -> bool: + """ + If necessary, this method creates a new "pre-processed" model in preparation for + quantization of a model to be used in QNN EP. Returns true if a new model was created. + + This method perfoms the following operations: + - Fuse Erf sequence into a single Gelu node. + - Fuse ReduceL2 sequence into a single LpNormalization node (p == 2). + - (Optional) Fuse ReduceMean sequence into a single LayerNormalization node. + + Args: + model_input: Path to the input model file. + model_output: Path the output model file, which is only created if this method returns True. + fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes. + Defaults to False. + save_as_external_data: True if output model should be saved with external data. Defaults to false. + all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false. + If true, save all tensors to one external file specified by external_data_location. + If false, save each tensor to a file named with the tensor name. + external_data_location: Effective only if save_as_external_data is true. Defaults to None. + Specify the external file to which all tensors are saved. Path is relative + to the model path. If not specified, the model's name is used. + external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024. + Tensors with a data size >= external_data_size_threshold are converted to external data. + To convert every tensor with raw data to external data, set to 0. + external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false. + If true, convert all tensors to external data. + If false, convert only non-attribute tensors to external data. + """ modified = False model = onnx.load_model(model_input) onnx_model = ONNXModel(model) @@ -57,6 +96,14 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if modified: onnx_model.topological_sort() - onnx.save_model(model, model_output) + onnx.save_model( + model, + model_output, + save_as_external_data=save_as_external_data, + all_tensors_to_one_file=all_tensors_to_one_file, + location=external_data_location, + size_threshold=external_data_size_threshold, + convert_attribute=external_data_convert_attribute, + ) return modified diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 7c2fa4f65ae1b..e9affae7ac263 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -15,6 +15,7 @@ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} OP_TYPES_TO_EXCLUDE = {"Cast"} +MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB def get_qnn_qdq_config( @@ -28,14 +29,21 @@ def get_qnn_qdq_config( if per_channel: raise ValueError("QNN EP does not yet support per-channel quantization.") - # Process model nodes to setup overrides. - model = onnx.load_model(model_input) + model = onnx.load_model(model_input, load_external_data=False) op_types = set() tensor_quant_overrides = {} + model_has_external_data = False + name_to_initializer = {} - name_to_initializer = {initializer.name: initializer for initializer in model.graph.initializer} + # Build map of initializers (name -> initializer) and + # check if the model has external data. + for initializer in model.graph.initializer: + name_to_initializer[initializer.name] = initializer + if onnx.external_data_helper.uses_external_data(initializer): + model_has_external_data = True + # Setup quantization overrides for specific operator types for node in model.graph.node: op_types.add(node.op_type) @@ -89,5 +97,6 @@ def get_qnn_qdq_config( activation_type=activation_type, weight_type=weight_type, op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), extra_options=extra_options, ) diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py new file mode 100644 index 0000000000000..9b67fd41caac3 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest +from pathlib import Path + +import numpy as np +import onnx + +from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model +from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain + + +class TestQnnPreprocessModel(unittest.TestCase): + def build_model(self, shape, scale_val, bias_val): + """ + Build a model that supports 3 kinds of fusions: + - Erf sequence to Gelu + - ReduceL2 sequence to LpNormalization + - ReduceMean sequence to LayerNormalization + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + + # Erf sequence + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + e_mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["e_mul0_out"]) + e_div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["e_div_out"]) + e_erf_node = onnx.helper.make_node("Erf", ["e_div_out"], ["e_erf_out"]) + e_add_node = onnx.helper.make_node("Add", ["e_erf_out", "one_const"], ["e_add_out"]) + e_mul1_node = onnx.helper.make_node("Mul", ["e_add_out", "e_mul0_out"], ["erf_seq_output"]) + + # ReduceL2 sequence + axes_const = onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(1e-12, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + l2_rl2_node = onnx.helper.make_node("ReduceL2", ["erf_seq_output", "axes_const"], ["l2_rl2_out"], keepdims=1) + l2_clip_node = onnx.helper.make_node("Clip", ["l2_rl2_out", "eps_const"], ["l2_clip_out"]) + l2_expand_node = onnx.helper.make_node("Expand", ["l2_clip_out", "shape_const"], ["l2_expand_out"]) + l2_div_node = onnx.helper.make_node("Div", ["erf_seq_output", "l2_expand_out"], ["l2_seq_output"]) + + # ReduceMean sequence + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + + m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"]) + m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"]) + m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"]) + m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"]) + m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"]) + m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"]) + m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"]) + m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"]) + m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [ + e_mul0_node, + e_div_node, + e_erf_node, + e_add_node, + e_mul1_node, + l2_rl2_node, + l2_clip_node, + l2_expand_node, + l2_div_node, + m_rm0_node, + m_sub_node, + m_pow_node, + m_rm1_node, + m_add0_node, + m_sqrt_node, + m_div_node, + m_mul_node, + m_add1_node, + ], + "qnn_f32_model", + [root_inp], + [output], + initializer=[ + one_const, + half_const, + root2_const, + axes_const, + eps_const, + shape_const, + scale_const, + bias_const, + two_const, + ], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_all_fusions(self): + """ + Test calling qnn_preprocess_model() with a model that supports all 3 fusions. + """ + model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]) + onnx.save_model(model, "model.onnx") + modified = qnn_preprocess_model("model.onnx", "model.qnn_pp.onnx", fuse_layernorm=True) + + self.assertTrue(modified) + + fused_model = onnx.load_model("model.qnn_pp.onnx") + + # 3 fused Ops: Gelu, LpNorm, LayerNorm + self.assertEqual(len(fused_model.graph.node), 3) + expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} + for node in fused_model.graph.node: + self.assertIn(node.op_type, expected_op_types) + + # Should have added "com.microsoft" opset import because we added a Gelu. + ms_domain_opset = next((opset for opset in fused_model.opset_import if opset.domain == ms_domain), None) + self.assertIsNotNone(ms_domain_opset) + self.assertEqual(ms_domain_opset.version, 1) + + def test_external_data(self): + """ + Test calling qnn_preprocess_model() with a model that uses external data. + The new preprocessed model should also have external data. + """ + model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]) + onnx.save_model( + model, + "model.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.bin", + size_threshold=0, + ) + modified = qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + fuse_layernorm=True, + save_as_external_data=True, + all_tensors_to_one_file=True, + external_data_location="weights2.bin", + external_data_size_threshold=0, + ) + + self.assertTrue(modified) + + # Model should still have external data. + self.assertTrue(model_has_external_data(Path("model.qnn_pp.onnx"))) + + fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False) + + # 3 fused Ops: Gelu, LpNorm, LayerNorm + self.assertEqual(len(fused_model.graph.node), 3) + expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} + for node in fused_model.graph.node: + self.assertIn(node.op_type, expected_op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 0470953e385b6..cbb6b3ae2e776 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -555,6 +555,36 @@ def test_get_qnn_qdq_config(self): self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0)) + def test_get_qnn_qdq_config_ext_data(self): + """ + Test that get_qnn_qdq_config() returns a config that enables external data + if the input model has external data. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + large_weight = onnx.numpy_helper.from_array(np.random.random((1, 32, 32)).astype(np.float32), "weight") + graph = onnx.helper.make_graph( + [onnx.helper.make_node("Add", ["input", "weight"], ["output"])], + "add_ext_data", + [onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 32, 32))], + [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1, 32, 32))], + initializer=[large_weight], + ) + model = onnx.helper.make_model( + graph, + opset_imports=[onnx.helper.make_opsetid("", 18)], + ) + onnx.save_model( + model, + "add_ext_data.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location="add_ext_data.bin", + ) + + qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations)) + self.assertTrue(qnn_config.use_external_data_format) + if __name__ == "__main__": t = TestTensorQuantOverridesOption()