Skip to content

Commit

Permalink
split python math functions
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Nov 7, 2023
1 parent efe2a58 commit 975979d
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,13 +1438,20 @@ def aten_cdist(
raise NotImplementedError()


@torch_op(("aten::ceil", "math::ceil"))
@torch_op("aten::ceil")
def aten_ceil(self: TFloat) -> TFloat:
"""ceil(Tensor self) -> Tensor"""

return op.Ceil(self)


@torch_op("math::ceil")
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
return op.Cast(ceil, to=INT64.dtype)


def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType:
"""chain_matmul(Tensor[] matrices) -> Tensor"""

Expand Down Expand Up @@ -3109,7 +3116,7 @@ def aten_empty_strided(
return op.Expand(zero, size)


@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"))
@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar"))
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3364,13 +3371,20 @@ def aten_flipud(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::floor", "math::floor"))
@torch_op("aten::floor")
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""floor(Tensor self) -> Tensor"""

return op.Floor(self)


@torch_op("math::floor")
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)


@torch_op(("aten::floor_divide", "_operator::floordiv"))
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
Expand Down

0 comments on commit 975979d

Please sign in to comment.