Skip to content

Commit

Permalink
Implement aten::normal overloads | feat(torchlib) (#1094)
Browse files Browse the repository at this point in the history
Implement

- aten::normal.Tensor_Tensor
- aten::normal.float_Tensor
- aten::normal.Tensor_float

Created customized test cases for the op.
  • Loading branch information
justinchuby authored Oct 17, 2023
1 parent 75e61a1 commit 87029cf
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 3 deletions.
32 changes: 30 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5832,13 +5832,13 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp
raise NotImplementedError()


@torch_op("aten::normal")
@torch_op(("aten::normal", "aten::normal_functional"))
def aten_normal(
self: TTensor,
mean: float = 0.0,
std: float = 1.0,
) -> TFloat: # type: ignore[type-var]
# normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor
"""normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor"""

self_rank = op.Size(op.Shape(self))
if self_rank == 0:
Expand All @@ -5860,6 +5860,34 @@ def aten_normal_float_float(
return op.Cast(result, to=dtype)


@torch_op("aten::normal.float_Tensor")
def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat:
"""normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor"""

mean_casted = op.CastLike(mean, std)
sampled = op.RandomNormalLike(mean_casted, mean=0.0, scale=1.0)
# Transform the distribution to the mean and std
return op.Add(op.Mul(mean_casted, sampled), std)


@torch_op("aten::normal.Tensor_float")
def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat:
"""normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor"""

sampled = op.RandomNormalLike(mean, mean=0.0, scale=1.0)
# Transform the distribution to the mean and std
return op.Add(op.Mul(mean, sampled), op.CastLike(std, sampled))


@torch_op("aten::normal.Tensor_Tensor")
def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat:
"""normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor"""

sampled = op.RandomNormalLike(mean, mean=0.0, scale=1.0)
# Transform the distribution to the mean and std
return op.Add(op.Mul(mean, sampled), std)


def aten_not_equal(self: TensorType, other: TensorType) -> TensorType:
"""not_equal.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down
66 changes: 66 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,51 @@ def sample_inputs_native_dropout(
yield opinfo_core.SampleInput(make_arg(case), p=p, train=training)


def sample_inputs_normal_tensor_float(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del requires_grad
del kwargs
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False
)
samples = (
((S, S), 0.0),
((S, S, S), 4.2),
)
for mean, std in samples:
yield opinfo_core.SampleInput(make_arg(mean), std)


def sample_inputs_normal_float_tensor(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del requires_grad
del kwargs
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False
)
samples = (
(4.2, (S, S)),
(-2.0, (S, S, S)),
)
for mean, std in samples:
yield opinfo_core.SampleInput(mean, make_arg(std, low=0.0))


def sample_inputs_normal_tensor_tensor(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del requires_grad
del kwargs
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False
)
samples = (
((S, S), (S, S)),
((S, S, S), (S, S, S)),
)
for mean, std in samples:
yield opinfo_core.SampleInput(make_arg(mean), make_arg(std, low=0.0))


def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del device # Unused
Expand Down Expand Up @@ -1262,6 +1307,27 @@ def sample_inputs_scaled_dot_product_flash_attention(
sample_inputs_func=sample_inputs_native_group_norm,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.normal.float_Tensor",
aten_name="normal.Tensor_Tensor",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs_normal_float_tensor,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.normal.Tensor_float",
aten_name="normal.Tensor_Tensor",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs_normal_tensor_float,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.normal.Tensor_Tensor",
aten_name="normal.Tensor_Tensor",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs_normal_tensor_tensor,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.max_pool1d_with_indices",
aten_name="max_pool1d_with_indices",
Expand Down
30 changes: 29 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,8 @@ def _where_input_wrangler(
TorchLibOpInfo(
"matmul",
core_ops.aten_matmul,
tolerance={torch.float32: (2e-5, 2e-5)}, # Windows requires a more relaxed tolerance
# Windows requires a more relaxed tolerance
tolerance={torch.float32: (2e-5, 2e-5)},
).skip(
matcher=lambda sample: torch.numel(sample.input) == 0,
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
Expand Down Expand Up @@ -1154,6 +1155,33 @@ def _where_input_wrangler(
reason="This variant does not support dtype as an argument",
matcher=lambda sample: sample.kwargs.get("dtype") is not None,
),
TorchLibOpInfo(
"ops.aten.normal.float_Tensor",
core_ops.aten_normal_float_tensor,
nondeterministic=True,
).xfail(
reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449",
dtypes=(torch.float16,),
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo(
"ops.aten.normal.Tensor_float",
core_ops.aten_normal_tensor_float,
nondeterministic=True,
).xfail(
reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449",
dtypes=(torch.float16,),
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo(
"ops.aten.normal.Tensor_Tensor",
core_ops.aten_normal_tensor_tensor,
nondeterministic=True,
).xfail(
reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449",
dtypes=(torch.float16,),
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo("ones", core_ops.aten_ones),
TorchLibOpInfo(
"permute",
Expand Down

0 comments on commit 87029cf

Please sign in to comment.