From d7f9a408f39c4318c37d3dfcd32759e8439cd38a Mon Sep 17 00:00:00 2001 From: JiCheng Date: Thu, 19 Oct 2023 05:33:50 +0000 Subject: [PATCH] fix --- .../transformers/large_model_exporter.py | 135 +++++++++++++----- 1 file changed, 102 insertions(+), 33 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 019e189783815..ceeb5d218e334 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -145,7 +145,7 @@ def hook_for_inputs(_, inputs, kwargs): forward_params = inspect.signature(model.forward).parameters input_keys = list(forward_params.keys()) default_values = [forward_params.get(key).default for key in input_keys] - model(sample_inputs[0], attention_mask=sample_inputs[1]) + out = model(sample_inputs[0], attention_mask=sample_inputs[1]) hook_handle.remove() user_inputs = user_inputs[0] onnx_inputs = default_values @@ -161,7 +161,7 @@ def hook_for_inputs(_, inputs, kwargs): if "use_cache" in key: onnx_inputs[idx] = with_past - return input_keys, tuple(onnx_inputs) + return input_keys, onnx_inputs, out.past_key_values def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: @@ -192,7 +192,7 @@ def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.M return model -def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple: +def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple: """move inputs to device""" sample_inputs_ = [] for sample_int in sample_inputs: @@ -203,43 +203,76 @@ def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple: return tuple(sample_inputs_) -@torch.no_grad() -def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): - """ - do export - model: torch model - onnx_path: where the onnx model saved to - sample_inputs_tp: inputs for torch model - """ - model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) - - model = move_to_approprate_device(model, sample_inputs_tp) - - sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) - - # input_keys would be usesful if the model has some special inputs - input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs) - - onnx_model_name = "model.onnx" - onnx_path: Path = Path(onnx_path_str).absolute() - if onnx_path.suffix != ".onnx": - onnx_path = onnx_path / onnx_model_name - +def fetch_onnx_inputs_outputs_name( + model: nn.Module, + onnx_inputs: list, + torch_input_names: tuple, + past_key_values: tuple, + with_past: bool, + input_with_past: bool, +): + """fetch onnx inputs and outputs name""" + num_of_past_key = 0 + # try get num_of_past_key and shape of past_key_value + if past_key_values is not None: + num_of_past_key = len(past_key_values) + seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1) + assert seq_index.numel() == 1 + kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"} + + if not num_of_past_key: + num_of_past_key = model.config.num_hidden_layers + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + if input_with_past: + for i in range(num_of_past_key): + onnx_inp_names += (f"present_key.{i}",) + onnx_inp_names += (f"present_values.{i}",) + + onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis + + if with_past or input_with_past: + for i in range(num_of_past_key): + onnx_out_names += (f"past_key.{i}",) + onnx_out_names += (f"past_values.{i}",) + onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis + + for idx, name in enumerate(torch_input_names): + if input_with_past: + if name == "past_key_values": + onnx_inputs[idx] = past_key_values + elif name == "attention_mask": + attn_mask = onnx_inputs[idx] + onnx_inputs[idx] = torch.cat( + (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 + ) + elif name == "input_ids": + input_ids = onnx_inputs[idx] + onnx_inputs[idx] = input_ids[:, -1:] + + return onnx_inp_names, onnx_out_names, onnx_dynamic_axes + + +def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int): + """do export with torch.onnx.export""" + onnx_model_name = onnx_path.name + onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple # two step to export onnx # 1. export onnx with lots of pieces of weights # 2. save all weights to external data with tempfile.TemporaryDirectory() as tmpdirname: tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") - onnx_inp_names = ("input_ids", "attention_mask") - onnx_out_names = ("logits",) - onnx_dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "attention_mask": {0: "batch_size", 1: "seq_len"}, - } torch.onnx.export( model=model, - args=onnx_inputs, + args=tuple(onnx_inputs), f=tmp_onnx, verbose=False, opset_version=opset, @@ -255,7 +288,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit onnx.save_model( onnx_model, str(onnx_path), - save_as_external_data=True, + save_as_external_data=(len(os.listdir(tmpdirname)) > 1), all_tensors_to_one_file=True, location=f"{onnx_model_name}_ext.data", size_threshold=1024, @@ -263,6 +296,42 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit ) +@torch.no_grad() +def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) + + model = move_to_approprate_device(model, sample_inputs_tp) + + sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) + + # input_keys would be usesful if the model has some special inputs + input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past) + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False) + + onnx_model_name = "model.onnx" + onnx_path: Path = Path(onnx_path_str).absolute() + if onnx_path.suffix != ".onnx": + onnx_path = onnx_path / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + if not with_past: + return + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True) + + onnx_model_name = "model_with_past.onnx" + onnx_path = onnx_path.parent / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + + def parse_arguments(): """arguments parsing.""" parser = argparse.ArgumentParser()