From 435e19953ea54115124fd637a67a87681a7fc8eb Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 5 Feb 2024 07:26:24 +0800 Subject: [PATCH] Fix llama.covert_onnx to make it runnable in CI (#19372) ### Description 1. make parity_check use local model to avoid using hf token 2. del the model didn't work because it tried to del the object define out of the function scope. So it caused out of memory in A10. 3. In fact, 16G GPU memory (one T4) is enough. But the conversion process always be killed in T4 and it works on A10/24G. Standard_NC4as_T4_v3 has 28G CPU memory Standard_NV36ads_A10_v5 has 440G memory. It looks that the model conversion needs very huge memory. ### Motivation and Context Last time, I came across some issues in convert_to_onnx.py so I use the onnx model in https://github.com/microsoft/Llama-2-Onnx for testing. Now, these issues could be fixed. So I use onnx model generated by this repo and the CI can cover the model conversion. --- .../models/llama/convert_to_onnx.py | 17 +++-- .../transformers/models/llama/llama_parity.py | 62 ++++++++++++++----- .../models/llama/requirements-cuda.txt | 4 +- .../models/llama/requirements.txt | 4 +- .../azure-pipelines/bigmodels-ci-pipeline.yml | 53 +++++++--------- 5 files changed, 84 insertions(+), 56 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 71f52faa2c1e6..c9ff384a4c856 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -781,6 +781,13 @@ def get_args(): action="store_true", help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.", ) + + parser.add_argument( + "--small_gpu", + action="store_true", + help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB.", + ) + parser.set_defaults(optimize_optimum=False) args = parser.parse_args() @@ -788,9 +795,7 @@ def get_args(): def main(): - if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__: - # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false - # in that scenario. It can be removed when torch v2.2.0 is released in stable. + if version.parse(torch.__version__) < version.parse("2.2.0"): logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") return @@ -1021,7 +1026,11 @@ def main(): args.precision, "--cache_dir", args.cache_dir, + "--torch_model_directory", + args.input, ] + if args.small_gpu: + parity_cmd.append("--small_gpu") if "with_past" in filename: parity_cmd.append("--use_past_kv") if "merged" in filename: @@ -1030,7 +1039,7 @@ def main(): parity_cmd.append("--use_gqa") try: - logger.debug(f"check parity with cmd: {parity_cmd}") + logger.info(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 25d7519769604..f41a90208c51b 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -17,7 +17,7 @@ get_sample_with_past_kv_inputs, ) from llama_torch import setup_torch_model -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig import onnxruntime as ort @@ -67,20 +67,39 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): def verify_parity( - args: argparse.Namespace, config: AutoConfig, pt_model: AutoModelForCausalLM, kv_cache_ortvalues: dict + args: argparse.Namespace, + location: str, + use_auth_token: bool, + kv_cache_ortvalues: dict, + pytorch_model: None | torch.nn.Module = None, + config: None | AutoConfig = None, ): + # If it's running in a machine which GPU memory < 36GB, it should unload the llama in GPU in time and free the GPU memory for ORT. + py_model = pytorch_model + if py_model is None: + config, py_model = setup_torch_model( + args, + location, + use_auth_token, + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), + device=args.device, + ) + inputs = get_inputs(args, config) # Run inference with PyTorch if args.execution_provider != "cpu": torch.cuda.synchronize() start_time = time.time() - pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy() + pt_outputs = py_model(**inputs).logits.detach().cpu().numpy() if args.execution_provider != "cpu": torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") - del pt_model + + if args.small_gpu and py_model is not None: + del py_model + torch.cuda.empty_cache() # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) @@ -222,6 +241,13 @@ def get_args(argv: list[str]): help="model cache dir to override default HF cache dir to avoid overflood the /home dir", ) + # The argument is used for CI mainly, because the CI machine has 24G GPU memory at most. + parser.add_argument( + "--small_gpu", + action="store_true", + help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ", + ) + args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -247,25 +273,29 @@ def main(argv: list[str] = []): # noqa: B006 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config, llama = setup_torch_model( - args, - location, - use_auth_token, - torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), - device=args.device, - ) - kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama, kv_cache_ortvalues) + verify_parity(args, location, use_auth_token, kv_cache_ortvalues) else: - # Verify prompt generation in merged model (decoder_model.onnx) + config = llama = None + if not args.small_gpu: + config, llama = setup_torch_model( + args, + location, + use_auth_token, + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), + device=args.device, + ) + + # Verify prompt processing in merged model (decoder_model.onnx) args.use_past_kv = False - kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) + kv_cache_ortvalues = verify_parity( + args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config + ) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama, kv_cache_ortvalues) + verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index b634bcc50f6e4..acd9c23aa42d0 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt -# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. +# Please manually install torch>=2.2.0 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.16.2 \ No newline at end of file +onnxruntime-gpu>=1.16.2 diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index b72c972e7a16a..8b57279295e35 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,6 +1,6 @@ optimum>=1.14.1 transformers>=4.33.2 -torch>=2.2.0.dev20230920 +torch>=2.2.0 onnx>=1.14.0 datasets>=2.8.0 -protobuf==3.20.2 \ No newline at end of file +protobuf==3.20.2 diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 0de2ac44215c4..65866fc9827a5 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -268,7 +268,7 @@ stages: skipComponentGovernanceDetection: true workspace: clean: all - pool: onnxruntime-Linux-GPU-T4 + pool: Onnxruntime-Linux-A10-24G steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -278,10 +278,6 @@ stages: clean: true submodules: none - - checkout: LLaMa2Onnx - clean: true - submodules: none - - template: templates/flex-downloadPipelineArtifact.yml parameters: StepName: 'Download Onnxruntime Artifact' @@ -290,47 +286,40 @@ stages: SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} - - task: DownloadPackage@1 - displayName: 'Download Llama2 model' - inputs: - packageType: upack - feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' - version: 1.0.0 - definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce' - downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16 - - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 - Context: onnxruntime/tools/ci_build/github/linux/docker/ - ScriptName: onnxruntime/tools/ci_build/get_docker_image.py + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + ScriptName: tools/ci_build/get_docker_image.py DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimeubi8packagestest UpdateDepsTxt: false + - task: DownloadPackage@1 + displayName: 'Download Meta Llama2 model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: '6fe0c4ed-9d0e-4d66-94cc-fb6a111d02a5' + downloadPath: $(Agent.TempDirectory)/meta_llama2_7b_hf + - script: | - docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \ + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ - -v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \ + -v $(Agent.TempDirectory)/meta_llama2_7b_hf:/meta-llama2 \ onnxruntimeubi8packagestest \ bash -c " set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ python3 -m pip install --upgrade pip ; \ + pushd models/llama ; \ + python3 -m pip install -r requirements-cuda.txt ; \ + popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ - python3 -m pip install sentencepiece ; \ - pushd /workspace ; \ - python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \ - --embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\ popd ; \ " - displayName: 'Run Llama2 demo' + displayName: 'Run Llama2 to Onnx F16 and parity Test' workingDirectory: $(Build.SourcesDirectory) - - - script: | - set -ex - real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt) - trim_actual=$(tr -dc '[[:print:]]' <<< "$real") - expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)." - [ "$expected" == "$trim_actual" ] && exit 0 || exit 1 - displayName: 'Check result'