Skip to content

Commit

Permalink
fix past-kv in general LLM exporter (#18529)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

For some models, we need to re run model.forward to get past-kv

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wejoncy authored Nov 21, 2023
1 parent c7fd930 commit a608c00
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/transformers/large_model_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ def hook_for_inputs(_, inputs, kwargs):
for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)):
if type(value) is torch.Tensor:
value.to(model.device)
# Didn't touch past_key_value now, please change it if you want
if "use_cache" in key:
onnx_inputs[idx] = with_past
out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out

return input_keys, onnx_inputs, out.past_key_values


def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
"""
According to the model size, we will upload it to
CPU if has no GPU or enough GPU memory,
Expand Down Expand Up @@ -307,7 +307,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit
"""
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)

model = move_to_approprate_device(model, sample_inputs_tp)
model = move_to_appropriate_device(model, sample_inputs_tp)

sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)

Expand Down

0 comments on commit a608c00

Please sign in to comment.