Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement rand* ops | feat(torchilb) #1035

Merged
merged 8 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 110 additions & 19 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6232,47 +6232,138 @@ def aten_rad2deg(self: TensorType) -> TensorType:


@torch_op("aten::rand")
def aten_rand(size: Sequence[int], dtype: int = 1) -> TReal:
def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
"""rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

return op.RandomUniform(shape=size, dtype=dtype)
shaper = op.ConstantOfShape(size)
return op.RandomUniformLike(shaper, dtype=dtype)


def aten_rand_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::rand_like")
def aten_rand_like(self: TFloat) -> TFloat:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

raise NotImplementedError()
return op.RandomUniformLike(self)


@torch_op("aten::rand_like")
def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

def aten_randint(high: int, size: INT64) -> TensorType:
"""randint(int high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
return op.RandomUniformLike(self, dtype=dtype)

raise NotImplementedError()

@torch_op("aten::randint")
def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType:
"""randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

def aten_randint_like(
self: TensorType, high: int, memory_format: Optional[str] = None
shaper = op.ConstantOfShape(size)
rand = op.RandomUniformLike(shaper)
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randint.low")
def aten_randint_low(
low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype
) -> TensorType:
"""randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
"""randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
shaper = op.ConstantOfShape(size)
rand = op.RandomUniformLike(shaper)
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randint_like")
def aten_randint_like(self: TensorType, high: INT64) -> IntType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.CastLike(rand_int, self)


@torch_op("aten::randint_like")
def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randint_like.low_dtype")
def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None.
"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
return op.CastLike(rand_int, self)


@torch_op("aten::randint_like.low_dtype")
def aten_randint_like_low_dtype_dtype(
self: TensorType, low: INT64, high: INT64, dtype: int
) -> TensorType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randn")
def aten_randn(
size: Sequence[int],
dtype: int = 1,
requires_grad: bool = False, # pylint: disable=unused-argument
) -> TReal:
def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
"""randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

return op.RandomNormal(dtype=dtype, shape=size)
shaper = op.ConstantOfShape(size)
return op.RandomNormalLike(shaper, dtype=dtype)


def aten_randn_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::randn_like")
def aten_randn_like(self: TFloat) -> TFloat:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

raise NotImplementedError()
return op.RandomNormalLike(self)


@torch_op("aten::randn_like")
def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

return op.RandomNormalLike(self, dtype=dtype)


def aten_randperm(n: int) -> TensorType:
Expand Down
Loading
Loading