Skip to content

Commit

Permalink
Implement rand* ops | feat(torchilb) (#1035)
Browse files Browse the repository at this point in the history
Implement rand* ops. I split all ops that can take a dtype into two
overloads to avoid trace_only logic.

Replaces #875
  • Loading branch information
justinchuby authored Oct 11, 2023
1 parent c74bc6a commit 1ee4ee1
Show file tree
Hide file tree
Showing 3 changed files with 372 additions and 32 deletions.
129 changes: 110 additions & 19 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6232,47 +6232,138 @@ def aten_rad2deg(self: TensorType) -> TensorType:


@torch_op("aten::rand")
def aten_rand(size: Sequence[int], dtype: int = 1) -> TReal:
def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
"""rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

return op.RandomUniform(shape=size, dtype=dtype)
shaper = op.ConstantOfShape(size)
return op.RandomUniformLike(shaper, dtype=dtype)


def aten_rand_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::rand_like")
def aten_rand_like(self: TFloat) -> TFloat:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

raise NotImplementedError()
return op.RandomUniformLike(self)


@torch_op("aten::rand_like")
def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

def aten_randint(high: int, size: INT64) -> TensorType:
"""randint(int high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
return op.RandomUniformLike(self, dtype=dtype)

raise NotImplementedError()

@torch_op("aten::randint")
def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType:
"""randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

def aten_randint_like(
self: TensorType, high: int, memory_format: Optional[str] = None
shaper = op.ConstantOfShape(size)
rand = op.RandomUniformLike(shaper)
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randint.low")
def aten_randint_low(
low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype
) -> TensorType:
"""randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
"""randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
shaper = op.ConstantOfShape(size)
rand = op.RandomUniformLike(shaper)
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randint_like")
def aten_randint_like(self: TensorType, high: INT64) -> IntType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.CastLike(rand_int, self)


@torch_op("aten::randint_like")
def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Scale to [0, high] first
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randint_like.low_dtype")
def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None.
"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
return op.CastLike(rand_int, self)


@torch_op("aten::randint_like.low_dtype")
def aten_randint_like_low_dtype_dtype(
self: TensorType, low: INT64, high: INT64, dtype: int
) -> TensorType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

self_float = op.Cast(self, to=FLOAT.dtype)
rand = op.RandomUniformLike(self_float)
# Translate to [low, high] first
high = op.Cast(high, to=FLOAT.dtype)
low = op.Cast(low, to=FLOAT.dtype)
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
return op.Cast(rand_int, to=dtype)


@torch_op("aten::randn")
def aten_randn(
size: Sequence[int],
dtype: int = 1,
requires_grad: bool = False, # pylint: disable=unused-argument
) -> TReal:
def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
"""randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

return op.RandomNormal(dtype=dtype, shape=size)
shaper = op.ConstantOfShape(size)
return op.RandomNormalLike(shaper, dtype=dtype)


def aten_randn_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::randn_like")
def aten_randn_like(self: TFloat) -> TFloat:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

raise NotImplementedError()
return op.RandomNormalLike(self)


@torch_op("aten::randn_like")
def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

return op.RandomNormalLike(self, dtype=dtype)


def aten_randperm(n: int) -> TensorType:
Expand Down
Loading

0 comments on commit 1ee4ee1

Please sign in to comment.