Skip to content

Commit

Permalink
Fix constant folding methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Jul 30, 2024
1 parent 18f45d8 commit f5afe44
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu
unsqueezed_inputs = []
for node_input in inputs:
unsqueezed_input = op.Unsqueeze(
node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"]
node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"]
)
unsqueezed_inputs.append(unsqueezed_input)
# Send unsqueezed outputs to Concat
Expand Down Expand Up @@ -427,13 +427,13 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
num_outputs = math.ceil(split_dimension_size / split_value.item())
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(
input, axis=axis, num_outputs=num_outputs, outputs=split_outputs
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
)
elif split_value.ndim == 1:
# split into 'size(split)' chunks
num_outputs = split_value.size
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(input, split, axis=axis, outputs=split_outputs)
split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
else:
return None

Expand All @@ -442,11 +442,11 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None
if keepdims == 0:
# squeeze the split dimension if keepdims is 0
axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"])
axis_val = op.Constant(value_int=axis, _outputs=[f"{output.name}_axis"])
squeezed_values = []
for i in range(num_outputs):
squeezed = op.Squeeze(
split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"]
split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"]
)
squeezed_values.append(squeezed)
split_values = squeezed_values
Expand Down

0 comments on commit f5afe44

Please sign in to comment.