From 975979d72a218840e3d22ed8740a46eb5056a05c Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 7 Nov 2023 01:15:38 +0000 Subject: [PATCH] split python math functions --- .../function_libs/torch_lib/ops/core.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cd74fa45f..7e108b99c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1438,13 +1438,20 @@ def aten_cdist( raise NotImplementedError() -@torch_op(("aten::ceil", "math::ceil")) +@torch_op("aten::ceil") def aten_ceil(self: TFloat) -> TFloat: """ceil(Tensor self) -> Tensor""" return op.Ceil(self) +@torch_op("math::ceil") +def python_math_ceil(self: TFloat) -> TInt: + """ceil(Tensor self) -> Tensor""" + ceil = op.Ceil(self) + return op.Cast(ceil, to=INT64.dtype) + + def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType: """chain_matmul(Tensor[] matrices) -> Tensor""" @@ -3109,7 +3116,7 @@ def aten_empty_strided( return op.Expand(zero, size) -@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq")) +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3364,13 +3371,20 @@ def aten_flipud(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::floor", "math::floor")) +@torch_op("aten::floor") def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """floor(Tensor self) -> Tensor""" return op.Floor(self) +@torch_op("math::floor") +def python_math_floor(self: TFloatOrBFloat16) -> TInt: + """floor(Tensor self) -> Tensor""" + floor = op.Floor(self) + return op.Cast(floor, to=INT64.dtype) + + @torch_op(("aten::floor_divide", "_operator::floordiv")) def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor"""