Skip to content

Commit

Permalink
Clean up Shape calls | chore(torchlib) (#1163)
Browse files Browse the repository at this point in the history
Update calls to `Shape` to use the `start` and `end` arguments to
simplify the graph and avoid `Gather` nodes.
  • Loading branch information
justinchuby authored Nov 17, 2023
1 parent 804ed01 commit 10f9a1f
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4058,12 +4058,9 @@ def aten_index_put_bool(
# change array([F,F,T,F,F]) to array([2])
index = op.ArgMax(index_int) # assume index only have 1 True
# change array([2]) to array([2,2,2,2,2])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0
)
self_dim_1 = op.Shape(self, start=1, end=2)
index_dim_0 = op.Shape(index, start=0, end=1)
shape = op.Concat(self_dim_1, index_dim_0, axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

Expand Down Expand Up @@ -7512,7 +7509,7 @@ def _center_window_around_zeros_if_needed(
window: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
# first dimension
n_win = op.Gather(op.Shape(window), 0)
n_win = op.Shape(window, start=0, end=1)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left = (n_fft - n_win) / 2
Expand Down

0 comments on commit 10f9a1f

Please sign in to comment.