diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f41ff1c3e..1fc73a220 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -39,7 +39,6 @@ RealType, TFloat, TFloatHighPrecision, - TFloatOrBFloat16, TInt, TReal, TRealOrUInt8, @@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType: @torch_op("aten::floor", traceable=True) -def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_floor(self: TFloat) -> TFloat: """floor(Tensor self) -> Tensor""" return op.Floor(self) @torch_op("math::floor", traceable=True) -def python_math_floor(self: TFloatOrBFloat16) -> TInt: +def python_math_floor(self: TFloat) -> TInt: """floor(Tensor self) -> Tensor""" floor = op.Floor(self) return op.Cast(floor, to=INT64.dtype) @@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL: @torch_op("aten::isinf") -def aten_isinf(self: TFloatOrBFloat16) -> BOOL: +def aten_isinf(self: TFloat) -> BOOL: """isinf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL: @torch_op("aten::isnan") -def aten_isnan(self: TFloatOrBFloat16) -> BOOL: +def aten_isnan(self: TFloat) -> BOOL: """isnan(Tensor self) -> Tensor""" return op.IsNaN(self) @torch_op("aten::isneginf") -def aten_isneginf(self: TFloatOrBFloat16) -> BOOL: +def aten_isneginf(self: TFloat) -> BOOL: """isneginf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL: @torch_op("aten::isposinf") -def aten_isposinf(self: TFloatOrBFloat16) -> BOOL: +def aten_isposinf(self: TFloat) -> BOOL: """isposinf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4778,42 +4777,42 @@ def aten_linspace( @torch_op("aten::log", traceable=True) -def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log(self: TFloat) -> TFloat: """log(Tensor self) -> Tensor""" return op.Log(self) @torch_op("aten::log10", traceable=True) -def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log10(self: TFloat) -> TFloat: """log10(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self)) @torch_op("aten::log1p") -def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log1p(self: TFloat) -> TFloat: """log1p(Tensor self) -> Tensor""" return op.Log(op.Add(self, 1.0)) @torch_op("aten::log2", traceable=True) -def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log2(self: TFloat) -> TFloat: """log2(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self)) @torch_op("aten::logaddexp", traceable=True) -def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat: """logaddexp(Tensor self, Tensor other) -> Tensor""" return op.Log(op.Add(op.Exp(self), op.Exp(other))) @torch_op("aten::logaddexp2", traceable=True) -def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat: """logaddexp2(Tensor self, Tensor other) -> Tensor""" two = op.CastLike(2.0, self) summation = op.Add(op.Pow(two, self), op.Pow(two, other)) @@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr @torch_op("aten::logcumsumexp", traceable=True) -def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: +def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat: """logcumsumexp(Tensor self, int dim) -> Tensor""" if IsScalar(self): @@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: @torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def _aten_logit_onnx(self: TFloat) -> TFloat: return op.Log(op.Div(self, op.Sub(1.0, self))) @torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16: +def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: eps = op.CastLike(eps, self) one = op.CastLike(1.0, self) temporary_self = op.Where(self <= one - eps, self, one - eps) @@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat @torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16: +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: """logit(Tensor self, float? eps=None) -> Tensor""" if eps is None: return _aten_logit_onnx(self) @@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout", trace_only=True) -def aten_native_dropout( - input: TFloatOrBFloat16, p: float, train: bool = True -) -> Tuple[TFloatOrBFloat16, BOOL]: +def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]: """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)""" result, mask = op.Dropout(input, p, train) @@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType: @torch_op("aten::reciprocal") -def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_reciprocal(self: TFloat) -> TFloat: """reciprocal(Tensor self) -> Tensor""" return op.Reciprocal(self) @@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: @torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar")) -def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_remainder(self: TFloat, other: TFloat) -> TFloat: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" # TODO(justinchuby): Improve fp16 precision by following the logic in @@ -7355,7 +7352,7 @@ def aten_rrelu( @torch_op("aten::rsqrt", traceable=True) -def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_rsqrt(self: TFloat) -> TFloat: """rsqrt(Tensor self) -> Tensor""" return op.Reciprocal(op.Sqrt(self)) @@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType: @torch_op("aten::sigmoid", traceable=True) -def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_sigmoid(self: TFloat) -> TFloat: """sigmoid(Tensor self) -> Tensor""" return op.Sigmoid(self) @@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType: @torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16: +def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" self_is_scalar = IsScalar(self) @@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB @torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True) -def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: +def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" self_is_scalar = IsScalar(self) @@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy( @torch_op("aten::sqrt", traceable=True) -def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_sqrt(self: TFloat) -> TFloat: """sqrt(Tensor self) -> Tensor""" return op.Sqrt(self) @@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType: @torch_op("aten::trunc") -def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_trunc(self: TFloat) -> TFloat: """trunc(Tensor self) -> Tensor""" # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126 diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4687e260a..e963050f5 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -25,7 +25,6 @@ from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, TFloat, - TFloatOrBFloat16, TFloatOrUInt8, TInt, TReal, @@ -364,13 +363,13 @@ def aten_conv_depthwise3d( @torch_op("aten::cross_entropy_loss", traceable=True) def aten_cross_entropy_loss( - self: TFloatOrBFloat16, + self: TFloat, target: IntType, - weight: Optional[TFloatOrBFloat16] = None, + weight: Optional[TFloat] = None, reduction: int = 1, # default is 'mean' ignore_index: int = -100, label_smoothing: float = 0.0, # this was ignored due to ONNX not support -) -> TFloatOrBFloat16: +) -> TFloat: """cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor""" if reduction == 0: # "none" @@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te @torch_op("aten::leaky_relu") -def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16: +def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat: """leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor""" return op.LeakyRelu(self, alpha=negative_slope) @@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat: @torch_op("aten::log_sigmoid") -def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log_sigmoid(self: TFloat) -> TFloat: """log_sigmoid(Tensor self) -> Tensor""" return op.Log(op.Sigmoid(self)) diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 6dd9edcd3..c791937b1 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -17,7 +17,7 @@ from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16 +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -92,21 +92,21 @@ def aten_special_entr(self: TensorType) -> TensorType: @torch_op(("aten::erf", "aten::special_erf")) -def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erf(self: TFloat) -> TFloat: """erf(Tensor self) -> Tensor""" return op.Erf(self) @torch_op(("aten::erfc", "aten::special_erfc")) -def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erfc(self: TFloat) -> TFloat: """erfc(Tensor self) -> Tensor""" return op.Sub(1, op.Erf(self)) @torch_op("aten::special_erfcx") -def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erfcx(self: TFloat) -> TFloat: """special_erfcx(Tensor self) -> Tensor""" return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self))) @@ -131,7 +131,7 @@ def aten_special_expit(self: TensorType) -> TensorType: @torch_op(("aten::expm1", "aten::special_expm1")) -def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_expm1(self: TFloat) -> TFloat: """special_expm1(Tensor self) -> Tensor""" return op.Sub(op.Exp(self), 1) @@ -216,9 +216,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: @torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True) -def aten_special_log_softmax( - self: TFloatOrBFloat16, dim: int, dtype: int = -1 -) -> TFloatOrBFloat16: +def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: """special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor""" self_is_scalar = IsScalar(self) @@ -366,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other")) -def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat: """special_xlogy(Tensor self, Tensor other) -> Tensor""" # https://pytorch.org/docs/stable/special.html#torch.special.xlogy diff --git a/onnxscript/function_libs/torch_lib/tensor_typing.py b/onnxscript/function_libs/torch_lib/tensor_typing.py index 7b5287f41..1f27c0cff 100644 --- a/onnxscript/function_libs/torch_lib/tensor_typing.py +++ b/onnxscript/function_libs/torch_lib/tensor_typing.py @@ -42,7 +42,7 @@ INT64, UINT8, ] -_FloatType = Union[FLOAT16, FLOAT, DOUBLE] +_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16] IntType = Union[INT8, INT16, INT32, INT64] RealType = Union[ BFLOAT16, @@ -61,7 +61,6 @@ TTensor2 = TypeVar("TTensor2", bound=_TensorType) TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING]) TFloat = TypeVar("TFloat", bound=_FloatType) -TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]) TInt = TypeVar("TInt", bound=IntType) TReal = TypeVar("TReal", bound=RealType)