diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 3b344d6dc9342..407c3b80e153f 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -157,14 +157,14 @@ def hook_for_inputs(_, inputs, kwargs): for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): if type(value) is torch.Tensor: value.to(model.device) - # Didn't touch past_key_value now, please change it if you want if "use_cache" in key: onnx_inputs[idx] = with_past + out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out return input_keys, onnx_inputs, out.past_key_values -def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: +def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: """ According to the model size, we will upload it to CPU if has no GPU or enough GPU memory, @@ -307,7 +307,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit """ model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) - model = move_to_approprate_device(model, sample_inputs_tp) + model = move_to_appropriate_device(model, sample_inputs_tp) sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)