From 10f9a1fd91f0358bcc2a61e2c3cf6639641be419 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Nov 2023 21:45:54 -0800 Subject: [PATCH] Clean up Shape calls | chore(torchlib) (#1163) Update calls to `Shape` to use the `start` and `end` arguments to simplify the graph and avoid `Gather` nodes. --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8d61be90b..cf2fae3db 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4058,12 +4058,9 @@ def aten_index_put_bool( # change array([F,F,T,F,F]) to array([2]) index = op.ArgMax(index_int) # assume index only have 1 True # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Gather(op.Shape(self), 1) - index_dim_0 = op.Gather(op.Shape(index), 0) - neg_1 = op.Constant(value_ints=[-1]) - shape = op.Concat( - op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0 - ) + self_dim_1 = op.Shape(self, start=1, end=2) + index_dim_0 = op.Shape(index, start=0, end=1) + shape = op.Concat(self_dim_1, index_dim_0, axis=0) new_ind = op.Expand(index, shape) new_ind_t = op.Transpose(new_ind) @@ -7512,7 +7509,7 @@ def _center_window_around_zeros_if_needed( window: TFloatOrBFloat16, n_fft: int ) -> TFloatOrBFloat16: # first dimension - n_win = op.Gather(op.Shape(window), 0) + n_win = op.Shape(window, start=0, end=1) # Center window around zeros if needed (required by ONNX's STFT) if n_win < n_fft: left = (n_fft - n_win) / 2