-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support large model export using multi-gpu (#17990)
### Description This PR is to implemente a exporter which works for large language models(LLM). It works for models like Llama2-70b or gpt-175. The main idea is to utilize multiple-GPU and dispatch differnet layers to different GPU, in short, it symply implemented auto pipeline parallelism. For example : to export Llama2-70b, you need 8x V100-32GB or 4x A100-80G or More GPU memories. It would expect to export decoder-only models. For encoder-decoder arch-like models, we didn't test it yet. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Justin Chu <[email protected]>
- Loading branch information
Showing
1 changed file
with
385 additions
and
0 deletions.
There are no files selected for viewing
385 changes: 385 additions & 0 deletions
385
onnxruntime/python/tools/transformers/large_model_exporter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,385 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
|
||
""" | ||
Export LLM to onnx | ||
""" | ||
import argparse | ||
import inspect | ||
import math | ||
import os | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import onnx | ||
import torch | ||
import transformers | ||
from torch import nn | ||
|
||
|
||
def disable_huggingface_init(): | ||
"""do not init model twice as it slow initialization""" | ||
|
||
torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.uniform_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.normal_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.constant_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x | ||
torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x | ||
|
||
|
||
def get_model_parameter_size(model: nn.Module): | ||
"""to calculate how much memory this model needs""" | ||
param_size = 0 | ||
param_sum = 0 | ||
for param in model.parameters(): | ||
param_size += param.nelement() * param.element_size() | ||
param_sum += param.nelement() | ||
buffer_size = 0 | ||
buffer_sum = 0 | ||
for buffer in model.buffers(): | ||
buffer_size += buffer.nelement() * buffer.element_size() | ||
buffer_sum += buffer.nelement() | ||
all_size = (param_size + buffer_size) / 1024 / 1024 | ||
return all_size | ||
|
||
|
||
def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None): | ||
""" | ||
get the pretrained torch model from hugginface, | ||
and sample model-inputs | ||
""" | ||
|
||
disable_huggingface_init() | ||
|
||
model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore | ||
hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True | ||
) | ||
if tokenizer is None: | ||
tokenizer = hf_model | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore | ||
|
||
sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) | ||
return model, sample_inputs | ||
|
||
|
||
def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple): | ||
"""Make the model executable across multiple GPUs.""" | ||
|
||
def input_gpu_device_hook(mod, inputs, kwargs): | ||
modifyed_inputs = [] | ||
first_dev = None | ||
for layer_input in inputs: | ||
if type(layer_input) is not torch.Tensor: | ||
modifyed_inputs.append(layer_input) | ||
elif hasattr(mod, "weight"): | ||
modifyed_inputs.append(layer_input.to(mod.weight.device)) | ||
elif hasattr(mod, "parameters"): | ||
device = next(mod.parameters(), layer_input).device | ||
modifyed_inputs.append(layer_input.to(device)) | ||
elif hasattr(next(mod.children(), None), "weight"): | ||
modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) | ||
elif first_dev is not None and layer_input.device != first_dev: | ||
modifyed_inputs.append(layer_input.to(first_dev)) | ||
else: | ||
modifyed_inputs.append(layer_input) | ||
if first_dev is None: | ||
first_dev = modifyed_inputs[0].device | ||
for key, value in kwargs.items(): | ||
if type(value) is torch.Tensor: | ||
kwargs[key] = value.to(first_dev) | ||
|
||
return (tuple(modifyed_inputs), kwargs) | ||
|
||
def move_layer_to_device_rurc(mod, dev): | ||
mod.to(dev) | ||
for layer in mod.named_children(): | ||
move_layer_to_device_rurc(layer[1], dev) | ||
|
||
model = model.half() | ||
all_hooks = [] | ||
all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) | ||
pre_fix = next(iter(model.named_children()))[0] | ||
for top_name, top_module in model.named_children(): | ||
for name, module in top_module.named_children(): | ||
all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) | ||
if type(module) in [torch.nn.ModuleList]: | ||
num_layers_on_each_gpu = math.floor(len(module) / len(gpulist)) | ||
for idx, attn_layer in enumerate(module): | ||
all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) | ||
|
||
to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))] | ||
attn_layer.to(to_dev) | ||
move_layer_to_device_rurc(attn_layer, to_dev) | ||
print(f"move {pre_fix}.{name}.{idx} to {to_dev}") | ||
else: | ||
module.to(gpulist[0]) | ||
print(f"move {pre_fix}.{name} to {gpulist[0]}") | ||
if len(list(top_module.named_children())) == 0: | ||
top_module.to(gpulist[0]) | ||
print(f"move {top_name} to {gpulist[0]}") | ||
|
||
with torch.no_grad(): | ||
model(sample_inputs[0], attention_mask=sample_inputs[1]) | ||
return model | ||
|
||
|
||
def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool): | ||
""" | ||
auto retrieve onnx inputs from torch model as we can't enumlate all possibilities | ||
for all models | ||
""" | ||
user_inputs = [] | ||
|
||
def hook_for_inputs(_, inputs, kwargs): | ||
user_inputs.append((inputs, kwargs)) | ||
return user_inputs[0] | ||
|
||
hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) | ||
|
||
forward_params = inspect.signature(model.forward).parameters | ||
input_keys = list(forward_params.keys()) | ||
default_values = [forward_params.get(key).default for key in input_keys] | ||
out = model(sample_inputs[0], attention_mask=sample_inputs[1]) | ||
hook_handle.remove() | ||
user_inputs = user_inputs[0] | ||
onnx_inputs = default_values | ||
for idx, _val in enumerate(user_inputs[0]): | ||
onnx_inputs[idx] = user_inputs[0][idx] | ||
for key, value in user_inputs[1].items(): | ||
idx = input_keys.index(key) | ||
onnx_inputs[idx] = value | ||
for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): | ||
if type(value) is torch.Tensor: | ||
value.to(model.device) | ||
# Didn't touch past_key_value now, please change it if you want | ||
if "use_cache" in key: | ||
onnx_inputs[idx] = with_past | ||
|
||
return input_keys, onnx_inputs, out.past_key_values | ||
|
||
|
||
def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: | ||
""" | ||
According to the model size, we will upload it to | ||
CPU if has no GPU or enough GPU memory, | ||
Single GPU if has only one GPU in local or model size is enough to fit one GPU | ||
Multiple GPU if there is more than one gpu in local and model is too large | ||
""" | ||
total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 | ||
|
||
print(f"Model_Size = {get_model_parameter_size(model)/1024} GB") | ||
print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB") | ||
if get_model_parameter_size(model) > total_mem_per_cpu * 0.45: | ||
device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] | ||
if len(device_collection) > 1: | ||
print( | ||
f"{len(device_collection)} GPUs are used to export onnx, \ | ||
Please set CUDA_VISIBLE_DEVICES to use specific GPU group" | ||
) | ||
model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp) | ||
else: | ||
print("!!!! convert model to float and export onnx using CPU") | ||
model = model.cpu().float() | ||
else: | ||
print("Export model on a single GPU") | ||
model = model.cuda().half() | ||
return model | ||
|
||
|
||
def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple: | ||
"""move inputs to device""" | ||
sample_inputs_ = [] | ||
for sample_int in sample_inputs: | ||
if isinstance(sample_int, torch.Tensor): | ||
sample_inputs_.append(sample_int.to(device)) | ||
else: | ||
sample_inputs_.append(sample_int) | ||
return tuple(sample_inputs_) | ||
|
||
|
||
def fetch_onnx_inputs_outputs_name( | ||
model: nn.Module, | ||
onnx_inputs: list, | ||
torch_input_names: tuple, | ||
past_key_values: tuple, | ||
with_past: bool, | ||
input_with_past: bool, | ||
): | ||
"""fetch onnx inputs and outputs name""" | ||
num_of_past_key = 0 | ||
kv_cache_axis = {0: "batch_size"} | ||
# try get num_of_past_key and shape of past_key_value | ||
if past_key_values is not None: | ||
num_of_past_key = len(past_key_values) | ||
seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1) | ||
assert seq_index.numel() == 1 | ||
kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"} | ||
|
||
if not num_of_past_key: | ||
num_of_past_key = model.config.num_hidden_layers | ||
|
||
onnx_inp_names = ("input_ids", "attention_mask") | ||
onnx_out_names = ("logits",) | ||
onnx_dynamic_axes = { | ||
"input_ids": {0: "batch_size", 1: "seq_len"}, | ||
"attention_mask": {0: "batch_size", 1: "seq_len"}, | ||
} | ||
if input_with_past: | ||
for i in range(num_of_past_key): | ||
onnx_inp_names += (f"present_key.{i}",) | ||
onnx_inp_names += (f"present_values.{i}",) | ||
|
||
onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis | ||
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis | ||
|
||
if with_past or input_with_past: | ||
for i in range(num_of_past_key): | ||
onnx_out_names += (f"past_key.{i}",) | ||
onnx_out_names += (f"past_values.{i}",) | ||
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis | ||
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis | ||
|
||
for idx, name in enumerate(torch_input_names): | ||
if input_with_past: | ||
if name == "past_key_values": | ||
onnx_inputs[idx] = past_key_values | ||
elif name == "attention_mask": | ||
attn_mask = onnx_inputs[idx] | ||
onnx_inputs[idx] = torch.cat( | ||
(attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 | ||
) | ||
elif name == "input_ids": | ||
input_ids = onnx_inputs[idx] | ||
onnx_inputs[idx] = input_ids[:, -1:] | ||
|
||
return onnx_inp_names, onnx_out_names, onnx_dynamic_axes | ||
|
||
|
||
def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int): | ||
"""do export with torch.onnx.export""" | ||
onnx_model_name = onnx_path.name | ||
onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple | ||
# two step to export onnx | ||
# 1. export onnx with lots of pieces of weights | ||
# 2. save all weights to external data | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") | ||
|
||
torch.onnx.export( | ||
model=model, | ||
args=tuple(onnx_inputs), | ||
f=tmp_onnx, | ||
verbose=False, | ||
opset_version=opset, | ||
input_names=onnx_inp_names, | ||
output_names=onnx_out_names, | ||
dynamic_axes=onnx_dynamic_axes, | ||
) | ||
|
||
onnx_path.unlink(missing_ok=True) | ||
(onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) | ||
|
||
onnx_model = onnx.load(str(tmp_onnx)) | ||
onnx.save_model( | ||
onnx_model, | ||
str(onnx_path), | ||
save_as_external_data=(len(os.listdir(tmpdirname)) > 1), | ||
all_tensors_to_one_file=True, | ||
location=f"{onnx_model_name}_ext.data", | ||
size_threshold=1024, | ||
convert_attribute=False, | ||
) | ||
|
||
|
||
@torch.no_grad() | ||
def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): | ||
""" | ||
do export | ||
model: torch model | ||
onnx_path: where the onnx model saved to | ||
sample_inputs_tp: inputs for torch model | ||
""" | ||
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) | ||
|
||
model = move_to_approprate_device(model, sample_inputs_tp) | ||
|
||
sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) | ||
|
||
# input_keys would be usesful if the model has some special inputs | ||
input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past) | ||
|
||
onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False) | ||
|
||
onnx_model_name = "model.onnx" | ||
onnx_path: Path = Path(onnx_path_str).absolute() | ||
if onnx_path.suffix != ".onnx": | ||
onnx_path = onnx_path / onnx_model_name | ||
|
||
do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) | ||
if not with_past: | ||
return | ||
|
||
onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True) | ||
|
||
onnx_model_name = "model_with_past.onnx" | ||
onnx_path = onnx_path.parent / onnx_model_name | ||
|
||
do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) | ||
|
||
|
||
def parse_arguments(): | ||
"""arguments parsing.""" | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"-m", | ||
"--model", | ||
required=True, | ||
type=str, | ||
default=["meta-llama/Llama-2-70b-hf"], | ||
help="Pre-trained models in huggingface model hub", | ||
) | ||
parser.add_argument( | ||
"-s", | ||
"--saved_path", | ||
required=False, | ||
type=str, | ||
default="./onnx_models/", | ||
help="where the onnx model will be saved", | ||
) | ||
parser.add_argument( | ||
"--cache_dir", | ||
required=False, | ||
type=str, | ||
default=None, | ||
help=("cache directy of huggingface, by setting this to avoid useless downloading if you have one"), | ||
) | ||
parser.add_argument( | ||
"--with_past", | ||
action="store_true", | ||
default=False, | ||
help=("The tool will export onnx without past-key-value by default"), | ||
) | ||
parser.add_argument( | ||
"--opset", | ||
required=False, | ||
type=int, | ||
default=17, | ||
help=( | ||
"the opset to save onnx model, \ | ||
try to increase it if this opset doens't have new features you want" | ||
), | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
|
||
export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset) |