Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/fix-varmean
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Nov 17, 2023
2 parents 9790f09 + 10f9a1f commit 8b93e81
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable(
@torch_op("aten::copy")
def aten_copy(
self: TTensor,
src: TTensor,
src: TTensor2,
non_blocking: bool = False, # pylint: disable=unused-argument
) -> TTensor:
"""copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor"""

self = op.Identity(src)
return self
return op.CastLike(src, self)


@torch_op("aten::_to_copy", trace_only=True)
Expand Down Expand Up @@ -4059,12 +4058,9 @@ def aten_index_put_bool(
# change array([F,F,T,F,F]) to array([2])
index = op.ArgMax(index_int) # assume index only have 1 True
# change array([2]) to array([2,2,2,2,2])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0
)
self_dim_1 = op.Shape(self, start=1, end=2)
index_dim_0 = op.Shape(index, start=0, end=1)
shape = op.Concat(self_dim_1, index_dim_0, axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

Expand Down Expand Up @@ -7513,7 +7509,7 @@ def _center_window_around_zeros_if_needed(
window: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
# first dimension
n_win = op.Gather(op.Shape(window), 0)
n_win = op.Shape(window, start=0, end=1)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left = (n_fft - n_win) / 2
Expand Down

0 comments on commit 8b93e81

Please sign in to comment.