Skip to content

Commit

Permalink
[torchlib] Include bfloat16 as part of the float types (#1894)
Browse files Browse the repository at this point in the history
Since onnx in opset 20 or so enabled bfloat16 for most relevant ops, we
are just going to include allow them in torchlib (even though it is
opset18 for now) to unblock bfloat16 model export.
  • Loading branch information
justinchuby authored Oct 8, 2024
1 parent db30dbb commit 3be8fc4
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 45 deletions.
53 changes: 25 additions & 28 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
RealType,
TFloat,
TFloatHighPrecision,
TFloatOrBFloat16,
TInt,
TReal,
TRealOrUInt8,
Expand Down Expand Up @@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType:


@torch_op("aten::floor", traceable=True)
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_floor(self: TFloat) -> TFloat:
"""floor(Tensor self) -> Tensor"""

return op.Floor(self)


@torch_op("math::floor", traceable=True)
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
def python_math_floor(self: TFloat) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)
Expand Down Expand Up @@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:


@torch_op("aten::isinf")
def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isinf(self: TFloat) -> BOOL:
"""isinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isnan")
def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
def aten_isnan(self: TFloat) -> BOOL:
"""isnan(Tensor self) -> Tensor"""

return op.IsNaN(self)


@torch_op("aten::isneginf")
def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
def aten_isneginf(self: TFloat) -> BOOL:
"""isneginf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isposinf")
def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isposinf(self: TFloat) -> BOOL:
"""isposinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand Down Expand Up @@ -4778,42 +4777,42 @@ def aten_linspace(


@torch_op("aten::log", traceable=True)
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log(self: TFloat) -> TFloat:
"""log(Tensor self) -> Tensor"""

return op.Log(self)


@torch_op("aten::log10", traceable=True)
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log10(self: TFloat) -> TFloat:
"""log10(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))


@torch_op("aten::log1p")
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log1p(self: TFloat) -> TFloat:
"""log1p(Tensor self) -> Tensor"""

return op.Log(op.Add(self, 1.0))


@torch_op("aten::log2", traceable=True)
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log2(self: TFloat) -> TFloat:
"""log2(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))


@torch_op("aten::logaddexp", traceable=True)
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp(Tensor self, Tensor other) -> Tensor"""

return op.Log(op.Add(op.Exp(self), op.Exp(other)))


@torch_op("aten::logaddexp2", traceable=True)
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
two = op.CastLike(2.0, self)
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
Expand All @@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr


@torch_op("aten::logcumsumexp", traceable=True)
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
"""logcumsumexp(Tensor self, int dim) -> Tensor"""

if IsScalar(self):
Expand Down Expand Up @@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:


@torch_op("aten::logit", private=True)
def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def _aten_logit_onnx(self: TFloat) -> TFloat:
return op.Log(op.Div(self, op.Sub(1.0, self)))


@torch_op("aten::logit", private=True)
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat:
eps = op.CastLike(eps, self)
one = op.CastLike(1.0, self)
temporary_self = op.Where(self <= one - eps, self, one - eps)
Expand All @@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat


@torch_op("aten::logit", trace_only=True)
def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
"""logit(Tensor self, float? eps=None) -> Tensor"""
if eps is None:
return _aten_logit_onnx(self)
Expand Down Expand Up @@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:


@torch_op("aten::native_dropout", trace_only=True)
def aten_native_dropout(
input: TFloatOrBFloat16, p: float, train: bool = True
) -> Tuple[TFloatOrBFloat16, BOOL]:
def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

result, mask = op.Dropout(input, p, train)
Expand Down Expand Up @@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType:


@torch_op("aten::reciprocal")
def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""

return op.Reciprocal(self)
Expand All @@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

# TODO(justinchuby): Improve fp16 precision by following the logic in
Expand Down Expand Up @@ -7355,7 +7352,7 @@ def aten_rrelu(


@torch_op("aten::rsqrt", traceable=True)
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_rsqrt(self: TFloat) -> TFloat:
"""rsqrt(Tensor self) -> Tensor"""

return op.Reciprocal(op.Sqrt(self))
Expand Down Expand Up @@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType:


@torch_op("aten::sigmoid", traceable=True)
def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sigmoid(self: TFloat) -> TFloat:
"""sigmoid(Tensor self) -> Tensor"""

return op.Sigmoid(self)
Expand Down Expand Up @@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:


@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand All @@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB


@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True)
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy(


@torch_op("aten::sqrt", traceable=True)
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sqrt(self: TFloat) -> TFloat:
"""sqrt(Tensor self) -> Tensor"""

return op.Sqrt(self)
Expand Down Expand Up @@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:


@torch_op("aten::trunc")
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_trunc(self: TFloat) -> TFloat:
"""trunc(Tensor self) -> Tensor"""

# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
Expand Down
11 changes: 5 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
Expand Down Expand Up @@ -364,13 +363,13 @@ def aten_conv_depthwise3d(

@torch_op("aten::cross_entropy_loss", traceable=True)
def aten_cross_entropy_loss(
self: TFloatOrBFloat16,
self: TFloat,
target: IntType,
weight: Optional[TFloatOrBFloat16] = None,
weight: Optional[TFloat] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
) -> TFloatOrBFloat16:
) -> TFloat:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""

if reduction == 0: # "none"
Expand Down Expand Up @@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te


@torch_op("aten::leaky_relu")
def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""

return op.LeakyRelu(self, alpha=negative_slope)
Expand Down Expand Up @@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:


@torch_op("aten::log_sigmoid")
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log_sigmoid(self: TFloat) -> TFloat:
"""log_sigmoid(Tensor self) -> Tensor"""

return op.Log(op.Sigmoid(self))
Expand Down
16 changes: 7 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down Expand Up @@ -92,21 +92,21 @@ def aten_special_entr(self: TensorType) -> TensorType:


@torch_op(("aten::erf", "aten::special_erf"))
def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erf(self: TFloat) -> TFloat:
"""erf(Tensor self) -> Tensor"""

return op.Erf(self)


@torch_op(("aten::erfc", "aten::special_erfc"))
def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erfc(self: TFloat) -> TFloat:
"""erfc(Tensor self) -> Tensor"""

return op.Sub(1, op.Erf(self))


@torch_op("aten::special_erfcx")
def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erfcx(self: TFloat) -> TFloat:
"""special_erfcx(Tensor self) -> Tensor"""

return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self)))
Expand All @@ -131,7 +131,7 @@ def aten_special_expit(self: TensorType) -> TensorType:


@torch_op(("aten::expm1", "aten::special_expm1"))
def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_expm1(self: TFloat) -> TFloat:
"""special_expm1(Tensor self) -> Tensor"""

return op.Sub(op.Exp(self), 1)
Expand Down Expand Up @@ -216,9 +216,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:


@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True)
def aten_special_log_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = -1
) -> TFloatOrBFloat16:
def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -366,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:


@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other"))
def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat:
"""special_xlogy(Tensor self, Tensor other) -> Tensor"""

# https://pytorch.org/docs/stable/special.html#torch.special.xlogy
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/function_libs/torch_lib/tensor_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
INT64,
UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
Expand All @@ -61,7 +61,6 @@
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8])
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
Expand Down

0 comments on commit 3be8fc4

Please sign in to comment.