Skip to content

Commit

Permalink
Merge branch 'main' into more_urls
Browse files Browse the repository at this point in the history
  • Loading branch information
take-cheeze authored Nov 20, 2023
2 parents ab9b10c + 10f9a1f commit 881788b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
16 changes: 6 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable(
@torch_op("aten::copy")
def aten_copy(
self: TTensor,
src: TTensor,
src: TTensor2,
non_blocking: bool = False, # pylint: disable=unused-argument
) -> TTensor:
"""copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor"""

self = op.Identity(src)
return self
return op.CastLike(src, self)


@torch_op("aten::_to_copy", trace_only=True)
Expand Down Expand Up @@ -4059,12 +4058,9 @@ def aten_index_put_bool(
# change array([F,F,T,F,F]) to array([2])
index = op.ArgMax(index_int) # assume index only have 1 True
# change array([2]) to array([2,2,2,2,2])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0
)
self_dim_1 = op.Shape(self, start=1, end=2)
index_dim_0 = op.Shape(index, start=0, end=1)
shape = op.Concat(self_dim_1, index_dim_0, axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

Expand Down Expand Up @@ -7513,7 +7509,7 @@ def _center_window_around_zeros_if_needed(
window: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
# first dimension
n_win = op.Gather(op.Shape(window), 0)
n_win = op.Shape(window, start=0, end=1)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left = (n_fft - n_win) / 2
Expand Down
7 changes: 5 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,10 +2038,13 @@ def aten_soft_margin_loss_backward(
raise NotImplementedError()


def aten_softplus(self: TensorType, beta: float = 1.0, threshold: float = 20.0) -> TensorType:
@torch_op("aten::softplus")
def aten_softplus(self: TFloat, beta: float = 1.0, threshold: float = 20.0) -> TFloat:
"""softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor"""

raise NotImplementedError()
self_scaled = self * beta
softplus = op.Softplus(self_scaled) / beta
return op.Where(self_scaled > threshold, self, softplus)


def aten_softplus_backward(
Expand Down
5 changes: 5 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,11 @@ def _where_input_wrangler(
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449",
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo(
"split_with_sizes",
core_ops.aten_split_with_sizes,
Expand Down

0 comments on commit 881788b

Please sign in to comment.