From ce48fbc212044bf33b062f304f9a3984af451f1f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 29 Jul 2024 09:14:35 -0700 Subject: [PATCH] [wip] add axiswise granularity to Float8Tensor Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 163452e97ed6a26fa5dcba01c36f49eb744484a6 Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/352 --- float8_experimental/config.py | 12 ++ float8_experimental/float8_ops.py | 75 ++++++++++++- float8_experimental/float8_scaling_utils.py | 14 ++- float8_experimental/float8_tensor.py | 13 +-- float8_experimental/float8_utils.py | 25 ++++- test/test_base.py | 117 +++++++++++++++++++- 6 files changed, 235 insertions(+), 21 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 5d1bf9f..217fca1 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -21,6 +21,18 @@ def short_str(self): return "dyn" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 2a11726..3f3af10 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -19,6 +19,15 @@ FLOAT8_OPS_TABLE: Dict[Any, Any] = {} +def _assert_tensorwise_scale(aten_op, scale): + assert ( + # TODO(future PR): figure out why tensorwise scaling can have + # both rank 0 and rank 1 + len(scale.shape) + in (0, 1) + ), f"{aten_op} with axiswise scaling is not supported yet" + + def implements(aten_ops): """Register aten ops to the float8 op table""" @@ -32,18 +41,16 @@ def decorator(func): @implements( [ - aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, ] ) def float8_desugar_op(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( new_data, @@ -54,8 +61,61 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.t.default, + aten.transpose.int, + ] +) +def float8_desugar_data_and_scale(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + +@implements([aten.view.default]) +def float8_view(aten_op, args, kwargs=None): + if len(args[0]._scale.shape) < 2: + # tensorwise scaling + return float8_desugar_op(aten_op, args, kwargs) + + t, new_shape = args[0], args[1] + # for now, only support reshaping to [-1, dim] or [dim, -1] + if len(new_shape) == 2: + if new_shape == [t.shape[0], -1] and t._scale.shape[0] == 1: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale = aten_op(t._scale, [1, -1], **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + ) + elif new_shape == [-1, t.shape[-1]] and t._scale.shape[-1] == 1: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale = aten_op(t._scale, [-1, 1], **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + ) + raise AssertionError( + f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} new_shape {new_shape} is not supported yet." + ) + + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) def make_float8(data): @@ -101,6 +161,7 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._gemm_input_role is gemm_input_role ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + _assert_tensorwise_scale(aten_op, chunk._scale) chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) @@ -117,6 +178,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" """ + _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): if isinstance(x, Float8Tensor): @@ -229,6 +291,7 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) def float8_is_same_size(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @@ -238,6 +301,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): when the input is a Float8Tensor, presenting as a fp32 tensor. """ + _assert_tensorwise_scale(aten_op, args[0]._scale) assert isinstance(args[0], Float8Tensor) assert ( len(kwargs) == 1 and "dtype" in kwargs @@ -265,6 +329,7 @@ def allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance( fp8_input, Float8Tensor @@ -284,6 +349,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) def wait_tensor_fp8(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -304,6 +370,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype assert fp8_self._orig_dtype == fp8_values._orig_dtype @@ -334,8 +401,10 @@ def copy_fp8(aten_op, args, kwargs=None): if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): src_hp = src.to_original_precision() + _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + _assert_tensorwise_scale(aten_op, src._scale) assert ( self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index ce6422f..06c93c1 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -12,6 +12,8 @@ import torch +from float8_experimental.config import ScalingGranularity + from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale( + hp_tensor, + float8_dtype, + reduce_amax, + scaling_granularity, + axiswise_dim, + ) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 641f972..22c2a32 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor): * `_data`: the underlying e4m3 or e5m2 data * `_scale`: the scale used to scale the original fp32 tensor. We multiply by scale to go from fp32 range to fp8 range, and divide by scale to go - from fp8 range to fp32 range. + from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible + with `_data`. For example: + - if scaling is tensorwise, `_scale` is a scalar tensor + - if scaling is axiswise and _data.shape is [3, 5], `_scale` could have + shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling + axis. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. * `_emulate`: if true using fp32 emulation for the matmuls, helpful @@ -279,12 +284,6 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 2be568e..fdd9189 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import float8_experimental.config as config import torch import torch.distributed as dist +from float8_experimental.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -100,8 +101,18 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) +def tensor_to_amax( + x: torch.Tensor, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, +) -> torch.Tensor: + if scaling_granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + assert axiswise_dim is not None, "unsupported" + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/test/test_base.py b/test/test_base.py index 4e0c685..739d6b0 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -16,7 +16,12 @@ import torch import torch.nn as nn -from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType +from float8_experimental.config import ( + CastConfig, + Float8LinearConfig, + ScalingGranularity, + ScalingType, +) from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -24,6 +29,7 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_python_api import addmm_float8_unwrapped +from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_dynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -57,7 +63,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: return True -class TestFloat8Tensor(unittest.TestCase): +class TestFloat8Tensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -67,7 +73,7 @@ def test_preserves_dtype(self) -> None: x1_s = tensor_to_scale(x1_hp, lp_dtype) x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() - self.assertTrue(x3_hp.dtype == hp_dtype) + assert x3_hp.dtype == hp_dtype def test_differentiable_casts(self) -> None: lp_dtypes = (e4m3_dtype, e5m2_dtype) @@ -102,7 +108,7 @@ def test_index_put(self): fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): b[index] = fp8_a fp8_b[index] = a fp8_b_bad[index] = fp8_a @@ -116,7 +122,7 @@ def test_copy_(self): b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work torch.testing.assert_close(b, fp8_a.to_original_precision()) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail fp8_b = Float8Tensor( @@ -143,6 +149,107 @@ def test_weights_only_load(self): buffer.seek(0) _ = torch.load(buffer, weights_only=True) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("dim_name", ["first", "last"]) + def test_axiswise_dynamic_cast(self, shape, dim_name): + a = torch.randn(*shape, dtype=torch.bfloat16) + + if dim_name == "first": + dim = 0 + elif dim_name == "last": + dim = len(a.shape) - 1 + + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=dim, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + + def test_axiswise_reshape(self): + a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + linear_mm_config = LinearMMConfig() + + # if we scale across dim0, we can only reshape to [3, -1] + a_fp8_d0 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + assert list(a_fp8_d0._data.shape) == [3, 5, 7] + assert list(a_fp8_d0._scale.shape) == [1, 5, 7] + + a_fp8_d0_r = a_fp8_d0.reshape(3, -1) + assert list(a_fp8_d0_r.shape) == [3, 5 * 7] + assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7] + # verify numerics did not change + assert torch.allclose( + a_fp8_d0.to_original_precision(), + a_fp8_d0_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(AssertionError): + a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7) + + # if we scale across dim2, we can only reshape to [-1, 7] + a_fp8_d2 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=2, + ) + assert list(a_fp8_d2._data.shape) == [3, 5, 7] + assert list(a_fp8_d2._scale.shape) == [3, 5, 1] + + a_fp8_d2_r = a_fp8_d2.reshape(-1, 7) + assert list(a_fp8_d2_r.shape) == [3 * 5, 7] + assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1] + # verify numerics did not change + assert torch.allclose( + a_fp8_d2.to_original_precision(), + a_fp8_d2_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(AssertionError): + a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) + + def test_axiswise_gemm(self): + a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, + ) + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, + ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + assert sqnr >= 25.0 + class TestFloat8Linear: def _test_linear_impl(