Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement rand* ops | feat(torchilb) #1035

Merged
merged 8 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 110 additions & 19 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6232,47 +6232,138 @@


@torch_op("aten::rand")
def aten_rand(size: Sequence[int], dtype: int = 1) -> TReal:
def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
"""rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

return op.RandomUniform(shape=size, dtype=dtype)
shaper = op.ConstantOfShape(size)
return op.RandomUniformLike(shaper, dtype=dtype)

Check warning on line 6239 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6238-L6239

Added lines #L6238 - L6239 were not covered by tests


def aten_rand_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::rand_like")
def aten_rand_like(self: TFloat) -> TFloat:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

raise NotImplementedError()
return op.RandomUniformLike(self)

Check warning on line 6246 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6246

Added line #L6246 was not covered by tests


@torch_op("aten::rand_like")
def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

def aten_randint(high: int, size: INT64) -> TensorType:
"""randint(int high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
return op.RandomUniformLike(self, dtype=dtype)

Check warning on line 6253 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6253

Added line #L6253 was not covered by tests

raise NotImplementedError()

@torch_op("aten::randint")
def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType:
"""randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

def aten_randint_like(
self: TensorType, high: int, memory_format: Optional[str] = None
shaper = op.ConstantOfShape(size)
rand = op.RandomUniformLike(shaper)

Check warning on line 6261 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6260-L6261

Added lines #L6260 - L6261 were not covered by tests
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))

Check warning on line 6263 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6263

Added line #L6263 was not covered by tests
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.Cast(rand_int, to=dtype)

Check warning on line 6266 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6265-L6266

Added lines #L6265 - L6266 were not covered by tests


@torch_op("aten::randint.low")
def aten_randint_low(
low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype
) -> TensorType:
"""randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
"""randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
shaper = op.ConstantOfShape(size)
rand = op.RandomUniformLike(shaper)

Check warning on line 6276 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6275-L6276

Added lines #L6275 - L6276 were not covered by tests
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)

Check warning on line 6280 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6278-L6280

Added lines #L6278 - L6280 were not covered by tests
# Round to ints
rand_int = op.Floor(rand_translated)
return op.Cast(rand_int, to=dtype)

Check warning on line 6283 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6282-L6283

Added lines #L6282 - L6283 were not covered by tests


@torch_op("aten::randint_like")
def aten_randint_like(self: TensorType, high: INT64) -> IntType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)

Check warning on line 6291 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6290-L6291

Added lines #L6290 - L6291 were not covered by tests
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))

Check warning on line 6293 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6293

Added line #L6293 was not covered by tests
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.CastLike(rand_int, self)

Check warning on line 6296 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6295-L6296

Added lines #L6295 - L6296 were not covered by tests


@torch_op("aten::randint_like")
def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)

Check warning on line 6304 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6303-L6304

Added lines #L6303 - L6304 were not covered by tests
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))

Check warning on line 6306 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6306

Added line #L6306 was not covered by tests
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.Cast(rand_int, to=dtype)

Check warning on line 6309 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6308-L6309

Added lines #L6308 - L6309 were not covered by tests


@torch_op("aten::randint_like.low_dtype")
def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None.
"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)

Check warning on line 6320 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6319-L6320

Added lines #L6319 - L6320 were not covered by tests
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)

Check warning on line 6324 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6322-L6324

Added lines #L6322 - L6324 were not covered by tests
# Round to ints
rand_int = op.Floor(rand_translated)
return op.CastLike(rand_int, self)

Check warning on line 6327 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6326-L6327

Added lines #L6326 - L6327 were not covered by tests


@torch_op("aten::randint_like.low_dtype")
def aten_randint_like_low_dtype_dtype(
self: TensorType, low: INT64, high: INT64, dtype: int
) -> TensorType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)

Check warning on line 6337 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6336-L6337

Added lines #L6336 - L6337 were not covered by tests
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)

Check warning on line 6341 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6339-L6341

Added lines #L6339 - L6341 were not covered by tests
# Round to ints
rand_int = op.Floor(rand_translated)
return op.Cast(rand_int, to=dtype)

Check warning on line 6344 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6343-L6344

Added lines #L6343 - L6344 were not covered by tests


@torch_op("aten::randn")
def aten_randn(
size: Sequence[int],
dtype: int = 1,
requires_grad: bool = False, # pylint: disable=unused-argument
) -> TReal:
def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
"""randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

return op.RandomNormal(dtype=dtype, shape=size)
shaper = op.ConstantOfShape(size)
return op.RandomNormalLike(shaper, dtype=dtype)

Check warning on line 6352 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6351-L6352

Added lines #L6351 - L6352 were not covered by tests


def aten_randn_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::randn_like")
def aten_randn_like(self: TFloat) -> TFloat:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

raise NotImplementedError()
return op.RandomNormalLike(self)

Check warning on line 6359 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6359

Added line #L6359 was not covered by tests


@torch_op("aten::randn_like")
def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

return op.RandomNormalLike(self, dtype=dtype)

Check warning on line 6366 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6366

Added line #L6366 was not covered by tests


def aten_randperm(n: int) -> TensorType:
Expand Down
103 changes: 103 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,109 @@
yield opinfo_core.SampleInput(make_arg(case), p=p, train=training)


def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs):

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'kwargs' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
def op_info # Unused
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
shapes = (
(M,),
(S, S),
(S, S, S),
)

for shape in shapes:
yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype))


def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs):

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'kwargs' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
def op_info # Unused

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
shapes = (
(M,),
(S, S),
(S, S, S),
)

for shape in shapes:
yield opinfo_core.SampleInput(make_arg(shape))


def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs):

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'kwargs' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
def op_info # Unused

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=torch.float32, requires_grad=requires_grad
)
shapes = (
(M,),
(S, S),
(S, S, S),
)

for shape in shapes:
yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype))


def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
inputs = [
((), {}),
((S, S), {}),
((0, S, 0), {}),
((S,), {'dtype': dtype}),
# Hard-code some dtypes/devices. We want to test cases where the
# (dtype, device) is different from the input's (dtype, device)
((S,), {'dtype': torch.double}),
((S,), ),

]
for shape, kwargs in inputs:
t = torch_testing.make_tensor(shape, dtype=dtype, device=device,
low=None, high=None,
requires_grad=requires_grad)
yield opinfo_core.SampleInput(t, **kwargs)


def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs):
high = 10

for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
# With high
yield opinfo_core.SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs)

def sample_inputs_randint_low(self, device, dtype, requires_grad, **kwargs):
low = 2
high = 10

for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
# With low and high
yield opinfo_core.SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs)


def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs):
low = 2
high = 10

for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
# With high
yield SampleInput(
sample.input,
high,
*sample.args,
**sample.kwargs)
# With low and high
yield SampleInput(
get_independent_tensor(sample.input),
low,
high,
*sample.args,
**sample.kwargs)


def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down
29 changes: 24 additions & 5 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,16 +1170,35 @@ def _where_input_wrangler(
trace_only=True,
),
TorchLibOpInfo("pow", core_ops.aten_pow),
# TorchLibOpInfo("rand", core_ops.aten_rand), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
TorchLibOpInfo(
"randn",
core_ops.aten_randn,
input_wrangler=_randn_input_wrangler,
"ops.aten.rand_like__dtype", core_ops.aten_rand_like_dtype, nondeterministic=True
),
TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True),
TorchLibOpInfo("ops.aten.randint.low", core_ops.aten_randint_low, nondeterministic=True),
TorchLibOpInfo("ops.aten.randint_like", core_ops.aten_randint_like, nondeterministic=True),
TorchLibOpInfo(
"ops.aten.randint_like__dtype", core_ops.aten_randint_like_dtype, nondeterministic=True
),
TorchLibOpInfo(
"ops.aten.randint_like.low_dtype",
core_ops.aten_randint_like_low_dtype,
nondeterministic=True,
).xfail(
),
TorchLibOpInfo(
"ops.aten.randint_like.low_dtype__dtype",
core_ops.aten_randint_like_low_dtype_dtype,
nondeterministic=True,
),
TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail(
dtypes=(torch.float16,),
reason="fixme: Shape inference error",
),
TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn, nondeterministic=True),
TorchLibOpInfo(
"ops.aten.randn_like_dtype", core_ops.aten_randn_like_dtype, nondeterministic=True
),
TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal),
TorchLibOpInfo(
"remainder",
Expand Down
Loading