From f3b3e3f3d2c63e8c442c1960221b4203411517de Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 21 Nov 2023 11:50:54 +0800 Subject: [PATCH 1/3] fix past-kv in general LLM exporter --- onnxruntime/python/tools/transformers/large_model_exporter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 3b344d6dc9342..44b0bd83d49af 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -157,9 +157,10 @@ def hook_for_inputs(_, inputs, kwargs): for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): if type(value) is torch.Tensor: value.to(model.device) - # Didn't touch past_key_value now, please change it if you want if "use_cache" in key: onnx_inputs[idx] = with_past + out = model(sample_inputs[0], attention_mask=sample_inputs[1], + use_cache=with_past) if with_past else out return input_keys, onnx_inputs, out.past_key_values From 80e2275f5e98dab33e4bad891c11426fae6a12a4 Mon Sep 17 00:00:00 2001 From: JiCheng Date: Tue, 21 Nov 2023 04:40:02 +0000 Subject: [PATCH 2/3] format --- onnxruntime/python/tools/transformers/large_model_exporter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 44b0bd83d49af..988a90fbe16e9 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -159,8 +159,7 @@ def hook_for_inputs(_, inputs, kwargs): value.to(model.device) if "use_cache" in key: onnx_inputs[idx] = with_past - out = model(sample_inputs[0], attention_mask=sample_inputs[1], - use_cache=with_past) if with_past else out + out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out return input_keys, onnx_inputs, out.past_key_values From a07b4ea1699459dee3412e2340bf1e1335370a6c Mon Sep 17 00:00:00 2001 From: JiCheng Date: Tue, 21 Nov 2023 04:41:38 +0000 Subject: [PATCH 3/3] fix typo --- onnxruntime/python/tools/transformers/large_model_exporter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 988a90fbe16e9..407c3b80e153f 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -164,7 +164,7 @@ def hook_for_inputs(_, inputs, kwargs): return input_keys, onnx_inputs, out.past_key_values -def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: +def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: """ According to the model size, we will upload it to CPU if has no GPU or enough GPU memory, @@ -307,7 +307,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit """ model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) - model = move_to_approprate_device(model, sample_inputs_tp) + model = move_to_appropriate_device(model, sample_inputs_tp) sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)