Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Oct 22, 2023
1 parent 0c0c6e3 commit d7f9a40
Showing 1 changed file with 102 additions and 33 deletions.
135 changes: 102 additions & 33 deletions onnxruntime/python/tools/transformers/large_model_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def hook_for_inputs(_, inputs, kwargs):
forward_params = inspect.signature(model.forward).parameters
input_keys = list(forward_params.keys())
default_values = [forward_params.get(key).default for key in input_keys]
model(sample_inputs[0], attention_mask=sample_inputs[1])
out = model(sample_inputs[0], attention_mask=sample_inputs[1])
hook_handle.remove()
user_inputs = user_inputs[0]
onnx_inputs = default_values
Expand All @@ -161,7 +161,7 @@ def hook_for_inputs(_, inputs, kwargs):
if "use_cache" in key:
onnx_inputs[idx] = with_past

return input_keys, tuple(onnx_inputs)
return input_keys, onnx_inputs, out.past_key_values


def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
Expand Down Expand Up @@ -192,7 +192,7 @@ def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.M
return model


def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple:
def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple:
"""move inputs to device"""
sample_inputs_ = []
for sample_int in sample_inputs:
Expand All @@ -203,43 +203,76 @@ def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple:
return tuple(sample_inputs_)


@torch.no_grad()
def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int):
"""
do export
model: torch model
onnx_path: where the onnx model saved to
sample_inputs_tp: inputs for torch model
"""
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)

model = move_to_approprate_device(model, sample_inputs_tp)

sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)

# input_keys would be usesful if the model has some special inputs
input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs)

onnx_model_name = "model.onnx"
onnx_path: Path = Path(onnx_path_str).absolute()
if onnx_path.suffix != ".onnx":
onnx_path = onnx_path / onnx_model_name

def fetch_onnx_inputs_outputs_name(
model: nn.Module,
onnx_inputs: list,
torch_input_names: tuple,
past_key_values: tuple,
with_past: bool,
input_with_past: bool,
):
"""fetch onnx inputs and outputs name"""
num_of_past_key = 0
# try get num_of_past_key and shape of past_key_value
if past_key_values is not None:
num_of_past_key = len(past_key_values)
seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1)
assert seq_index.numel() == 1
kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"}

if not num_of_past_key:
num_of_past_key = model.config.num_hidden_layers

onnx_inp_names = ("input_ids", "attention_mask")
onnx_out_names = ("logits",)
onnx_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
}
if input_with_past:
for i in range(num_of_past_key):
onnx_inp_names += (f"present_key.{i}",)
onnx_inp_names += (f"present_values.{i}",)

onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis

if with_past or input_with_past:
for i in range(num_of_past_key):
onnx_out_names += (f"past_key.{i}",)
onnx_out_names += (f"past_values.{i}",)
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis

for idx, name in enumerate(torch_input_names):
if input_with_past:
if name == "past_key_values":
onnx_inputs[idx] = past_key_values
elif name == "attention_mask":
attn_mask = onnx_inputs[idx]
onnx_inputs[idx] = torch.cat(
(attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1
)
elif name == "input_ids":
input_ids = onnx_inputs[idx]
onnx_inputs[idx] = input_ids[:, -1:]

return onnx_inp_names, onnx_out_names, onnx_dynamic_axes


def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int):
"""do export with torch.onnx.export"""
onnx_model_name = onnx_path.name
onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple
# two step to export onnx
# 1. export onnx with lots of pieces of weights
# 2. save all weights to external data
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_onnx = os.path.join(tmpdirname, "tmp.onnx")

onnx_inp_names = ("input_ids", "attention_mask")
onnx_out_names = ("logits",)
onnx_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
}
torch.onnx.export(
model=model,
args=onnx_inputs,
args=tuple(onnx_inputs),
f=tmp_onnx,
verbose=False,
opset_version=opset,
Expand All @@ -255,14 +288,50 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit
onnx.save_model(
onnx_model,
str(onnx_path),
save_as_external_data=True,
save_as_external_data=(len(os.listdir(tmpdirname)) > 1),
all_tensors_to_one_file=True,
location=f"{onnx_model_name}_ext.data",
size_threshold=1024,
convert_attribute=False,
)


@torch.no_grad()
def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int):
"""
do export
model: torch model
onnx_path: where the onnx model saved to
sample_inputs_tp: inputs for torch model
"""
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)

model = move_to_approprate_device(model, sample_inputs_tp)

sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)

# input_keys would be usesful if the model has some special inputs
input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past)

onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False)

onnx_model_name = "model.onnx"
onnx_path: Path = Path(onnx_path_str).absolute()
if onnx_path.suffix != ".onnx":
onnx_path = onnx_path / onnx_model_name

do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
if not with_past:
return

onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True)

onnx_model_name = "model_with_past.onnx"
onnx_path = onnx_path.parent / onnx_model_name

do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)


def parse_arguments():
"""arguments parsing."""
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit d7f9a40

Please sign in to comment.