Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama.covert_onnx to make it runnable in CI #19372

Merged
merged 11 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -781,16 +781,21 @@ def get_args():
action="store_true",
help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.",
)

parser.add_argument(
"--small_gpu",
action="store_true",
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB.",
)

parser.set_defaults(optimize_optimum=False)

args = parser.parse_args()
return args


def main():
if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__:
# Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false
# in that scenario. It can be removed when torch v2.2.0 is released in stable.
if version.parse(torch.__version__) < version.parse("2.2.0"):
logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.")
return

Expand Down Expand Up @@ -1021,7 +1026,11 @@ def main():
args.precision,
"--cache_dir",
args.cache_dir,
"--torch_model_directory",
args.input,
]
if args.small_gpu:
parity_cmd.append("--small_gpu")
if "with_past" in filename:
parity_cmd.append("--use_past_kv")
if "merged" in filename:
Expand All @@ -1030,7 +1039,7 @@ def main():
parity_cmd.append("--use_gqa")

try:
logger.debug(f"check parity with cmd: {parity_cmd}")
logger.info(f"check parity with cmd: {parity_cmd}")
parity_check(parity_cmd)
except Exception as e:
logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True)
Expand Down
62 changes: 47 additions & 15 deletions onnxruntime/python/tools/transformers/models/llama/llama_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_sample_with_past_kv_inputs,
)
from llama_torch import setup_torch_model
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig

import onnxruntime as ort

Expand Down Expand Up @@ -67,20 +67,40 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig):


def verify_parity(
args: argparse.Namespace, config: AutoConfig, pt_model: AutoModelForCausalLM, kv_cache_ortvalues: dict
args: argparse.Namespace,
location,
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
use_auth_token,
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
kv_cache_ortvalues: dict,
pytorch_model: None | torch.nn.Module = None,
config: None | AutoConfig = None,
):
# If it's running in a machine which GPU memory < 36GB, it should unload the llama in GPU in time and free the GPU memory for ORT.
py_model = pytorch_model
if py_model is None:
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
config, llama = setup_torch_model(
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
py_model = llama
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved

inputs = get_inputs(args, config)

# Run inference with PyTorch
if args.execution_provider != "cpu":
torch.cuda.synchronize()
start_time = time.time()
pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy()
pt_outputs = py_model(**inputs).logits.detach().cpu().numpy()
if args.execution_provider != "cpu":
torch.cuda.synchronize()
end_time = time.time()
logger.info(f"PyTorch took {end_time - start_time} s")
del pt_model

if pytorch_model is None:
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
del py_model
torch.cuda.empty_cache()

# Run inference with ORT
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args)
Expand Down Expand Up @@ -222,6 +242,13 @@ def get_args(argv: list[str]):
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)

# The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
parser.add_argument(
"--small_gpu",
action="store_true",
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
)

args = parser.parse_args() if argv == [] else parser.parse_args(argv)

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
Expand All @@ -247,25 +274,30 @@ def main(argv: list[str] = []): # noqa: B006
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory

config, llama = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)

kv_cache_ortvalues = {}
if not args.merged:
verify_parity(args, config, llama, kv_cache_ortvalues)
verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
else:
if args.small_gpu:
config = llama = None
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
else:
config, llama = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)

# Verify prompt generation in merged model (decoder_model.onnx)
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
args.use_past_kv = False
kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues)
kv_cache_ortvalues = verify_parity(
args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
)

# Verify token generation in merged model (decoder_with_past_model.onnx)
args.use_past_kv = True
verify_parity(args, config, llama, kv_cache_ortvalues)
verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-r requirements.txt
# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system.
# Please manually install torch>=2.2.0 with CUDA enabled for the CUDA version installed in your system.
# Instructions can be found here: https://pytorch.org/get-started/locally/
onnxruntime-gpu>=1.16.2
onnxruntime-gpu>=1.16.2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
optimum>=1.14.1
transformers>=4.33.2
torch>=2.2.0.dev20230920
torch>=2.2.0
onnx>=1.14.0
datasets>=2.8.0
protobuf==3.20.2
protobuf==3.20.2
53 changes: 21 additions & 32 deletions tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ stages:
skipComponentGovernanceDetection: true
workspace:
clean: all
pool: onnxruntime-Linux-GPU-T4
pool: Onnxruntime-Linux-A10-24G
steps:
- task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
displayName: 'Clean Agent Directories'
Expand All @@ -243,10 +243,6 @@ stages:
clean: true
submodules: none

- checkout: LLaMa2Onnx
clean: true
submodules: none

- template: templates/flex-downloadPipelineArtifact.yml
parameters:
StepName: 'Download Onnxruntime Artifact'
Expand All @@ -255,47 +251,40 @@ stages:
SpecificArtifact: ${{ parameters.specificArtifact }}
BuildId: ${{ parameters.BuildId }}

- task: DownloadPackage@1
displayName: 'Download Llama2 model'
inputs:
packageType: upack
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
version: 1.0.0
definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce'
downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16

- template: templates/get-docker-image-steps.yml
parameters:
Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6
Context: onnxruntime/tools/ci_build/github/linux/docker/
ScriptName: onnxruntime/tools/ci_build/get_docker_image.py
Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6
Context: tools/ci_build/github/linux/docker/
ScriptName: tools/ci_build/get_docker_image.py
DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )"
Repository: onnxruntimeubi8packagestest
UpdateDepsTxt: false

- task: DownloadPackage@1
displayName: 'Download Meta Llama2 model'
inputs:
packageType: upack
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
version: 1.0.0
definition: '6fe0c4ed-9d0e-4d66-94cc-fb6a111d02a5'
downloadPath: $(Agent.TempDirectory)/meta_llama2_7b_hf

- script: |
docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \
docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \
-v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \
-v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \
-v $(Agent.TempDirectory)/meta_llama2_7b_hf:/meta-llama2 \
onnxruntimeubi8packagestest \
bash -c "
set -ex; \
pushd /workspace/onnxruntime/python/tools/transformers/ ; \
python3 -m pip install --upgrade pip ; \
pushd models/llama ; \
python3 -m pip install -r requirements-cuda.txt ; \
popd ; \
python3 -m pip install /ort-artifact/*.whl ; \
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \
python3 -m pip install sentencepiece ; \
pushd /workspace ; \
python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \
--embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \
python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\
popd ; \
"
displayName: 'Run Llama2 demo'
displayName: 'Run Llama2 to Onnx F16 and parity Test'
workingDirectory: $(Build.SourcesDirectory)

- script: |
set -ex
real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt)
trim_actual=$(tr -dc '[[:print:]]' <<< "$real")
expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)."
[ "$expected" == "$trim_actual" ] && exit 0 || exit 1
displayName: 'Check result'
Loading