Skip to content

Commit

Permalink
perf improve in python code
Browse files Browse the repository at this point in the history
get buffer_names_dict only when it's used
  • Loading branch information
zhijxu-MS committed Nov 6, 2023
1 parent dfafcb5 commit de3c24b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _expand_inputs(current_input, non_none_inputs, name=""):
_expand_inputs(inputs, non_none_inputs)
flattened_kwargs_inputs = {}
_expand_inputs(kwargs, flattened_kwargs_inputs)
buffer_names_dict = {buffer_name: inp for buffer_name, inp in named_buffer}
buffer_names_dict = None
result = []
embed_sparsity_results = OrderedDict()
label_sparsity_results = OrderedDict()
Expand All @@ -232,6 +232,8 @@ def _expand_inputs(current_input, non_none_inputs, name=""):

if inp is None:
# Registered buffers are translated to user_input+initializer in ONNX
if buffer_names_dict is None:
buffer_names_dict = {buffer_name: inp for buffer_name, inp in named_buffer}
try: # noqa: SIM105
inp = buffer_names_dict[name]
except KeyError:
Expand Down

0 comments on commit de3c24b

Please sign in to comment.