From 183dbec75635482fcecd57839f0fd571651d176b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 09:52:24 -0700 Subject: [PATCH 01/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 119 +++++++++++++++++++++++-- torchao/float8/config.py | 12 +++ torchao/float8/float8_ops.py | 93 ++++++++++++++++++- torchao/float8/float8_scaling_utils.py | 15 +++- torchao/float8/float8_tensor.py | 34 ++++--- torchao/float8/float8_utils.py | 25 ++++-- 6 files changed, 275 insertions(+), 23 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 2a875c44d6..e6dd67951c 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -22,7 +22,12 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingGranularity, + ScalingType, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -30,6 +35,7 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -58,7 +64,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: return True -class TestFloat8Tensor(unittest.TestCase): +class TestFloat8Tensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -68,7 +74,7 @@ def test_preserves_dtype(self) -> None: x1_s = tensor_to_scale(x1_hp, lp_dtype) x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() - self.assertTrue(x3_hp.dtype == hp_dtype) + assert x3_hp.dtype == hp_dtype def test_differentiable_casts(self) -> None: lp_dtypes = (e4m3_dtype, e5m2_dtype) @@ -103,7 +109,7 @@ def test_index_put(self): fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): b[index] = fp8_a fp8_b[index] = a fp8_b_bad[index] = fp8_a @@ -117,7 +123,7 @@ def test_copy_(self): b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work torch.testing.assert_close(b, fp8_a.to_original_precision()) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail fp8_b = Float8Tensor( @@ -129,6 +135,109 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("dim_name", ["first", "last"]) + def test_axiswise_dynamic_cast(self, shape, dim_name): + a = torch.randn(*shape, dtype=torch.bfloat16) + + if dim_name == "first": + dim = 0 + elif dim_name == "last": + dim = len(a.shape) - 1 + + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=dim, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + + def test_axiswise_reshape(self): + a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + linear_mm_config = LinearMMConfig() + + # if we scale across dim0, we can only reshape to [3, -1] + a_fp8_d0 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + assert list(a_fp8_d0._data.shape) == [3, 5, 7] + assert list(a_fp8_d0._scale.shape) == [1, 5, 7] + + a_fp8_d0_r = a_fp8_d0.reshape(3, -1) + assert list(a_fp8_d0_r.shape) == [3, 5 * 7] + assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7] + # verify numerics did not change + assert torch.allclose( + a_fp8_d0.to_original_precision(), + a_fp8_d0_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(RuntimeError): + a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7) + + # if we scale across dim2, we can only reshape to [-1, 7] + a_fp8_d2 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=2, + ) + assert list(a_fp8_d2._data.shape) == [3, 5, 7] + assert list(a_fp8_d2._scale.shape) == [3, 5, 1] + + a_fp8_d2_r = a_fp8_d2.reshape(-1, 7) + assert list(a_fp8_d2_r.shape) == [3 * 5, 7] + assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1] + # verify numerics did not change + assert torch.allclose( + a_fp8_d2.to_original_precision(), + a_fp8_d2_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(RuntimeError): + a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) + + @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + def test_axiswise_gemm(self, a_shape): + a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + ) + a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, # will be transposed + ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + a = a.reshape(-1, a_shape[-1]) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + assert sqnr >= 25.0 class TestFloat8Linear: diff --git a/torchao/float8/config.py b/torchao/float8/config.py index eb28dcbd8e..b24b5ba749 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -26,6 +26,18 @@ def short_str(self): return "sta" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index f8115649b3..1bf9faaa4c 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -19,6 +19,15 @@ FLOAT8_OPS_TABLE: Dict[Any, Any] = {} +def _assert_tensorwise_scale(aten_op, scale): + assert ( + # TODO(future PR): figure out why tensorwise scaling can have + # both rank 0 and rank 1 + len(scale.shape) + in (0, 1) + ), f"{aten_op} with axiswise scaling is not supported yet" + + def implements(aten_ops): """Register aten ops to the float8 op table""" @@ -45,6 +54,7 @@ def decorator(func): ] ) def float8_desugar_op(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( new_data, @@ -55,10 +65,82 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.t.default, + aten.transpose.int, + ] +) +def float8_desugar_data_and_scale(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + + if aten_op == aten.transpose.int: + _assert_tensorwise_scale(aten_op, args[0]._scale) + + old_axiswise_dim = args[0]._axiswise_dim + new_axiswise_dim = old_axiswise_dim + if old_axiswise_dim is not None: + if old_axiswise_dim == 0: + new_axiswise_dim == -1 + else: + new_axiswise_dim == 0 + + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + new_axiswise_dim, + ) + + +@implements([aten.view.default]) +def float8_view(aten_op, args, kwargs=None): + if len(args[0]._scale.shape) < 2: + # tensorwise scaling + return float8_desugar_op(aten_op, args, kwargs) + + t, new_shape = args[0], args[1] + # for now, only support reshaping to [-1, dim] or [dim, -1] + axiswise_dim = t._axiswise_dim + if len(new_shape) == 2: + + if axiswise_dim == 0: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale_shape = [1, new_shape[-1]] + new_scale = aten_op(t._scale, new_scale_shape, **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + t._axiswise_dim, + ) + elif axiswise_dim == -1 or axiswise_dim == (len(t.shape) - 1): + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale_shape = [new_shape[0], 1] + new_scale = aten_op(t._scale, new_scale_shape, **kwargs) + new_axiswise_dim = -1 + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + new_axiswise_dim, + ) + raise AssertionError( + f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet." + ) + + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) - + _assert_tensorwise_scale(aten_op, args[0]._scale) def make_float8(data): return Float8Tensor( data, @@ -102,6 +184,7 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._gemm_input_role is gemm_input_role ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + _assert_tensorwise_scale(aten_op, chunk._scale) chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) @@ -118,6 +201,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" """ + _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): if isinstance(x, Float8Tensor): @@ -230,6 +314,7 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) def float8_is_same_size(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @@ -239,6 +324,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): when the input is a Float8Tensor, presenting as a fp32 tensor. """ + _assert_tensorwise_scale(aten_op, args[0]._scale) assert isinstance(args[0], Float8Tensor) assert ( len(kwargs) == 1 and "dtype" in kwargs @@ -266,6 +352,7 @@ def allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance( fp8_input, Float8Tensor @@ -285,6 +372,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) def wait_tensor_fp8(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -305,6 +393,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype assert fp8_self._orig_dtype == fp8_values._orig_dtype @@ -335,8 +424,10 @@ def copy_fp8(aten_op, args, kwargs=None): if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): src_hp = src.to_original_precision() + _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + _assert_tensorwise_scale(aten_op, src._scale) assert ( self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index d2ae896320..f46293d616 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -12,6 +12,8 @@ import torch +from torchao.float8.config import ScalingGranularity + from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -49,16 +53,25 @@ def hp_tensor_to_float8_dynamic( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale( + hp_tensor, + float8_dtype, + reduce_amax, + scaling_granularity, + axiswise_dim, + ) return hp_tensor_and_scale_to_float8( hp_tensor, scale, float8_dtype, linear_mm_config, gemm_input_role, + axiswise_dim, ) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 63110101a5..c8b68586c0 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -152,6 +152,7 @@ def forward( float8_dtype=e4m3_dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): """ This function will apply the scaling, and then convert to a Float8Tensor @@ -180,6 +181,7 @@ def forward( tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, + axiswise_dim=axiswise_dim, ) return DTensor.from_local( inner_float8_tensor, @@ -196,6 +198,7 @@ def forward( tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, + axiswise_dim=axiswise_dim, ) @staticmethod @@ -226,6 +229,7 @@ def hp_tensor_and_scale_to_float8( float8_dtype=e4m3_dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, @@ -242,9 +246,10 @@ def hp_tensor_and_scale_to_float8( the 3 fwd/bwd gemms of linear gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + axiswise_dim: for rowwise scaling, contains the axis scaled across """ return _ToFloat8ConstrFunc.apply( - hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role + hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim ) @@ -258,11 +263,19 @@ class Float8Tensor(torch.Tensor): * `_data`: the underlying e4m3 or e5m2 data * `_scale`: the scale used to scale the original fp32 tensor. We multiply by scale to go from fp32 range to fp8 range, and divide by scale to go - from fp8 range to fp32 range. + from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible + with `_data`. For example: + - if scaling is tensorwise, `_scale` is a scalar tensor + - if scaling is axiswise and _data.shape is [3, 5], `_scale` could have + shape [1, 5] or [3, 1]. The dim of the non-one entry defines the scaling + axis. + - if scaling is axiswise and _data.shape is [2, 3, 5], `_scale` could have + shape [1, 1, 5] or [2, 1, 1]. The dim of the non-one entry defines the scaling + axis. Non-one entries which are not the first or last element are not + supported. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. - * `_emulate`: if true using fp32 emulation for the matmuls, helpful - if you don't have access to h100 hardware. + * `_axiswise_dim`: for axiswise scaling only, contains the axis scales across Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -277,6 +290,7 @@ class Float8Tensor(torch.Tensor): _scale: torch.Tensor _orig_dtype: torch.dtype _linear_mm_config: LinearMMConfig + _axiswise_dim: Optional[int] __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] def __new__( @@ -286,13 +300,8 @@ def __new__( orig_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), @@ -310,17 +319,19 @@ def __new__( linear_mm_config if linear_mm_config is not None else LinearMMConfig() ) self._gemm_input_role = gemm_input_role + self._axiswise_dim = axiswise_dim return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, "_linear_mm_config": self._linear_mm_config, "_gemm_input_role": self._gemm_input_role, + "_axiswise_dim": self._axiswise_dim, } return ["_data", "_scale"], ctx @@ -333,6 +344,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride metadata["_orig_dtype"], metadata["_linear_mm_config"], metadata["_gemm_input_role"], + metadata["_axiswise_dim"], ) def to_original_precision(self): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..55e520f8ca 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import torchao.float8.config as config import torch import torch.distributed as dist +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -98,8 +99,18 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) +def tensor_to_amax( + x: torch.Tensor, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, +) -> torch.Tensor: + if scaling_granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + assert axiswise_dim is not None, "unsupported" + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -112,9 +123,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim) return amax_to_scale(amax, float8_dtype, x.dtype) From 241f815f73b85f27a59fdf9f8e9d0a80cedb5f2c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 09:52:27 -0700 Subject: [PATCH 02/32] Update [ghstack-poisoned] --- benchmarks/float8/bench_linear_float8.py | 32 +++- benchmarks/float8/profile_linear_float8.py | 27 ++- test/float8/test_base.py | 32 +++- test/float8/test_compile.py | 103 ++++++++++- test/float8/test_numerics_integration.py | 37 +++- torchao/float8/config.py | 32 ++++ torchao/float8/float8_linear.py | 188 +++++++++++++++++++-- torchao/float8/float8_ops.py | 22 ++- 8 files changed, 434 insertions(+), 39 deletions(-) diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index e18006f0e4..f92303c627 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -14,7 +14,12 @@ import torch import torch.utils.benchmark as benchmark -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( linear_requires_sync, @@ -107,6 +112,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", ): device = "cuda" print(f"Compile is set to | {compile}") @@ -114,28 +120,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -167,7 +186,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = linear_float8.scaling_repr() + scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -310,6 +329,7 @@ def invoke_main() -> None: parser.add_argument("--scaling_type_input", type=str, required=False) parser.add_argument("--scaling_type_weight", type=str, required=False) parser.add_argument("--scaling_type_grad_output", type=str, required=False) + parser.add_argument("--scaling_granularity", type=str, required=False) args = parser.parse_args() output_path = Path(args.output_path) if args.output_path is not None else None kwargs = {} @@ -327,6 +347,8 @@ def invoke_main() -> None: kwargs["scaling_type_weight"] = args.scaling_type_weight if args.scaling_type_grad_output is not None: kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output + if args.scaling_granularity is not None: + kwargs["scaling_granularity"] = args.scaling_granularity main( output_path, not args.disable_compile, diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index c204d49b03..6afefa0096 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -22,7 +22,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -252,6 +257,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -263,28 +269,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e6dd67951c..2fa5394b81 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -327,6 +327,10 @@ def _test_linear_impl( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -337,6 +341,7 @@ def test_linear( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, linear_dtype: torch.dtype, linear_bias: bool, ): @@ -349,30 +354,51 @@ def test_linear( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + linear_dtype != torch.bfloat16 + ): + pytest.skip() + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec3..899f63bdb3 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -18,7 +18,12 @@ import torch import torch.nn as nn -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -60,6 +65,8 @@ def _test_compile_base( y_fp8.sum().backward() y_ref = m_ref(x) y_ref.sum().backward() + # TODO(future PR): can also test fp8 eager vs compile here with a tigher + # tolerance torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) torch.testing.assert_close( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 @@ -70,29 +77,42 @@ def _get_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, + scaling_granularity, emulate, ): if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -103,6 +123,24 @@ def _get_config( return config +def is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, +) -> bool: + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + dtype != torch.bfloat16 + ): + return False + return True + + @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] @@ -113,6 +151,9 @@ def _get_config( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -122,11 +163,25 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "eager", @@ -147,6 +202,9 @@ def test_eager_only( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( @@ -155,11 +213,25 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "aot_eager", @@ -180,6 +252,9 @@ def test_aot_eager( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -188,11 +263,25 @@ def test_inductor( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "inductor", diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 6db05dc56d..396de7efa8 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -90,6 +95,10 @@ class TestFloat8NumericsIntegrationTest: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( @@ -97,10 +106,20 @@ def test_encoder_fw_bw( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + data_dtype != torch.bfloat16 + ): + pytest.skip() + # LLaMa 3 70B shapes model_ref = ( FeedForward( @@ -119,24 +138,34 @@ def test_encoder_fw_bw( if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b24b5ba749..4d82bd1118 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -37,6 +37,13 @@ class ScalingGranularity(enum.Enum): # size 1. AXISWISE = "axiswise" + def short_str(self): + if self is ScalingGranularity.TENSORWISE: + return "ten" + else: + assert self is ScalingGranularity.AXISWISE + return "axs" + @dataclass(frozen=True) class CastConfig: @@ -45,12 +52,16 @@ class CastConfig: """ scaling_type: ScalingType = ScalingType.DYNAMIC + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None def __post_init__(self): if self.scaling_type is ScalingType.STATIC: assert self.static_scale is not None, \ "static_scale must be specified for static scaling" + if self.scaling_granularity is ScalingGranularity.AXISWISE: + assert self.scaling_type is ScalingType.DYNAMIC, \ + "only dynamic scaling type is supported for axiswise scaling granularity" @dataclass(frozen=True) class DelayedScalingConfig: @@ -144,6 +155,27 @@ class Float8LinearConfig: # configuration, this field may move to per-tensor configs. delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + def __post_init__(self): + # float8 all-gather only supports tensorwise, in the future may support blockwise + if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: + assert not self.enable_fsdp_float8_all_gather, \ + f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" + + # for now, axiswise granularity is all-or-nothing + # TODO(future PR): enable more granular setting per-gemm-input + has_any_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + has_all_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + if has_any_axiswise_scaling: + assert has_all_axiswise_scaling, \ + "for now, axiswise scaling must be enabled for either all casts or none of the casts" # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index cb0ff7afb0..5f87e82fe4 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -14,7 +14,7 @@ import torch -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig, ScalingType, ScalingGranularity from torchao.float8.float8_scaling_utils import ( _maybe_initialize_amaxes_scales_for_float8_cast, @@ -42,11 +42,17 @@ ) -# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files @torch._dynamo.allow_in_graph -class manual_float8_matmul(torch.autograd.Function): +class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): """ Like torch.matmul, but with the arguments in float8 + + Note: this function requires all arguments to already be Float8Tensor objects, + which only supports tensorwise scaling granularity. The reason we didn't just make this + function support axiswise scaling granularity is because that would need very + careful testing of delayed scaling, as delayed scaling modifies buffers inplace. + + In the future we'll probably have to unify, just postponing that until a future PR. """ @staticmethod @@ -97,6 +103,133 @@ def backward(ctx, grad_output_fp8): return grad_input, grad_weight.t() +@torch._dynamo.allow_in_graph +class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in high precision and the cast to float8 + defined inside of this function. + + Note: this function currently only supports dynamic scaling type and + axiswise granularity. We will have to unify this with other scaling types + and other granularities in a separate PR. + """ + + # TODO(this PR): types of inputs + @staticmethod + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp_t: torch.Tensor, + linear_mm_config: LinearMMConfig, + input_scaling_granularity: ScalingGranularity, + weight_scaling_granularity: ScalingGranularity, + grad_output_scaling_granularity: ScalingGranularity, + ): + ctx.save_for_backward(input_hp, weight_hp_t) + ctx.linear_mm_config = linear_mm_config + ctx.input_scaling_granularity = input_scaling_granularity + ctx.weight_scaling_granularity = weight_scaling_granularity + ctx.grad_output_scaling_granularity = grad_output_scaling_granularity + + input_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=input_scaling_granularity, + axiswise_dim=-1, + ) + + weight_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=weight_scaling_granularity, + axiswise_dim=0, + ) + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, grad_output): + input_hp, weight_hp_t = ctx.saved_tensors + + # TODO scaling + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + grad_output_orig_shape = grad_output.shape + grad_output_reshaped = grad_output.reshape( + -1, grad_output_orig_shape[-1] + ) + + # + # calculate grad_input + # + + grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=-1, + ) + weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ctx.weight_scaling_granularity, + axiswise_dim=1, # will be transposed + ) + + grad_input = torch.mm( + grad_output_reshaped_fp8_dim0, + weight_t_fp8_dim0.t(), + ) + grad_input = grad_input.reshape( + *grad_output_orig_shape[:-1], grad_input.shape[-1] + ) + + input_hp_orig_shape = input_hp.shape + input_hp_reshaped = input_hp.reshape(-1, input_hp_orig_shape[-1]) + + # + # calculate grad_weight + # + + grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=0, # will be transposed + ) + input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ctx.input_scaling_granularity, + axiswise_dim=0, + ) + + grad_weight = torch.mm( + grad_output_reshaped_fp8_dim1.t(), + input_reshaped_fp8_dim1, + ) + + return grad_input, grad_weight.t(), None, None, None, None + class Float8Linear(torch.nn.Linear): """ @@ -289,7 +422,10 @@ def cast_input_to_float8( ) elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( - input, e4m3_dtype, self.linear_mm_config + input, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC @@ -395,13 +531,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) - weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) + # TODO(this PR): reuse with config, make a property + has_all_axiswise_scaling = ( + self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + + if not has_all_axiswise_scaling: + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) + output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8.t()) - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) + + else: + # for now, axiswise path is separate + # TODO(future PR): unify to support mix and match + output = manual_float8_matmul_with_args_in_hp.apply( + input, + self.weight.t(), + self.linear_mm_config, + self.config.cast_config_input.scaling_granularity, + self.config.cast_config_weight.scaling_granularity, + self.config.cast_config_grad_output.scaling_granularity, + ) if self.bias is not None: output = output + self.bias.to(output.dtype) @@ -410,13 +566,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_post_forward() return output - def scaling_repr(self): - # add scaling settings without using too many characters + def scaling_type_repr(self): + # add scaling type settings without using too many characters # example: "i:del,w:del,go:dyn" return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" + def scaling_granularity_repr(self): + # add scaling granularity settings without using too many characters + # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" + gi = self.config.cast_config_input.scaling_granularity.short_str() + gw = self.config.cast_config_weight.scaling_granularity.short_str() + ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() + return f"i:{gi},w:{gw},go:{ggo}" + def extra_repr(self): - s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' + s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' return s @classmethod diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 1bf9faaa4c..b97d032113 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -43,12 +43,9 @@ def decorator(func): [ aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, - aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, aten.reshape.default, ] @@ -65,13 +62,30 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.detach.default, + ] +) +def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + @implements( [ aten.t.default, aten.transpose.int, ] ) -def float8_desugar_data_and_scale(aten_op, args, kwargs=None): +def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) From f15c2a02b0822784f718c51480c1fbe422ea5bb8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 09:54:13 -0700 Subject: [PATCH 03/32] Update [ghstack-poisoned] --- torchao/float8/float8_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 55e520f8ca..e79cf27d88 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -100,7 +100,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax( - x: torch.Tensor, + x: torch.Tensor, reduce_amax: bool = False, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, @@ -110,7 +110,7 @@ def tensor_to_amax( else: assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" assert axiswise_dim is not None, "unsupported" - amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -123,8 +123,8 @@ def tensor_to_amax( @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, - float8_dtype: torch.dtype, + x: torch.Tensor, + float8_dtype: torch.dtype, reduce_amax: bool = False, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, From 9150b4f2b2f0e13d9cd79f876067b9907037841c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 10:55:16 -0700 Subject: [PATCH 04/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e6dd67951c..d3f48a7153 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -57,6 +57,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" @@ -210,6 +211,8 @@ def test_axiswise_reshape(self): a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") From 459e92c434378eb4d8b70d69299af94a69e4d45c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 12:25:42 -0700 Subject: [PATCH 05/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d3f48a7153..ebc33f0372 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -159,7 +159,7 @@ def test_axiswise_dynamic_cast(self, shape, dim_name): assert sqnr >= 25.0 def test_axiswise_reshape(self): - a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + a = torch.randn(3, 5, 7, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() # if we scale across dim0, we can only reshape to [3, -1] From 732b231dffe87bad13a76ccb7cc859b8714a6239 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 13:22:08 -0700 Subject: [PATCH 06/32] Update [ghstack-poisoned] --- test/float8/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 74cc6faa5c..eacd317b1a 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -137,7 +137,7 @@ def is_supported( scaling_type_weight != ScalingType.DYNAMIC or scaling_type_grad_output != ScalingType.DYNAMIC or dtype != torch.bfloat16 or - (not IS_H100) + (not is_H100) ): return False return True From e7c15d1675ec3aaa81c3035a7f06dd716db9c2f6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 24 Sep 2024 16:25:42 -0700 Subject: [PATCH 07/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 190 ++++++++++++++++++++++++-------- torchao/float8/config.py | 70 +++++++++--- torchao/float8/float8_linear.py | 82 +++++++++++--- torchao/float8/float8_ops.py | 11 ++ 4 files changed, 279 insertions(+), 74 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index f0e0ac0a9d..b8d4a04f01 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -8,6 +8,7 @@ import itertools import random import re +from typing import List, Tuple import unittest import warnings @@ -59,6 +60,41 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +# scaling granularity and keep_in_original_precision to test by gemm arguments in this +# order: output, grad_input, grad_weight +scaling_granularities_by_gemm = [ + # TODO(before land): move this last + # @lcw's recipe + # TODO(future): add leaving one of the matmuls in bfloat16 + [ + # output = input @ weight_t + # input: axiswise + # weight_t: axiswise + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + # grad_input = grad_output @ weight + # grad_output: axiswise + # weight: tensorwise (but that can be computed from axiswise done in the forward) + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), + # grad_weight = input_t @ grad_output, in high precision (bfloat16) + # input_t: high precision + # grad_output: high precision + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), + ], + # all tensorwise + #[ + # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE), + # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE), + # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE), + #], + # all axiswise + #[ + # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), + # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), + # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), + #], +] + + def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -211,31 +247,52 @@ def test_axiswise_reshape(self): a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @pytest.mark.parametrize( + "a_granularity,b_granularity", + [ + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE), + (ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE), + ] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") - def test_axiswise_gemm(self, a_shape): + def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") linear_mm_config = LinearMMConfig() + if a_granularity is ScalingGranularity.AXISWISE: + a_axiswise_dim = -1 + else: + assert a_granularity is ScalingGranularity.TENSORWISE + a_axiswise_dim = None a_fp8 = hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=-1, + scaling_granularity=a_granularity, + axiswise_dim=a_axiswise_dim, ) a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + + b_axiswise_dim = 1 if b_granularity is ScalingGranularity.AXISWISE else None + if b_granularity is ScalingGranularity.AXISWISE: + b_axiswise_dim = 1 # will be transposed + else: + assert b_granularity is ScalingGranularity.TENSORWISE + b_axiswise_dim = None b_fp8 = hp_tensor_to_float8_dynamic( b, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=1, # will be transposed + scaling_granularity=b_granularity, + axiswise_dim=b_axiswise_dim, ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) a = a.reshape(-1, a_shape[-1]) c_ref = torch.mm(a, b.t()) @@ -316,26 +373,33 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + # @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize("emulate", [False] if is_cuda_8_9 else [True]) + # @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + @pytest.mark.parametrize("x_shape", [(16, 16),]) @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + [ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + [ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + # [ScalingType.DELAYED, ScalingType.DYNAMIC], + [ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_granularity", - [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) - @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) - @pytest.mark.parametrize("linear_bias", [False, True]) + # @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, ]) + # @pytest.mark.parametrize("linear_bias", [False, True]) + @pytest.mark.parametrize("linear_bias", [False, ]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear( self, @@ -344,7 +408,7 @@ def test_linear( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], linear_dtype: torch.dtype, linear_bias: bool, ): @@ -357,7 +421,23 @@ def test_linear( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() - if scaling_granularity is ScalingGranularity.AXISWISE: + + ( + (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), + (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), + (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), + ) = scaling_granularities_by_gemm + + has_any_axiswise_scaling = ( + scaling_granularity_input is ScalingGranularity.AXISWISE or + scaling_granularity_weight is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output is ScalingGranularity.AXISWISE or + scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or + scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE + ) + + if has_any_axiswise_scaling: if ( scaling_type_input != ScalingType.DYNAMIC or scaling_type_weight != ScalingType.DYNAMIC or @@ -370,46 +450,70 @@ def test_linear( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + static_scale_one = torch.tensor([1.0], device="cuda") + if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) + static_scale_input = static_scale_one else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) + static_scale_input = None if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) + static_scale_weight = static_scale_one else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) + static_scale_weight = None if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) + static_scale_grad_output = static_scale_one else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) + static_scale_grad_output = None + + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity_input, + static_scale=static_scale_input, + keep_in_original_precision=original_prec_input, + ) + cast_config_input_for_grad_weight = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity_input_for_grad_weight, + static_scale=static_scale_input, + keep_in_original_precision=original_prec_input_for_grad_weight, + ) + + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity_weight, + static_scale=static_scale_weight, + keep_in_original_precision=original_prec_weight, + ) + cast_config_weight_for_grad_input = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity_weight_for_grad_input, + static_scale=static_scale_weight, + keep_in_original_precision=original_prec_weight_for_grad_input, + ) + + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity_grad_output, + static_scale=static_scale_grad_output, + keep_in_original_precision=original_prec_grad_output, + ) + cast_config_grad_output_for_grad_weight = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity_grad_output_for_grad_weight, + static_scale=static_scale_grad_output, + keep_in_original_precision=original_prec_grad_output_for_grad_weight, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, cast_config_weight=cast_config_weight, cast_config_grad_output=cast_config_grad_output, + cast_config_input_for_grad_weight=cast_config_input_for_grad_weight, + cast_config_weight_for_grad_input=cast_config_weight_for_grad_input, + cast_config_grad_output_for_grad_weight=cast_config_grad_output_for_grad_weight, emulate=emulate, ) + self._test_linear_impl( x, m_ref, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 4d82bd1118..cf3c8463db 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -48,12 +48,16 @@ def short_str(self): @dataclass(frozen=True) class CastConfig: """ - Configuration for casting a single tensor to float8 + Configuration for maybe casting a single tensor to float8 """ scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + # If True, this tensor is not scaled to float8 and left in its original + # precision. + # TODO(ideally before this PR lands): a better name for this + keep_in_original_precision: bool = False def __post_init__(self): if self.scaling_type is ScalingType.STATIC: @@ -98,7 +102,7 @@ class Float8GemmConfig: use_fast_accum: bool = False -@dataclass(frozen=True) +@dataclass(frozen=False) class Float8LinearConfig: """ Configuration for converting a `torch.nn.Linear` module to float8 @@ -112,6 +116,18 @@ class Float8LinearConfig: cast_config_weight: CastConfig = CastConfig() cast_config_grad_output: CastConfig = CastConfig() + # + # Optional per-tensor configuration for `input`, `weight`, `grad_output` to + # calculate `grad_weight`, `grad_input`, and `grad_weight` respectively. + # If not specified, then the configuration from the is reused. + # TODO(future PR): maybe rename `cast_config_input` to + # `cast_config_input_for_output`, etc, to make the names consistent, + # will be BC-breaking. + # + cast_config_input_for_grad_weight: Optional[CastConfig] = None + cast_config_weight_for_grad_input: Optional[CastConfig] = None + cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None + # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` @@ -156,26 +172,46 @@ class Float8LinearConfig: delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() def __post_init__(self): + # populate the additional cast overrides, if the user did not specify them + if self.cast_config_input_for_grad_weight is None: + self.cast_config_input_for_grad_weight = self.cast_config_input + if self.cast_config_weight_for_grad_input is None: + self.cast_config_weight_for_grad_input = self.cast_config_weight + if self.cast_config_grad_output_for_grad_weight is None: + self.cast_config_grad_output_for_grad_weight = self.cast_config_grad_output + # float8 all-gather only supports tensorwise, in the future may support blockwise if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: assert not self.enable_fsdp_float8_all_gather, \ f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" - # for now, axiswise granularity is all-or-nothing - # TODO(future PR): enable more granular setting per-gemm-input - has_any_axiswise_scaling = ( - self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or - self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or - self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE - ) - has_all_axiswise_scaling = ( - self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and - self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and - self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE - ) - if has_any_axiswise_scaling: - assert has_all_axiswise_scaling, \ - "for now, axiswise scaling must be enabled for either all casts or none of the casts" + # save some characters in the compatibility checks below + cc_i = self.cast_config_input + cc_w = self.cast_config_weight + cc_go = self.cast_config_grad_output + cc_i2 = self.cast_config_input_for_grad_weight + cc_w2 = self.cast_config_weight_for_grad_input + cc_go2 = self.cast_config_grad_output_for_grad_weight + + # for now, we only have gemm kernels where both operands are scaled with the same + # granularity. In the future this may be relaxed. + assert cc_i.scaling_granularity == cc_w.scaling_granularity, \ + "incompatible scaling granularity for output" + # assert cc_go.scaling_granularity == cc_w2.scaling_granularity, \ + # "incompatible scaling granularity for grad_input" + assert cc_i2.scaling_granularity == cc_go2.scaling_granularity, \ + "incompatible scaling granularity for grad_weight" + + # for now, we only have gemm kernels where both operands are either both + # in high precision, or both in float8. In the future, this may be relaxed. + # TODO(future): make the float8 check more precise with the specific dtypes. + assert cc_i.keep_in_original_precision == cc_w.keep_in_original_precision, \ + "incompatible operand precision for output" + assert cc_go.keep_in_original_precision == cc_w2.keep_in_original_precision, \ + "incompatible operand precision for grad_input" + assert cc_i2.keep_in_original_precision == cc_go2.keep_in_original_precision, \ + "incompatible operand precision for grad_weight" + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 5f87e82fe4..ced1c64316 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -114,7 +114,6 @@ class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): and other granularities in a separate PR. """ - # TODO(this PR): types of inputs @staticmethod def forward( ctx, @@ -124,13 +123,26 @@ def forward( input_scaling_granularity: ScalingGranularity, weight_scaling_granularity: ScalingGranularity, grad_output_scaling_granularity: ScalingGranularity, + input_for_grad_weight_scaling_granularity: ScalingGranularity, + weight_for_grad_input_scaling_granularity: ScalingGranularity, + grad_output_for_grad_weight_scaling_granularity: ScalingGranularity, ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config ctx.input_scaling_granularity = input_scaling_granularity ctx.weight_scaling_granularity = weight_scaling_granularity ctx.grad_output_scaling_granularity = grad_output_scaling_granularity - + ctx.input_for_grad_weight_scaling_granularity = input_for_grad_weight_scaling_granularity + ctx.weight_for_grad_input_scaling_granularity = weight_for_grad_input_scaling_granularity + ctx.grad_output_for_grad_weight_scaling_granularity = grad_output_for_grad_weight_scaling_granularity + + # TODO(next): make the tensorwise+axiswise combination work, by + # broadcasting the tensorwise max to axiswise + if input_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim = None + else: + assert input_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim = -1 input_fp8 = hp_tensor_to_float8_dynamic( input_hp, e4m3_dtype, @@ -140,6 +152,11 @@ def forward( axiswise_dim=-1, ) + if weight_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim = None + else: + assert weight_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim = 0 weight_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, e4m3_dtype, @@ -174,21 +191,34 @@ def backward(ctx, grad_output): # calculate grad_input # + if ctx.grad_output_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim1 = None + else: + assert ctx.grad_output_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim1 = -1 grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=ctx.grad_output_scaling_granularity, - axiswise_dim=-1, + axiswise_dim=axiswise_dim1, ) + + if ctx.weight_for_grad_input_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim2 = None + # TODO(future PR): if the weight was scaled axiswise in the forward, can get the + # tensorwise max from that. + else: + assert ctx.grad_weight_for_grad_input_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim2 = 1 # will be transposed weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ctx.weight_scaling_granularity, - axiswise_dim=1, # will be transposed + scaling_granularity=ctx.weight_for_grad_input_scaling_granularity, + axiswise_dim=axiswise_dim2, ) grad_input = torch.mm( @@ -206,21 +236,36 @@ def backward(ctx, grad_output): # calculate grad_weight # + # TODO(next): do the gemm in high precision if the config says so + + if ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim3 = None + else: + assert ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim3 = 0 # will be transposed + grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_scaling_granularity, - axiswise_dim=0, # will be transposed + scaling_granularity=ctx.grad_output_for_grad_weight_scaling_granularity, + axiswise_dim=axiswise_dim3, ) + + if ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim4 = None + else: + assert ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim4 = 0 + input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ctx.input_scaling_granularity, - axiswise_dim=0, + scaling_granularity=ctx.input_for_grad_weight_scaling_granularity, + axiswise_dim=axiswise_dim4, ) grad_weight = torch.mm( @@ -228,7 +273,9 @@ def backward(ctx, grad_output): input_reshaped_fp8_dim1, ) - return grad_input, grad_weight.t(), None, None, None, None + empty_grads = None, None, None, None, None, None, None + + return grad_input, grad_weight.t(), *empty_grads class Float8Linear(torch.nn.Linear): @@ -557,6 +604,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.config.cast_config_input.scaling_granularity, self.config.cast_config_weight.scaling_granularity, self.config.cast_config_grad_output.scaling_granularity, + self.config.cast_config_input_for_grad_weight.scaling_granularity, + self.config.cast_config_weight_for_grad_input.scaling_granularity, + self.config.cast_config_grad_output_for_grad_weight.scaling_granularity, ) if self.bias is not None: @@ -574,10 +624,14 @@ def scaling_type_repr(self): def scaling_granularity_repr(self): # add scaling granularity settings without using too many characters # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" - gi = self.config.cast_config_input.scaling_granularity.short_str() - gw = self.config.cast_config_weight.scaling_granularity.short_str() - ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() - return f"i:{gi},w:{gw},go:{ggo}" + c = self.config + gi = c.cast_config_input.scaling_granularity.short_str() + gw = c.cast_config_weight.scaling_granularity.short_str() + ggo = c.cast_config_grad_output.scaling_granularity.short_str() + gi2 = c.cast_config_input_for_grad_weight.scaling_granularity.short_str() + gw2 = c.cast_config_weight_for_grad_input.scaling_granularity.short_str() + ggo2 = c.cast_config_grad_output_for_grad_weight.scaling_granularity.short_str() + return f"i:{gi},w:{gw},go:{ggo},i2:{gi2},w2:{gw2},go2:{ggo2}" def extra_repr(self): s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index b97d032113..0a79ca3291 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -251,6 +251,17 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): if is_row_major(b_data.stride()): b_data = b_data.t().contiguous().t() b_scale = b._scale + + # Today, torch._scaled_mm only supports both operands using the + # same granularity. The code below checks for cases where one + # operand is scaled axiswise and one tensorwise. If this case is found, + # we reshape the tensorwise scale to be repeat along the needed axis, + # so that torch._scaled_mm can call the axiswise-axiswise kernel. + if len(a_scale.shape) == 0 and len(b_scale.shape) > 0: + a_scale = a_scale.repeat(a_data.shape[0]).reshape(-1, 1) + elif len(a_scale.shape) > 0 and len(b_scale.shape) == 0: + b_scale = b_scale.repeat(b_data.shape[1]).reshape(1, -1) + return a_data, a_scale, b_data, b_scale From 1d01df33f1d52fcbb7f9a1f16fa449fd2faf61dd Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 25 Sep 2024 12:37:49 -0700 Subject: [PATCH 08/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 105 ++-------------- test/float8/test_compile.py | 176 +++++++++++++-------------- torchao/float8/float8_linear.py | 71 ++++++----- torchao/float8/float8_ops.py | 7 +- torchao/testing/float8/test_utils.py | 114 +++++++++++++++++ 5 files changed, 256 insertions(+), 217 deletions(-) create mode 100644 torchao/testing/float8/test_utils.py diff --git a/test/float8/test_base.py b/test/float8/test_base.py index b8d4a04f01..9967de69bc 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -52,6 +52,10 @@ FP8_TYPES, tensor_to_scale, ) +from torchao.testing.float8.test_utils import ( + scaling_granularities_by_gemm, + get_test_float8_linear_config, +) random.seed(0) torch.manual_seed(0) @@ -60,39 +64,6 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -# scaling granularity and keep_in_original_precision to test by gemm arguments in this -# order: output, grad_input, grad_weight -scaling_granularities_by_gemm = [ - # TODO(before land): move this last - # @lcw's recipe - # TODO(future): add leaving one of the matmuls in bfloat16 - [ - # output = input @ weight_t - # input: axiswise - # weight_t: axiswise - (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - # grad_input = grad_output @ weight - # grad_output: axiswise - # weight: tensorwise (but that can be computed from axiswise done in the forward) - (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), - # grad_weight = input_t @ grad_output, in high precision (bfloat16) - # input_t: high precision - # grad_output: high precision - (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), - ], - # all tensorwise - #[ - # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE), - # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE), - # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE), - #], - # all axiswise - #[ - # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), - # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), - # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), - #], -] def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: @@ -450,68 +421,12 @@ def test_linear( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - static_scale_one = torch.tensor([1.0], device="cuda") - - if scaling_type_input is ScalingType.STATIC: - static_scale_input = static_scale_one - else: - static_scale_input = None - if scaling_type_weight is ScalingType.STATIC: - static_scale_weight = static_scale_one - else: - static_scale_weight = None - if scaling_type_grad_output is ScalingType.STATIC: - static_scale_grad_output = static_scale_one - else: - static_scale_grad_output = None - - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity_input, - static_scale=static_scale_input, - keep_in_original_precision=original_prec_input, - ) - cast_config_input_for_grad_weight = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity_input_for_grad_weight, - static_scale=static_scale_input, - keep_in_original_precision=original_prec_input_for_grad_weight, - ) - - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity_weight, - static_scale=static_scale_weight, - keep_in_original_precision=original_prec_weight, - ) - cast_config_weight_for_grad_input = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity_weight_for_grad_input, - static_scale=static_scale_weight, - keep_in_original_precision=original_prec_weight_for_grad_input, - ) - - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity_grad_output, - static_scale=static_scale_grad_output, - keep_in_original_precision=original_prec_grad_output, - ) - cast_config_grad_output_for_grad_weight = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity_grad_output_for_grad_weight, - static_scale=static_scale_grad_output, - keep_in_original_precision=original_prec_grad_output_for_grad_weight, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - cast_config_input_for_grad_weight=cast_config_input_for_grad_weight, - cast_config_weight_for_grad_input=cast_config_weight_for_grad_input, - cast_config_grad_output_for_grad_weight=cast_config_grad_output_for_grad_weight, - emulate=emulate, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, + emulate, ) self._test_linear_impl( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index eacd317b1a..5a41e29728 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy import random +from typing import List, Tuple import sys import unittest from io import StringIO @@ -33,6 +34,10 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed from torchao.float8.float8_tensor import LinearMMConfig from torchao.float8.float8_utils import e4m3_dtype +from torchao.testing.float8.test_utils import ( + scaling_granularities_by_gemm, + get_test_float8_linear_config, +) from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -52,7 +57,8 @@ def _test_compile_base( x_shape = (16, 16) linear_dtype = torch.bfloat16 - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + x_ref = copy.deepcopy(x) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) m_fp8 = Float8Linear.from_float( @@ -64,7 +70,7 @@ def _test_compile_base( m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) y_fp8 = m_fp8(x) y_fp8.sum().backward() - y_ref = m_ref(x) + y_ref = m_ref(x_ref) y_ref.sum().backward() # TODO(future PR): can also test fp8 eager vs compile here with a tigher # tolerance @@ -73,65 +79,33 @@ def _test_compile_base( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 ) torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2) - -def _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, - emulate, -): - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - emulate=emulate, - ) - return config + torch.testing.assert_close(x.grad, x_ref.grad, atol=8e-2, rtol=8e-2) def is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, dtype, ) -> bool: - if scaling_granularity is ScalingGranularity.AXISWISE: + + ( + (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), + (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), + (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), + ) = scaling_granularities_by_gemm + + has_any_axiswise_scaling = ( + scaling_granularity_input is ScalingGranularity.AXISWISE or + scaling_granularity_weight is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output is ScalingGranularity.AXISWISE or + scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or + scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE + ) + + if has_any_axiswise_scaling: if ( scaling_type_input != ScalingType.DYNAMIC or scaling_type_weight != ScalingType.DYNAMIC or @@ -145,19 +119,28 @@ def is_supported( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_input", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", [ScalingType.DYNAMIC,] ) +# @pytest.mark.parametrize( +# "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +# ) @pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +# @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, ]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, ]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, @@ -165,11 +148,11 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): if not is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, @@ -178,11 +161,11 @@ def test_eager_only( pytest.skip() torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -194,20 +177,26 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +# @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False,]) @pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_input", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16,]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, @@ -215,11 +204,11 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): if not is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, @@ -228,11 +217,11 @@ def test_aot_eager( pytest.skip() torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -246,30 +235,35 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_input", [ScalingType.DYNAMIC, ] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", [ScalingType.DYNAMIC, ] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", [ScalingType.DYNAMIC, ] ) @pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16,]) def test_inductor( fullgraph, emulate: bool, scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): if not is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, @@ -278,11 +272,11 @@ def test_inductor( pytest.skip() torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, emulate, ) _test_compile_base( diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index ced1c64316..09b5ab58f7 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -126,6 +126,9 @@ def forward( input_for_grad_weight_scaling_granularity: ScalingGranularity, weight_for_grad_input_scaling_granularity: ScalingGranularity, grad_output_for_grad_weight_scaling_granularity: ScalingGranularity, + # TODO(this PR): add all the others + input_for_grad_weight_keep_in_original_precision: bool, + grad_output_for_grad_weight_keep_in_original_precision: bool, ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config @@ -135,9 +138,9 @@ def forward( ctx.input_for_grad_weight_scaling_granularity = input_for_grad_weight_scaling_granularity ctx.weight_for_grad_input_scaling_granularity = weight_for_grad_input_scaling_granularity ctx.grad_output_for_grad_weight_scaling_granularity = grad_output_for_grad_weight_scaling_granularity + ctx.input_for_grad_weight_keep_in_original_precision = input_for_grad_weight_keep_in_original_precision + ctx.grad_output_for_grad_weight_keep_in_original_precision = grad_output_for_grad_weight_keep_in_original_precision - # TODO(next): make the tensorwise+axiswise combination work, by - # broadcasting the tensorwise max to axiswise if input_scaling_granularity == ScalingGranularity.TENSORWISE: axiswise_dim = None else: @@ -210,7 +213,7 @@ def backward(ctx, grad_output): # TODO(future PR): if the weight was scaled axiswise in the forward, can get the # tensorwise max from that. else: - assert ctx.grad_weight_for_grad_input_scaling_granularity == ScalingGranularity.AXISWISE + assert ctx.weight_for_grad_input_scaling_granularity == ScalingGranularity.AXISWISE axiswise_dim2 = 1 # will be transposed weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, @@ -236,44 +239,52 @@ def backward(ctx, grad_output): # calculate grad_weight # - # TODO(next): do the gemm in high precision if the config says so + # TODO(this PR): also respect keep_in_original_precision for the other gemms + if ctx.grad_output_for_grad_weight_keep_in_original_precision: + # TODO(this PR): more sensical variable name, now this isn't always fp8 + grad_output_reshaped_fp8_dim1 = grad_output_reshaped - if ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: - axiswise_dim3 = None else: - assert ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE - axiswise_dim3 = 0 # will be transposed + if ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim3 = None + else: + assert ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim3 = 0 # will be transposed + + grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_for_grad_weight_scaling_granularity, + axiswise_dim=axiswise_dim3, + ) - grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( - grad_output_reshaped, - e5m2_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_for_grad_weight_scaling_granularity, - axiswise_dim=axiswise_dim3, - ) + if ctx.input_for_grad_weight_keep_in_original_precision: + input_reshaped_fp8_dim1 = input_hp_reshaped - if ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: - axiswise_dim4 = None else: - assert ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE - axiswise_dim4 = 0 + if ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: + axiswise_dim4 = None + else: + assert ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE + axiswise_dim4 = 0 - input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( - input_hp_reshaped, - e4m3_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ctx.input_for_grad_weight_scaling_granularity, - axiswise_dim=axiswise_dim4, - ) + input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ctx.input_for_grad_weight_scaling_granularity, + axiswise_dim=axiswise_dim4, + ) grad_weight = torch.mm( grad_output_reshaped_fp8_dim1.t(), input_reshaped_fp8_dim1, ) - empty_grads = None, None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None, None, None return grad_input, grad_weight.t(), *empty_grads @@ -607,6 +618,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.config.cast_config_input_for_grad_weight.scaling_granularity, self.config.cast_config_weight_for_grad_input.scaling_granularity, self.config.cast_config_grad_output_for_grad_weight.scaling_granularity, + self.config.cast_config_input_for_grad_weight.keep_in_original_precision, + self.config.cast_config_grad_output_for_grad_weight.keep_in_original_precision, ) if self.bias is not None: diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 0a79ca3291..8f5bc768eb 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -257,9 +257,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): # operand is scaled axiswise and one tensorwise. If this case is found, # we reshape the tensorwise scale to be repeat along the needed axis, # so that torch._scaled_mm can call the axiswise-axiswise kernel. - if len(a_scale.shape) == 0 and len(b_scale.shape) > 0: + # Note: using shape/size info does not work with compile here, which is + # why we are using inferring scaling type from the presence of + # axiswise_dim. + if a._axiswise_dim is None and b._axiswise_dim is not None: a_scale = a_scale.repeat(a_data.shape[0]).reshape(-1, 1) - elif len(a_scale.shape) > 0 and len(b_scale.shape) == 0: + elif a._axiswise_dim is not None and b._axiswise_dim is None: b_scale = b_scale.repeat(b_data.shape[1]).reshape(1, -1) return a_data, a_scale, b_data, b_scale diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py new file mode 100644 index 0000000000..497fc2a03b --- /dev/null +++ b/torchao/testing/float8/test_utils.py @@ -0,0 +1,114 @@ +import torch +from torchao.float8.config import ScalingGranularity, ScalingType, CastConfig, Float8LinearConfig + +# scaling granularity and keep_in_original_precision to test by gemm arguments in this +# order: output, grad_input, grad_weight +scaling_granularities_by_gemm = [ + # TODO(before land): move this last + # @lcw's recipe + # TODO(future): add leaving one of the matmuls in bfloat16 + [ + # output = input @ weight_t + # input: axiswise + # weight_t: axiswise + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + # grad_input = grad_output @ weight + # grad_output: axiswise + # weight: tensorwise (but that can be computed from axiswise done in the forward) + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), + # grad_weight = input_t @ grad_output, in high precision (bfloat16) + # input_t: high precision + # grad_output: high precision + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), + ], + # all tensorwise + #[ + # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + #], + # all axiswise + #[ + # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + #], +] + +def get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, + emulate: bool, +): + ( + (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), + (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), + (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), + ) = scaling_granularities_by_gemm + + static_scale_one = torch.tensor([1.0], device="cuda") + + if scaling_type_input is ScalingType.STATIC: + static_scale_input = static_scale_one + else: + static_scale_input = None + if scaling_type_weight is ScalingType.STATIC: + static_scale_weight = static_scale_one + else: + static_scale_weight = None + if scaling_type_grad_output is ScalingType.STATIC: + static_scale_grad_output = static_scale_one + else: + static_scale_grad_output = None + + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity_input, + static_scale=static_scale_input, + keep_in_original_precision=original_prec_input, + ) + cast_config_input_for_grad_weight = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity_input_for_grad_weight, + static_scale=static_scale_input, + keep_in_original_precision=original_prec_input_for_grad_weight, + ) + + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity_weight, + static_scale=static_scale_weight, + keep_in_original_precision=original_prec_weight, + ) + cast_config_weight_for_grad_input = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity_weight_for_grad_input, + static_scale=static_scale_weight, + keep_in_original_precision=original_prec_weight_for_grad_input, + ) + + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity_grad_output, + static_scale=static_scale_grad_output, + keep_in_original_precision=original_prec_grad_output, + ) + cast_config_grad_output_for_grad_weight = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity_grad_output_for_grad_weight, + static_scale=static_scale_grad_output, + keep_in_original_precision=original_prec_grad_output_for_grad_weight, + ) + + config = Float8LinearConfig( + cast_config_input=cast_config_input, + cast_config_weight=cast_config_weight, + cast_config_grad_output=cast_config_grad_output, + cast_config_input_for_grad_weight=cast_config_input_for_grad_weight, + cast_config_weight_for_grad_input=cast_config_weight_for_grad_input, + cast_config_grad_output_for_grad_weight=cast_config_grad_output_for_grad_weight, + emulate=emulate, + ) + return config From 62acdaf7d5169c9c9eadd33b0876fc9e0b4197fa Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 25 Sep 2024 13:01:35 -0700 Subject: [PATCH 09/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 92 +++++++++++++--------- torchao/float8/config.py | 1 + torchao/testing/float8/test_utils.py | 75 +++++++++++------- 3 files changed, 102 insertions(+), 66 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 6afefa0096..912f4a1c1b 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -33,6 +33,10 @@ linear_requires_sync, sync_float8_amax_and_scale_history, ) +from torchao.testing.float8.test_utils import ( + scaling_granularities_by_gemm_lcw_recipe, + get_test_float8_linear_config, +) from torch.profiler import profile, ProfilerActivity, record_function from utils import ( kernel_name_to_category, @@ -258,6 +262,8 @@ def main( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", scaling_granularity: str = "tensorwise", + # TODO(future PR): clean up the override, it's confusing + recipe_override: Optional[str] = None, model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -271,45 +277,57 @@ def main( scaling_type_grad_output = ScalingType(scaling_type_grad_output) scaling_granularity = ScalingGranularity(scaling_granularity) - if scaling_type_input is ScalingType.STATIC: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, + if recipe_override is None: + + if scaling_type_input is ScalingType.STATIC: + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, + ) + else: + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) + if scaling_type_weight is ScalingType.STATIC: + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, + ) + else: + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) + if scaling_type_grad_output is ScalingType.STATIC: + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, + ) + else: + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) + + config = Float8LinearConfig( + cast_config_input=cast_config_input, + cast_config_weight=cast_config_weight, + cast_config_grad_output=cast_config_grad_output, ) - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - ) + elif recipe_override == "lcw": + scaling_granularities_by_gemm = scaling_granularities_by_gemm_lcw_recipe + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, + False, # emulate + ) scaling_repr = "_".join( [ diff --git a/torchao/float8/config.py b/torchao/float8/config.py index cf3c8463db..6442067157 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -131,6 +131,7 @@ class Float8LinearConfig: # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` + # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 497fc2a03b..6aa5e78e91 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,38 +1,47 @@ import torch -from torchao.float8.config import ScalingGranularity, ScalingType, CastConfig, Float8LinearConfig +from torchao.float8.config import ( + ScalingGranularity, + ScalingType, + CastConfig, + Float8LinearConfig, + Float8GemmConfig, +) + +scaling_granularities_by_gemm_lcw_recipe = [ + # @lcw's recipe + # output = input @ weight_t + # input: axiswise + # weight_t: axiswise + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + # grad_input = grad_output @ weight + # grad_output: axiswise + # weight: tensorwise (but that can be computed from axiswise done in the forward) + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), + # grad_weight = input_t @ grad_output, in high precision (bfloat16) + # input_t: high precision + # grad_output: high precision + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), +] + +scaling_granularities_by_gemm_all_tensorwise = [ + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), +] + +scaling_granularities_by_gemm_all_axiswise = [ + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), +] # scaling granularity and keep_in_original_precision to test by gemm arguments in this # order: output, grad_input, grad_weight scaling_granularities_by_gemm = [ # TODO(before land): move this last - # @lcw's recipe - # TODO(future): add leaving one of the matmuls in bfloat16 - [ - # output = input @ weight_t - # input: axiswise - # weight_t: axiswise - (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - # grad_input = grad_output @ weight - # grad_output: axiswise - # weight: tensorwise (but that can be computed from axiswise done in the forward) - (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), - # grad_weight = input_t @ grad_output, in high precision (bfloat16) - # input_t: high precision - # grad_output: high precision - (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), - ], - # all tensorwise - #[ - # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), - # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), - # (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), - #], - # all axiswise - #[ - # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - # (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - #], + scaling_granularities_by_gemm_lcw_recipe, + # scaling_granularities_by_gemm_all_tensorwise, + # scaling_granularities_by_gemm_all_axiswise, ] def get_test_float8_linear_config( @@ -102,6 +111,11 @@ def get_test_float8_linear_config( keep_in_original_precision=original_prec_grad_output_for_grad_weight, ) + gemm_config_output = Float8GemmConfig(use_fast_accum=True) + # TODO(this PR): toggle fast accum by axiswise scaling presence + gemm_config_grad_input = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_weight = Float8GemmConfig(use_fast_accum=True) + config = Float8LinearConfig( cast_config_input=cast_config_input, cast_config_weight=cast_config_weight, @@ -109,6 +123,9 @@ def get_test_float8_linear_config( cast_config_input_for_grad_weight=cast_config_input_for_grad_weight, cast_config_weight_for_grad_input=cast_config_weight_for_grad_input, cast_config_grad_output_for_grad_weight=cast_config_grad_output_for_grad_weight, + gemm_config_output=gemm_config_output, + gemm_config_grad_input=gemm_config_grad_input, + gemm_config_grad_weight=gemm_config_grad_weight, emulate=emulate, ) return config From 381e16e2a569980797f84e945ac9cb8b5d17aa13 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 09:27:08 -0700 Subject: [PATCH 10/32] Update [ghstack-poisoned] --- torchao/float8/config.py | 14 ++--- torchao/float8/float8_linear.py | 90 +++++++++++++++------------------ torchao/float8/float8_utils.py | 57 ++++++++++++++++++++- 3 files changed, 104 insertions(+), 57 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 6442067157..1470d07930 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -190,17 +190,17 @@ def __post_init__(self): cc_i = self.cast_config_input cc_w = self.cast_config_weight cc_go = self.cast_config_grad_output - cc_i2 = self.cast_config_input_for_grad_weight - cc_w2 = self.cast_config_weight_for_grad_input - cc_go2 = self.cast_config_grad_output_for_grad_weight + cc_i_gw = self.cast_config_input_for_grad_weight + cc_w_gi = self.cast_config_weight_for_grad_input + cc_go_gw = self.cast_config_grad_output_for_grad_weight # for now, we only have gemm kernels where both operands are scaled with the same # granularity. In the future this may be relaxed. assert cc_i.scaling_granularity == cc_w.scaling_granularity, \ "incompatible scaling granularity for output" - # assert cc_go.scaling_granularity == cc_w2.scaling_granularity, \ + # assert cc_go.scaling_granularity == cc_w_gi.scaling_granularity, \ # "incompatible scaling granularity for grad_input" - assert cc_i2.scaling_granularity == cc_go2.scaling_granularity, \ + assert cc_i_gw.scaling_granularity == cc_go_gw.scaling_granularity, \ "incompatible scaling granularity for grad_weight" # for now, we only have gemm kernels where both operands are either both @@ -208,9 +208,9 @@ def __post_init__(self): # TODO(future): make the float8 check more precise with the specific dtypes. assert cc_i.keep_in_original_precision == cc_w.keep_in_original_precision, \ "incompatible operand precision for output" - assert cc_go.keep_in_original_precision == cc_w2.keep_in_original_precision, \ + assert cc_go.keep_in_original_precision == cc_w_gi.keep_in_original_precision, \ "incompatible operand precision for grad_input" - assert cc_i2.keep_in_original_precision == cc_go2.keep_in_original_precision, \ + assert cc_i_gw.keep_in_original_precision == cc_go_gw.keep_in_original_precision, \ "incompatible operand precision for grad_weight" diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 09b5ab58f7..dc70f8eece 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -33,7 +33,13 @@ ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax +from torchao.float8.float8_utils import ( + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, + float8_linear_config_to_concise_casts_config, + Float8LinearConciseCastsConfig, +) from torchao.float8.fsdp_utils import ( WeightWithDelayedFloat8CastTensor, @@ -120,52 +126,39 @@ def forward( input_hp: torch.Tensor, weight_hp_t: torch.Tensor, linear_mm_config: LinearMMConfig, - input_scaling_granularity: ScalingGranularity, - weight_scaling_granularity: ScalingGranularity, - grad_output_scaling_granularity: ScalingGranularity, - input_for_grad_weight_scaling_granularity: ScalingGranularity, - weight_for_grad_input_scaling_granularity: ScalingGranularity, - grad_output_for_grad_weight_scaling_granularity: ScalingGranularity, - # TODO(this PR): add all the others - input_for_grad_weight_keep_in_original_precision: bool, - grad_output_for_grad_weight_keep_in_original_precision: bool, + concise_casts_config: Float8LinearConciseCastsConfig, ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config - ctx.input_scaling_granularity = input_scaling_granularity - ctx.weight_scaling_granularity = weight_scaling_granularity - ctx.grad_output_scaling_granularity = grad_output_scaling_granularity - ctx.input_for_grad_weight_scaling_granularity = input_for_grad_weight_scaling_granularity - ctx.weight_for_grad_input_scaling_granularity = weight_for_grad_input_scaling_granularity - ctx.grad_output_for_grad_weight_scaling_granularity = grad_output_for_grad_weight_scaling_granularity - ctx.input_for_grad_weight_keep_in_original_precision = input_for_grad_weight_keep_in_original_precision - ctx.grad_output_for_grad_weight_keep_in_original_precision = grad_output_for_grad_weight_keep_in_original_precision - - if input_scaling_granularity == ScalingGranularity.TENSORWISE: + ctx.concise_casts_config = concise_casts_config + + c = concise_casts_config + + if c.cc_i.sc_gr == ScalingGranularity.TENSORWISE: axiswise_dim = None else: - assert input_scaling_granularity == ScalingGranularity.AXISWISE + assert c.cc_i.sc_gr == ScalingGranularity.AXISWISE axiswise_dim = -1 input_fp8 = hp_tensor_to_float8_dynamic( input_hp, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=input_scaling_granularity, + scaling_granularity=c.cc_i.sc_gr, axiswise_dim=-1, ) - if weight_scaling_granularity == ScalingGranularity.TENSORWISE: + if c.cc_w.sc_gr == ScalingGranularity.TENSORWISE: axiswise_dim = None else: - assert weight_scaling_granularity == ScalingGranularity.AXISWISE + assert c.cc_w.sc_gr == ScalingGranularity.AXISWISE axiswise_dim = 0 weight_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=weight_scaling_granularity, + scaling_granularity=c.cc_w.sc_gr, axiswise_dim=0, ) @@ -182,6 +175,7 @@ def backward(ctx, grad_output): input_hp, weight_hp_t = ctx.saved_tensors # TODO scaling + c = ctx.concise_casts_config # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -194,33 +188,34 @@ def backward(ctx, grad_output): # calculate grad_input # - if ctx.grad_output_scaling_granularity == ScalingGranularity.TENSORWISE: + + if c.cc_go.sc_gr == ScalingGranularity.TENSORWISE: axiswise_dim1 = None else: - assert ctx.grad_output_scaling_granularity == ScalingGranularity.AXISWISE + assert c.cc_go.sc_gr == ScalingGranularity.AXISWISE axiswise_dim1 = -1 grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_scaling_granularity, + scaling_granularity=c.cc_go.sc_gr, axiswise_dim=axiswise_dim1, ) - - if ctx.weight_for_grad_input_scaling_granularity == ScalingGranularity.TENSORWISE: + + if c.cc_w_gi.sc_gr == ScalingGranularity.TENSORWISE: axiswise_dim2 = None # TODO(future PR): if the weight was scaled axiswise in the forward, can get the # tensorwise max from that. else: - assert ctx.weight_for_grad_input_scaling_granularity == ScalingGranularity.AXISWISE + assert c.cc_w_gi.sc_gr == ScalingGranularity.AXISWISE axiswise_dim2 = 1 # will be transposed weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ctx.weight_for_grad_input_scaling_granularity, + scaling_granularity=c.cc_w_gi.sc_gr, axiswise_dim=axiswise_dim2, ) @@ -240,15 +235,15 @@ def backward(ctx, grad_output): # # TODO(this PR): also respect keep_in_original_precision for the other gemms - if ctx.grad_output_for_grad_weight_keep_in_original_precision: + if c.cc_go_gw.orig_prec: # TODO(this PR): more sensical variable name, now this isn't always fp8 grad_output_reshaped_fp8_dim1 = grad_output_reshaped else: - if ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: + if c.cc_go_gw.sc_gr == ScalingGranularity.TENSORWISE: axiswise_dim3 = None else: - assert ctx.grad_output_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE + assert c.cc_go_gw.sc_gr == ScalingGranularity.AXISWISE axiswise_dim3 = 0 # will be transposed grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( @@ -256,18 +251,19 @@ def backward(ctx, grad_output): e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_for_grad_weight_scaling_granularity, + scaling_granularity=c.cc_go_gw.sc_gr, axiswise_dim=axiswise_dim3, ) - if ctx.input_for_grad_weight_keep_in_original_precision: + + if c.cc_i_gw.orig_prec: input_reshaped_fp8_dim1 = input_hp_reshaped else: - if ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.TENSORWISE: + if c.cc_i_gw.sc_gr == ScalingGranularity.TENSORWISE: axiswise_dim4 = None else: - assert ctx.input_for_grad_weight_scaling_granularity == ScalingGranularity.AXISWISE + assert c.cc_i_gw.sc_gr == ScalingGranularity.AXISWISE axiswise_dim4 = 0 input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( @@ -275,7 +271,7 @@ def backward(ctx, grad_output): e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ctx.input_for_grad_weight_scaling_granularity, + scaling_granularity=c.cc_i_gw.sc_gr, axiswise_dim=axiswise_dim4, ) @@ -284,7 +280,7 @@ def backward(ctx, grad_output): input_reshaped_fp8_dim1, ) - empty_grads = None, None, None, None, None, None, None, None, None + empty_grads = None, None return grad_input, grad_weight.t(), *empty_grads @@ -371,6 +367,9 @@ def __init__(self, *args, **kwargs): # would be initialized in every iteration. self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward + self.concise_casts_config: Float8LinearConciseCastsConfig = \ + float8_linear_config_to_concise_casts_config(self.config) + def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len @@ -612,14 +611,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.weight.t(), self.linear_mm_config, - self.config.cast_config_input.scaling_granularity, - self.config.cast_config_weight.scaling_granularity, - self.config.cast_config_grad_output.scaling_granularity, - self.config.cast_config_input_for_grad_weight.scaling_granularity, - self.config.cast_config_weight_for_grad_input.scaling_granularity, - self.config.cast_config_grad_output_for_grad_weight.scaling_granularity, - self.config.cast_config_input_for_grad_weight.keep_in_original_precision, - self.config.cast_config_grad_output_for_grad_weight.keep_in_original_precision, + self.concise_casts_config, ) if self.bias is not None: diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index e79cf27d88..d42ae37ca2 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Optional, Tuple, Union +from typing import Iterable, Literal, NamedTuple, Optional, Tuple, Union import torchao.float8.config as config @@ -258,3 +258,58 @@ def pad_tensor_for_matmul( pad_dim2 = dim2_aligned - dim2 return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) + + +# The code below introduces a bit of duplication with Float8LinearConfig in +# order to improve readability of the implementation of how Float8Linear +# uses the config. Specifically, we do two things: +# 1. wrap the relevant parts of configs in namedtuple, so we can pass +# them around in compile-friendly code. +# 2. make the tuple key names more brief, to make the implementation +# code less verbose (the code was so verbose that I felt the need +# to add this workaround). +# As I was writing this, it became less and less clear on why not just have +# a namedtuple as a top level config. Punting that to a future PR as +# that might be BC-breaking, but probably worth exploring. +# Note: I also think below is pretty hacky, it's good enough to unblock +# further prototyping, but IMO pretty important to clean up sooner rather +# than later. + +class ConciseCastConfig(NamedTuple): + sc_tp: config.ScalingType + sc_gr: config.ScalingGranularity + st_sc: Optional[torch.Tensor] + orig_prec: bool + + @classmethod + def from_cast_config(cls, c: config.CastConfig): + return cls( + sc_tp=c.scaling_type, + sc_gr=c.scaling_granularity, + st_sc=c.static_scale, + orig_prec=c.keep_in_original_precision, + ) + +class Float8LinearConciseCastsConfig(NamedTuple): + cc_i: ConciseCastConfig + cc_w: ConciseCastConfig + cc_go: ConciseCastConfig + cc_i_gw: ConciseCastConfig + cc_w_gi: ConciseCastConfig + cc_go_gw: ConciseCastConfig + + +def float8_linear_config_to_concise_casts_config( + c: config.Float8LinearConfig, +) -> Float8LinearConciseCastsConfig: + + concise_config = Float8LinearConciseCastsConfig( + cc_i = ConciseCastConfig.from_cast_config(c.cast_config_input), + cc_w = ConciseCastConfig.from_cast_config(c.cast_config_weight), + cc_go = ConciseCastConfig.from_cast_config(c.cast_config_grad_output), + cc_i_gw = ConciseCastConfig.from_cast_config(c.cast_config_input_for_grad_weight), + cc_w_gi = ConciseCastConfig.from_cast_config(c.cast_config_weight_for_grad_input), + cc_go_gw = ConciseCastConfig.from_cast_config(c.cast_config_grad_output_for_grad_weight), + ) + + return concise_config From afdf660d04fd6d8364a39df0163d5b0fd319f17d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 09:38:08 -0700 Subject: [PATCH 11/32] Update [ghstack-poisoned] --- torchao/float8/float8_utils.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index d42ae37ca2..665373c91f 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -260,21 +260,22 @@ def pad_tensor_for_matmul( return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) -# The code below introduces a bit of duplication with Float8LinearConfig in -# order to improve readability of the implementation of how Float8Linear +# The code below introduces a bit of duplication with Float8LinearConfig in +# order to improve readability of the implementation of how Float8Linear # uses the config. Specifically, we do two things: # 1. wrap the relevant parts of configs in namedtuple, so we can pass # them around in compile-friendly code. # 2. make the tuple key names more brief, to make the implementation # code less verbose (the code was so verbose that I felt the need -# to add this workaround). +# to add this workaround). # As I was writing this, it became less and less clear on why not just have -# a namedtuple as a top level config. Punting that to a future PR as +# a namedtuple as a top level config. Punting that to a future PR as # that might be BC-breaking, but probably worth exploring. # Note: I also think below is pretty hacky, it's good enough to unblock # further prototyping, but IMO pretty important to clean up sooner rather # than later. + class ConciseCastConfig(NamedTuple): sc_tp: config.ScalingType sc_gr: config.ScalingGranularity @@ -290,6 +291,7 @@ def from_cast_config(cls, c: config.CastConfig): orig_prec=c.keep_in_original_precision, ) + class Float8LinearConciseCastsConfig(NamedTuple): cc_i: ConciseCastConfig cc_w: ConciseCastConfig @@ -302,14 +304,15 @@ class Float8LinearConciseCastsConfig(NamedTuple): def float8_linear_config_to_concise_casts_config( c: config.Float8LinearConfig, ) -> Float8LinearConciseCastsConfig: - concise_config = Float8LinearConciseCastsConfig( - cc_i = ConciseCastConfig.from_cast_config(c.cast_config_input), - cc_w = ConciseCastConfig.from_cast_config(c.cast_config_weight), - cc_go = ConciseCastConfig.from_cast_config(c.cast_config_grad_output), - cc_i_gw = ConciseCastConfig.from_cast_config(c.cast_config_input_for_grad_weight), - cc_w_gi = ConciseCastConfig.from_cast_config(c.cast_config_weight_for_grad_input), - cc_go_gw = ConciseCastConfig.from_cast_config(c.cast_config_grad_output_for_grad_weight), + cc_i=ConciseCastConfig.from_cast_config(c.cast_config_input), + cc_w=ConciseCastConfig.from_cast_config(c.cast_config_weight), + cc_go=ConciseCastConfig.from_cast_config(c.cast_config_grad_output), + cc_i_gw=ConciseCastConfig.from_cast_config(c.cast_config_input_for_grad_weight), + cc_w_gi=ConciseCastConfig.from_cast_config(c.cast_config_weight_for_grad_input), + cc_go_gw=ConciseCastConfig.from_cast_config( + c.cast_config_grad_output_for_grad_weight + ), ) return concise_config From d53f2ce5c643658295de108f1f00d45338f9b8a7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 09:53:07 -0700 Subject: [PATCH 12/32] Update [ghstack-poisoned] --- torchao/float8/float8_linear.py | 56 +++++--------------------- torchao/float8/float8_scaling_utils.py | 15 +++++++ 2 files changed, 26 insertions(+), 45 deletions(-) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index dc70f8eece..3b8ac25a04 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -21,6 +21,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, hp_tensor_to_float8_static, + get_maybe_axiswise_dim, NoopFwToFloat8E5M2BwDelayed, NoopFwToFloat8E5M2BwDynamic, NoopFwToFloat8E5M2BwStatic, @@ -134,32 +135,22 @@ def forward( c = concise_casts_config - if c.cc_i.sc_gr == ScalingGranularity.TENSORWISE: - axiswise_dim = None - else: - assert c.cc_i.sc_gr == ScalingGranularity.AXISWISE - axiswise_dim = -1 input_fp8 = hp_tensor_to_float8_dynamic( input_hp, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cc_i.sc_gr, - axiswise_dim=-1, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_i.sc_gr), ) - if c.cc_w.sc_gr == ScalingGranularity.TENSORWISE: - axiswise_dim = None - else: - assert c.cc_w.sc_gr == ScalingGranularity.AXISWISE - axiswise_dim = 0 weight_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cc_w.sc_gr, - axiswise_dim=0, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_w.sc_gr), ) # the reshapes are needed in order to make the shapes compatible with @@ -188,35 +179,25 @@ def backward(ctx, grad_output): # calculate grad_input # - - if c.cc_go.sc_gr == ScalingGranularity.TENSORWISE: - axiswise_dim1 = None - else: - assert c.cc_go.sc_gr == ScalingGranularity.AXISWISE - axiswise_dim1 = -1 grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cc_go.sc_gr, - axiswise_dim=axiswise_dim1, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_go.sc_gr), ) - if c.cc_w_gi.sc_gr == ScalingGranularity.TENSORWISE: - axiswise_dim2 = None - # TODO(future PR): if the weight was scaled axiswise in the forward, can get the - # tensorwise max from that. - else: - assert c.cc_w_gi.sc_gr == ScalingGranularity.AXISWISE - axiswise_dim2 = 1 # will be transposed + # TODO(future PR): if the weight was scaled axiswise in the forward, can get the + # tensorwise max from that. + # TODO link the pytorch core issue weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cc_w_gi.sc_gr, - axiswise_dim=axiswise_dim2, + axiswise_dim=get_maybe_axiswise_dim(1, c.cc_w_gi.sc_gr), ) grad_input = torch.mm( @@ -235,44 +216,29 @@ def backward(ctx, grad_output): # # TODO(this PR): also respect keep_in_original_precision for the other gemms + # TODO(this PR): more sensical variable name, now this isn't always fp8 if c.cc_go_gw.orig_prec: - # TODO(this PR): more sensical variable name, now this isn't always fp8 grad_output_reshaped_fp8_dim1 = grad_output_reshaped - else: - if c.cc_go_gw.sc_gr == ScalingGranularity.TENSORWISE: - axiswise_dim3 = None - else: - assert c.cc_go_gw.sc_gr == ScalingGranularity.AXISWISE - axiswise_dim3 = 0 # will be transposed - grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cc_go_gw.sc_gr, - axiswise_dim=axiswise_dim3, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_go_gw.sc_gr), ) - if c.cc_i_gw.orig_prec: input_reshaped_fp8_dim1 = input_hp_reshaped - else: - if c.cc_i_gw.sc_gr == ScalingGranularity.TENSORWISE: - axiswise_dim4 = None - else: - assert c.cc_i_gw.sc_gr == ScalingGranularity.AXISWISE - axiswise_dim4 = 0 - input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cc_i_gw.sc_gr, - axiswise_dim=axiswise_dim4, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_i_gw.sc_gr), ) grad_weight = torch.mm( diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index f46293d616..a8ee92f286 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -141,6 +141,21 @@ def hp_tensor_to_float8_static( ) +def get_maybe_axiswise_dim( + axiswise_dim: int, + scaling_granularity: ScalingGranularity, +) -> Optional[int]: + """ + Convenience function which takes in an axiswise dim which is only relevant + for axiswise scaing, and a scaling type. The output is pass-through + if scaling type is axiswise, and None otherwise. This is done to keep the + logic from choosing the axiswise dim out of the scaling function. + """ + if scaling_granularity is ScalingGranularity.AXISWISE: + return axiswise_dim + return None + + def _maybe_initialize_amaxes_scales_for_float8_cast( x, cur_amax, From 0737eb821649509c6a3df36aade71bb369a0d0f7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 11:08:34 -0700 Subject: [PATCH 13/32] Update [ghstack-poisoned] --- torchao/float8/float8_linear.py | 109 +++++++++++++++++--------------- 1 file changed, 59 insertions(+), 50 deletions(-) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 3b8ac25a04..2767b7d1e9 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -135,37 +135,41 @@ def forward( c = concise_casts_config - input_fp8 = hp_tensor_to_float8_dynamic( - input_hp, - e4m3_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=c.cc_i.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_i.sc_gr), - ) + if c.cc_i.orig_prec: + input_maybe_fp8 = input_hp + else: + input_maybe_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=c.cc_i.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_i.sc_gr), + ) - weight_fp8_t = hp_tensor_to_float8_dynamic( - weight_hp_t, - e4m3_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=c.cc_w.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(0, c.cc_w.sc_gr), - ) + if c.cc_w.orig_prec: + weight_maybe_fp8_t = weight_hp_t + else: + weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cc_w.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_w.sc_gr), + ) # the reshapes are needed in order to make the shapes compatible with # torch.mm - orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) - res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + orig_shape = input_maybe_fp8.shape + input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t) res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) return res_bits @staticmethod def backward(ctx, grad_output): input_hp, weight_hp_t = ctx.saved_tensors - - # TODO scaling c = ctx.concise_casts_config # the reshapes are needed in order to make the shapes compatible with @@ -179,30 +183,37 @@ def backward(ctx, grad_output): # calculate grad_input # - grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( - grad_output_reshaped, - e5m2_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=c.cc_go.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_go.sc_gr), - ) + if c.cc_go.orig_prec: + grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped + else: + grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=c.cc_go.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_go.sc_gr), + ) - # TODO(future PR): if the weight was scaled axiswise in the forward, can get the - # tensorwise max from that. - # TODO link the pytorch core issue - weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( - weight_hp_t, - e4m3_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=c.cc_w_gi.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(1, c.cc_w_gi.sc_gr), - ) + if c.cc_w_gi.orig_prec: + weight_t_maybe_fp8_dim0 = weight_hp_t + else: + # Note: we need https://github.com/pytorch/pytorch/issues/136267 + # to be solved to have a chance to reuse max(abs(weight, dim=...)) + # from the forward to get max(abs(weight)) here without reading + # the entire tensor. + weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cc_w_gi.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(1, c.cc_w_gi.sc_gr), + ) grad_input = torch.mm( - grad_output_reshaped_fp8_dim0, - weight_t_fp8_dim0.t(), + grad_output_reshaped_maybe_fp8_dim0, + weight_t_maybe_fp8_dim0.t(), ) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -215,12 +226,10 @@ def backward(ctx, grad_output): # calculate grad_weight # - # TODO(this PR): also respect keep_in_original_precision for the other gemms - # TODO(this PR): more sensical variable name, now this isn't always fp8 if c.cc_go_gw.orig_prec: - grad_output_reshaped_fp8_dim1 = grad_output_reshaped + grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped else: - grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, e5m2_dtype, ctx.linear_mm_config, @@ -230,9 +239,9 @@ def backward(ctx, grad_output): ) if c.cc_i_gw.orig_prec: - input_reshaped_fp8_dim1 = input_hp_reshaped + input_reshaped_maybe_fp8_dim1 = input_hp_reshaped else: - input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, e4m3_dtype, ctx.linear_mm_config, @@ -242,8 +251,8 @@ def backward(ctx, grad_output): ) grad_weight = torch.mm( - grad_output_reshaped_fp8_dim1.t(), - input_reshaped_fp8_dim1, + grad_output_reshaped_maybe_fp8_dim1.t(), + input_reshaped_maybe_fp8_dim1, ) empty_grads = None, None From 2791eb3cc811160238044a4873b868c296842d49 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 12:51:08 -0700 Subject: [PATCH 14/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 12 +--- torchao/float8/config.py | 82 ++++++++++++++++++++++ 2 files changed, 85 insertions(+), 9 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 912f4a1c1b..721461dded 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -27,6 +27,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + _get_recipe, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -319,15 +320,8 @@ def main( cast_config_grad_output=cast_config_grad_output, ) - elif recipe_override == "lcw": - scaling_granularities_by_gemm = scaling_granularities_by_gemm_lcw_recipe - config = get_test_float8_linear_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularities_by_gemm, - False, # emulate - ) + elif recipe_override is not None: + config = _get_recipe(recipe_override) scaling_repr = "_".join( [ diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 1470d07930..cd59fa9bc9 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -218,3 +218,85 @@ def __post_init__(self): # Currently, ROCm only supports fnuz variants. # TODO(future PR): move this to Float8LinearConfig use_fnuz_dtype = False + + +# Pre-made recipes for common configurations +# TODO(future PR): go through a round of design on this, and eventually expose +# as a top level public API. +def _get_recipe(recipe_name: str) -> Float8LinearConfig: + if recipe_name == "all_tensorwise": + # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel + return Float8LinearConfig() + + elif recipe_name == "all_axiswise": + # dynamic axiswise scaling with the CUTLASS rowwise kernel + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only + # fast with `use_fast_accum=True`. Note that rowwise scaling is more + # accurate than tensorwise scaling, so the overall impact on accuracy + # of tensorwise vs rowwise taking this flag into account will vary. + gc_o = Float8GemmConfig(use_fast_accum=True) + gc_gi = Float8GemmConfig(use_fast_accum=True) + gc_gw = Float8GemmConfig(use_fast_accum=True) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + gemm_config_output=gc_o, + gemm_config_grad_input=gc_gi, + gemm_config_grad_weight=gc_gw, + ) + + elif recipe_name == "lw_axiswise_with_gw_hp": + + # lw's recipe for a modification on all-axiswise: + # + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + # grad_weight_hp = input_t_hp @ grad_output_hp + # + # key characteristics: + # * increased accuracy for grad_weight + # * `output` and `weight` now only need to be scaled axiswise across a + # single dim compared to vanilla all-axiswise, which is more + # amenable to fast kernels + + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(keep_in_original_precision=True) + cc_go_gw = CastConfig(keep_in_original_precision=True) + + # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only + # fast with `use_fast_accum=True`. Note that rowwise scaling is more + # accurate than tensorwise scaling, so the overall impact on accuracy + # of tensorwise vs rowwise taking this flag into account will vary. + gc_o = Float8GemmConfig(use_fast_accum=True) + gc_gi = Float8GemmConfig(use_fast_accum=True) + gc_gw = Float8GemmConfig(use_fast_accum=True) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + gemm_config_output=gc_o, + gemm_config_grad_input=gc_gi, + gemm_config_grad_weight=gc_gw, + ) + + else: + # TODO(before land): make recipe_name an enum and tell users what the options are + raise AssertionError(f"unknown recipe_name {recipe_name}") From 6cfd1cdd2a400d42dc6ce99e61ccab1ac6bd20fc Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 12:59:40 -0700 Subject: [PATCH 15/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 20 ++++++-------------- torchao/float8/config.py | 3 ++- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 9967de69bc..38e3f1a9d7 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -36,7 +36,10 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_dynamic, + get_maybe_axiswise_dim, +) from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -234,34 +237,23 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): linear_mm_config = LinearMMConfig() - if a_granularity is ScalingGranularity.AXISWISE: - a_axiswise_dim = -1 - else: - assert a_granularity is ScalingGranularity.TENSORWISE - a_axiswise_dim = None a_fp8 = hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=a_granularity, - axiswise_dim=a_axiswise_dim, + axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity), ) a_fp8 = a_fp8.reshape(-1, a_shape[-1]) - b_axiswise_dim = 1 if b_granularity is ScalingGranularity.AXISWISE else None - if b_granularity is ScalingGranularity.AXISWISE: - b_axiswise_dim = 1 # will be transposed - else: - assert b_granularity is ScalingGranularity.TENSORWISE - b_axiswise_dim = None b_fp8 = hp_tensor_to_float8_dynamic( b, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=b_granularity, - axiswise_dim=b_axiswise_dim, + axiswise_dim=get_maybe_axiswise_dim(1, b_granularity), ) c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index cd59fa9bc9..234a9c01e8 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -119,7 +119,8 @@ class Float8LinearConfig: # # Optional per-tensor configuration for `input`, `weight`, `grad_output` to # calculate `grad_weight`, `grad_input`, and `grad_weight` respectively. - # If not specified, then the configuration from the is reused. + # If not specified, then the configuration from `cast_config_input`, + # `cast_config_weight` and `cast_config_grad_output`, respectively, is reused. # TODO(future PR): maybe rename `cast_config_input` to # `cast_config_input_for_output`, etc, to make the names consistent, # will be BC-breaking. From 94907e58bb038da92c5da69ffb30521a082331e3 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 13:46:29 -0700 Subject: [PATCH 16/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 88 ++++++++++----------- test/float8/test_compile.py | 114 +++++++++------------------ torchao/testing/float8/test_utils.py | 44 ----------- 3 files changed, 79 insertions(+), 167 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 38e3f1a9d7..e0628fa171 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -28,6 +28,7 @@ Float8LinearConfig, ScalingGranularity, ScalingType, + _get_recipe, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -336,42 +337,30 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - # @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) - @pytest.mark.parametrize("emulate", [False] if is_cuda_8_9 else [True]) - # @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("x_shape", [(16, 16),]) + @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", - # [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - [ScalingType.DYNAMIC] + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( "scaling_type_weight", - # [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - [ScalingType.DYNAMIC] + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - # [ScalingType.DELAYED, ScalingType.DYNAMIC], - [ScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_granularities_by_gemm", - scaling_granularities_by_gemm + [ScalingType.DELAYED, ScalingType.DYNAMIC], ) - # @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) - @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, ]) - # @pytest.mark.parametrize("linear_bias", [False, True]) - @pytest.mark.parametrize("linear_bias", [False, ]) + @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_linear( + def test_linear_from_config_params( self, x_shape, emulate: bool, scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], linear_dtype: torch.dtype, linear_bias: bool, ): @@ -385,31 +374,6 @@ def test_linear( ) pytest.skip() - ( - (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), - (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), - (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), - ) = scaling_granularities_by_gemm - - has_any_axiswise_scaling = ( - scaling_granularity_input is ScalingGranularity.AXISWISE or - scaling_granularity_weight is ScalingGranularity.AXISWISE or - scaling_granularity_grad_output is ScalingGranularity.AXISWISE or - scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or - scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or - scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE - ) - - if has_any_axiswise_scaling: - if ( - scaling_type_input != ScalingType.DYNAMIC or - scaling_type_weight != ScalingType.DYNAMIC or - scaling_type_grad_output != ScalingType.DYNAMIC or - linear_dtype != torch.bfloat16 or - (not is_cuda_9_0) - ): - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) @@ -417,7 +381,6 @@ def test_linear( scaling_type_input, scaling_type_weight, scaling_type_grad_output, - scaling_granularities_by_gemm, emulate, ) @@ -427,6 +390,39 @@ def test_linear( config, ) + # Note: there are now too many config combinations to test all of + # them, so this function factors out some of the recipes which are annoying + # to combine with the main testing function. + # TODO(future PR): make this cleaner. + @pytest.mark.parametrize( + "recipe_name", + ["all_axiswise", "lw_axiswise_with_gw_hp"], + ) + @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + @pytest.mark.parametrize("linear_bias", [True, False]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_linear_from_recipe( + self, + recipe_name, + x_shape, + linear_bias: bool, + ): + if torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" + ) + pytest.skip() + + linear_dtype = torch.bfloat16 + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + config = _get_recipe(recipe_name) + self._test_linear_impl( + x, + m_ref, + config, + ) + @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 5a41e29728..001b04b0d7 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -24,6 +24,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + _get_recipe, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -119,28 +120,16 @@ def is_supported( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( - # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_input", [ScalingType.DYNAMIC,] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_weight", [ScalingType.DYNAMIC,] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_grad_output", [ScalingType.DYNAMIC,] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) -# @pytest.mark.parametrize( -# "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] -# ) -@pytest.mark.parametrize( - "scaling_granularities_by_gemm", - scaling_granularities_by_gemm -) -# @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) -@pytest.mark.parametrize("emulate", [False, ]) -# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, ]) +@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, @@ -148,24 +137,13 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): - if not is_supported( - scaling_granularities_by_gemm, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, - ): - pytest.skip() - torch._dynamo.reset() config = get_test_float8_linear_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, - scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -177,26 +155,17 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -# @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) -@pytest.mark.parametrize("emulate", [False,]) -@pytest.mark.parametrize( - # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_input", [ScalingType.DYNAMIC,] -) +@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) @pytest.mark.parametrize( - # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_weight", [ScalingType.DYNAMIC,] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_grad_output", [ScalingType.DYNAMIC,] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - "scaling_granularities_by_gemm", - scaling_granularities_by_gemm + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) -# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("dtype", [torch.bfloat16,]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, @@ -204,24 +173,13 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): - if not is_supported( - scaling_granularities_by_gemm, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, - ): - pytest.skip() - torch._dynamo.reset() config = get_test_float8_linear_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, - scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -235,48 +193,29 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize( - # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_input", [ScalingType.DYNAMIC, ] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_weight", [ScalingType.DYNAMIC, ] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] - "scaling_type_grad_output", [ScalingType.DYNAMIC, ] -) -@pytest.mark.parametrize( - "scaling_granularities_by_gemm", - scaling_granularities_by_gemm + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") -# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("dtype", [torch.bfloat16,]) -def test_inductor( +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_inductor_from_config_params( fullgraph, emulate: bool, scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): - if not is_supported( - scaling_granularities_by_gemm, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, - ): - pytest.skip() - torch._dynamo.reset() config = get_test_float8_linear_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, - scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -286,6 +225,27 @@ def test_inductor( dtype, ) +# Note: there are now too many config combinations to test all of +# them, so this function factors out some of the recipes which are annoying +# to combine with the main testing function. +# TODO(future PR): make this cleaner. +@pytest.mark.parametrize( + "recipe_name", + ["all_axiswise", "lw_axiswise_with_gw_hp"], +) +@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") +def test_inductor_from_recipe(recipe_name): + torch._dynamo.reset() + config = _get_recipe(recipe_name) + fullgraph = True + dtype = torch.bfloat16 + _test_compile_base( + "inductor", + fullgraph, + config, + dtype, + ) + class TestGraphBreaks(DynamoTestCase): class MockLinear(torch.nn.Module): diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 6aa5e78e91..a108647b97 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -48,15 +48,8 @@ def get_test_float8_linear_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, - scaling_granularities_by_gemm, emulate: bool, ): - ( - (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), - (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), - (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), - ) = scaling_granularities_by_gemm - static_scale_one = torch.tensor([1.0], device="cuda") if scaling_type_input is ScalingType.STATIC: @@ -74,58 +67,21 @@ def get_test_float8_linear_config( cast_config_input = CastConfig( scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity_input, - static_scale=static_scale_input, - keep_in_original_precision=original_prec_input, - ) - cast_config_input_for_grad_weight = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity_input_for_grad_weight, static_scale=static_scale_input, - keep_in_original_precision=original_prec_input_for_grad_weight, ) - cast_config_weight = CastConfig( scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity_weight, static_scale=static_scale_weight, - keep_in_original_precision=original_prec_weight, ) - cast_config_weight_for_grad_input = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity_weight_for_grad_input, - static_scale=static_scale_weight, - keep_in_original_precision=original_prec_weight_for_grad_input, - ) - cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity_grad_output, - static_scale=static_scale_grad_output, - keep_in_original_precision=original_prec_grad_output, - ) - cast_config_grad_output_for_grad_weight = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity_grad_output_for_grad_weight, static_scale=static_scale_grad_output, - keep_in_original_precision=original_prec_grad_output_for_grad_weight, ) - gemm_config_output = Float8GemmConfig(use_fast_accum=True) - # TODO(this PR): toggle fast accum by axiswise scaling presence - gemm_config_grad_input = Float8GemmConfig(use_fast_accum=True) - gemm_config_grad_weight = Float8GemmConfig(use_fast_accum=True) - config = Float8LinearConfig( cast_config_input=cast_config_input, cast_config_weight=cast_config_weight, cast_config_grad_output=cast_config_grad_output, - cast_config_input_for_grad_weight=cast_config_input_for_grad_weight, - cast_config_weight_for_grad_input=cast_config_weight_for_grad_input, - cast_config_grad_output_for_grad_weight=cast_config_grad_output_for_grad_weight, - gemm_config_output=gemm_config_output, - gemm_config_grad_input=gemm_config_grad_input, - gemm_config_grad_weight=gemm_config_grad_weight, emulate=emulate, ) return config From 423760af9cf03fdcafd7c4bdfecbd03995632f9e Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 13:49:52 -0700 Subject: [PATCH 17/32] Update [ghstack-poisoned] --- test/float8/test_compile.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 001b04b0d7..387d96c3a3 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -23,7 +23,6 @@ CastConfig, Float8LinearConfig, ScalingType, - ScalingGranularity, _get_recipe, ) from torchao.float8.float8_linear import Float8Linear @@ -83,41 +82,6 @@ def _test_compile_base( torch.testing.assert_close(x.grad, x_ref.grad, atol=8e-2, rtol=8e-2) -def is_supported( - scaling_granularities_by_gemm, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, -) -> bool: - - ( - (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), - (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), - (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), - ) = scaling_granularities_by_gemm - - has_any_axiswise_scaling = ( - scaling_granularity_input is ScalingGranularity.AXISWISE or - scaling_granularity_weight is ScalingGranularity.AXISWISE or - scaling_granularity_grad_output is ScalingGranularity.AXISWISE or - scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or - scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or - scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE - ) - - if has_any_axiswise_scaling: - if ( - scaling_type_input != ScalingType.DYNAMIC or - scaling_type_weight != ScalingType.DYNAMIC or - scaling_type_grad_output != ScalingType.DYNAMIC or - dtype != torch.bfloat16 or - (not is_H100) - ): - return False - return True - - @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] From 552db23d7aca79ea39e168ba08cbd664014bd3a8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 13:51:21 -0700 Subject: [PATCH 18/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 5 +---- test/float8/test_compile.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e0628fa171..9e5d8ddada 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -56,10 +56,7 @@ FP8_TYPES, tensor_to_scale, ) -from torchao.testing.float8.test_utils import ( - scaling_granularities_by_gemm, - get_test_float8_linear_config, -) +from torchao.testing.float8.test_utils import get_test_float8_linear_config random.seed(0) torch.manual_seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 387d96c3a3..27a9680e9d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -34,10 +34,7 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed from torchao.float8.float8_tensor import LinearMMConfig from torchao.float8.float8_utils import e4m3_dtype -from torchao.testing.float8.test_utils import ( - scaling_granularities_by_gemm, - get_test_float8_linear_config, -) +from torchao.testing.float8.test_utils import get_test_float8_linear_config, from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend From 953bc2fd47bb6895a5c117c3dd61f10b04a4c536 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 13:55:08 -0700 Subject: [PATCH 19/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 1 - test/float8/test_compile.py | 2 +- torchao/testing/float8/test_utils.py | 37 ---------------------- 3 files changed, 1 insertion(+), 39 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 721461dded..136a8b6921 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -263,7 +263,6 @@ def main( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", scaling_granularity: str = "tensorwise", - # TODO(future PR): clean up the override, it's confusing recipe_override: Optional[str] = None, model_type: str = "linear", dtype_filter: str = "both", diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 27a9680e9d..7ba0fb4f56 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -34,7 +34,7 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed from torchao.float8.float8_tensor import LinearMMConfig from torchao.float8.float8_utils import e4m3_dtype -from torchao.testing.float8.test_utils import get_test_float8_linear_config, +from torchao.testing.float8.test_utils import get_test_float8_linear_config from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index a108647b97..7f37c3f30a 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -4,45 +4,8 @@ ScalingType, CastConfig, Float8LinearConfig, - Float8GemmConfig, ) -scaling_granularities_by_gemm_lcw_recipe = [ - # @lcw's recipe - # output = input @ weight_t - # input: axiswise - # weight_t: axiswise - (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - # grad_input = grad_output @ weight - # grad_output: axiswise - # weight: tensorwise (but that can be computed from axiswise done in the forward) - (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), - # grad_weight = input_t @ grad_output, in high precision (bfloat16) - # input_t: high precision - # grad_output: high precision - (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), -] - -scaling_granularities_by_gemm_all_tensorwise = [ - (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), - (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), - (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), -] - -scaling_granularities_by_gemm_all_axiswise = [ - (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), - (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), -] - -# scaling granularity and keep_in_original_precision to test by gemm arguments in this -# order: output, grad_input, grad_weight -scaling_granularities_by_gemm = [ - # TODO(before land): move this last - scaling_granularities_by_gemm_lcw_recipe, - # scaling_granularities_by_gemm_all_tensorwise, - # scaling_granularities_by_gemm_all_axiswise, -] def get_test_float8_linear_config( scaling_type_input, From 10f2e0fbad3066f3489814005a78729729f58ecf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 15:55:49 -0700 Subject: [PATCH 20/32] Update [ghstack-poisoned] --- test/float8/test_numerics_integration.py | 124 +++++++++-------------- torchao/float8/float8_utils.py | 14 +-- 2 files changed, 53 insertions(+), 85 deletions(-) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 07fcddaad6..2ec4d3999f 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -24,6 +24,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + _get_recipe, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -31,6 +32,7 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import compute_error, IS_ROCM +from torchao.testing.float8.test_utils import get_test_float8_linear_config is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -84,44 +86,9 @@ def init_weights(self, init_std: float): class TestFloat8NumericsIntegrationTest: - @pytest.mark.parametrize( - "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_granularity", - [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], - ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") - @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") - def test_encoder_fw_bw( - self, - scaling_type_input: ScalingType, - scaling_type_weight: ScalingType, - scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, - ): - # TODO(later): maybe add float16 back if it becomes important - data_dtype = torch.bfloat16 - - if scaling_granularity is ScalingGranularity.AXISWISE: - if ( - scaling_type_input != ScalingType.DYNAMIC or - scaling_type_weight != ScalingType.DYNAMIC or - scaling_type_grad_output != ScalingType.DYNAMIC or - data_dtype != torch.bfloat16 or - (not is_cuda_9_0) - ): - pytest.skip() + def _test_impl(self, config: Float8LinearConfig) -> None: + data_dtype = torch.bfloat16 # LLaMa 3 70B shapes model_ref = ( FeedForward( @@ -137,44 +104,6 @@ def test_encoder_fw_bw( # for now just test the encoder to simplify things model_fp8 = copy.deepcopy(model_ref) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - ) - convert_to_float8_training( model_fp8, config=config, @@ -212,9 +141,9 @@ def test_encoder_fw_bw( out_sqnr = compute_error(model_ref_out, model_fp8_out) any_static_scaling = ( - scaling_type_input is ScalingType.STATIC - or scaling_type_weight is ScalingType.STATIC - or scaling_type_grad_output is ScalingType.STATIC + config.cast_config_input.scaling_type is ScalingType.STATIC + or config.cast_config_weight.scaling_type is ScalingType.STATIC + or config.cast_config_grad_output.scaling_type is ScalingType.STATIC ) if any_static_scaling: assert out_sqnr > 10.0 @@ -236,6 +165,45 @@ def test_encoder_fw_bw( sqnr = compute_error(ref_grad, cur_grad) assert sqnr > grad_sqnr_threshold + @pytest.mark.parametrize( + "scaling_type_input", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + ) + @pytest.mark.parametrize( + "scaling_type_weight", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + ) + @pytest.mark.parametrize( + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + ) + @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") + def test_encoder_fw_bw_from_config_params( + self, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + ): + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + emulate=False, + ) + self._test_impl(config) + + @pytest.mark.parametrize( + "recipe_name", + ["all_axiswise", "lw_axiswise_with_gw_hp"], + ) + def test_encoder_fw_bw_from_recipe( + self, + recipe_name: str, + ): + config = _get_recipe(recipe_name) + self._test_impl(config) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index d4ac3950df..812ba776b5 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -100,9 +100,9 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax( - x: torch.Tensor, - reduce_amax: bool = False, - device_mesh = None, + x: torch.Tensor, + reduce_amax: bool = False, + device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, ) -> torch.Tensor: @@ -133,10 +133,10 @@ def tensor_to_scale( axiswise_dim: Optional[int] = None, ) -> torch.Tensor: amax = tensor_to_amax( - x, - reduce_amax, - device_mesh, - scaling_granularity, + x, + reduce_amax, + device_mesh, + scaling_granularity, axiswise_dim, ) return amax_to_scale(amax, float8_dtype, x.dtype) From 4437054cb3afe325a6e68bca58224d23421633fc Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 30 Sep 2024 12:39:29 -0700 Subject: [PATCH 21/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 4 ++-- test/float8/test_base.py | 4 ++-- test/float8/test_compile.py | 4 ++-- test/float8/test_numerics_integration.py | 4 ++-- torchao/float8/config.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 136a8b6921..a110cfcf8c 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -27,7 +27,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, - _get_recipe, + _recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -320,7 +320,7 @@ def main( ) elif recipe_override is not None: - config = _get_recipe(recipe_override) + config = _recipe_name_to_linear_config(recipe_override) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index fa4314aed2..4154148716 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -28,7 +28,7 @@ Float8LinearConfig, ScalingGranularity, ScalingType, - _get_recipe, + _recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -403,7 +403,7 @@ def test_linear_from_recipe( linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - config = _get_recipe(recipe_name) + config = _recipe_name_to_linear_config(recipe_name) self._test_linear_impl( x, m_ref, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 63ab0dd277..8d936aea17 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -23,7 +23,7 @@ CastConfig, Float8LinearConfig, ScalingType, - _get_recipe, + _recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -197,7 +197,7 @@ def test_inductor_from_config_params( @unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() - config = _get_recipe(recipe_name) + config = _recipe_name_to_linear_config(recipe_name) fullgraph = True dtype = torch.bfloat16 _test_compile_base( diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 2ec4d3999f..882aeb4922 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -24,7 +24,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, - _get_recipe, + _recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -201,7 +201,7 @@ def test_encoder_fw_bw_from_recipe( self, recipe_name: str, ): - config = _get_recipe(recipe_name) + config = _recipe_name_to_linear_config(recipe_name) self._test_impl(config) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 234a9c01e8..9f4435c0e5 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -224,7 +224,7 @@ def __post_init__(self): # Pre-made recipes for common configurations # TODO(future PR): go through a round of design on this, and eventually expose # as a top level public API. -def _get_recipe(recipe_name: str) -> Float8LinearConfig: +def _recipe_name_to_linear_config(recipe_name: str) -> Float8LinearConfig: if recipe_name == "all_tensorwise": # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel return Float8LinearConfig() From 31a017bcd2d3f54a0172ef303fbdef9fcfe0ef73 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 30 Sep 2024 12:56:40 -0700 Subject: [PATCH 22/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 14 ++++++-------- test/float8/test_base.py | 3 ++- test/float8/test_compile.py | 3 ++- test/float8/test_numerics_integration.py | 3 ++- torchao/float8/config.py | 21 +++++++++++++++++---- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index a110cfcf8c..a42d4467c7 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -27,6 +27,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + _Float8LinearRecipeName, _recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( @@ -34,10 +35,6 @@ linear_requires_sync, sync_float8_amax_and_scale_history, ) -from torchao.testing.float8.test_utils import ( - scaling_granularities_by_gemm_lcw_recipe, - get_test_float8_linear_config, -) from torch.profiler import profile, ProfilerActivity, record_function from utils import ( kernel_name_to_category, @@ -263,7 +260,7 @@ def main( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", scaling_granularity: str = "tensorwise", - recipe_override: Optional[str] = None, + recipe_name: Optional[str] = None, model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -277,7 +274,7 @@ def main( scaling_type_grad_output = ScalingType(scaling_type_grad_output) scaling_granularity = ScalingGranularity(scaling_granularity) - if recipe_override is None: + if recipe_name is None: if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( @@ -319,8 +316,9 @@ def main( cast_config_grad_output=cast_config_grad_output, ) - elif recipe_override is not None: - config = _recipe_name_to_linear_config(recipe_override) + elif recipe_name is not None: + recipe_name = _Float8LinearRecipeName(recipe_name) + config = _recipe_name_to_linear_config(recipe_name) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 4154148716..30a3755d04 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -28,6 +28,7 @@ Float8LinearConfig, ScalingGranularity, ScalingType, + _Float8LinearRecipeName, _recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -383,7 +384,7 @@ def test_linear_from_config_params( # TODO(future PR): make this cleaner. @pytest.mark.parametrize( "recipe_name", - ["all_axiswise", "lw_axiswise_with_gw_hp"], + [_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8d936aea17..3e0aabd166 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -23,6 +23,7 @@ CastConfig, Float8LinearConfig, ScalingType, + _Float8LinearRecipeName, _recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -192,7 +193,7 @@ def test_inductor_from_config_params( # TODO(future PR): make this cleaner. @pytest.mark.parametrize( "recipe_name", - ["all_axiswise", "lw_axiswise_with_gw_hp"], + [_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], ) @unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") def test_inductor_from_recipe(recipe_name): diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 882aeb4922..6cfb860af7 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -24,6 +24,7 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + _Float8LinearRecipeName, _recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( @@ -195,7 +196,7 @@ def test_encoder_fw_bw_from_config_params( @pytest.mark.parametrize( "recipe_name", - ["all_axiswise", "lw_axiswise_with_gw_hp"], + [_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], ) def test_encoder_fw_bw_from_recipe( self, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 9f4435c0e5..55b93f5687 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -224,12 +224,25 @@ def __post_init__(self): # Pre-made recipes for common configurations # TODO(future PR): go through a round of design on this, and eventually expose # as a top level public API. -def _recipe_name_to_linear_config(recipe_name: str) -> Float8LinearConfig: - if recipe_name == "all_tensorwise": +class _Float8LinearRecipeName(enum.Enum): + ALL_TENSORWISE = "all_tensorwise" + ALL_AXISWISE = "all_axiswise" + LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" + + +def _recipe_name_to_linear_config( + recipe_name: _Float8LinearRecipeName, +) -> Float8LinearConfig: + """ + Input: `_Float8LinearRecipeName` value + Output: a `Float8LinearConfig` configured to implement the recipe + """ + + if recipe_name is _Float8LinearRecipeName.ALL_TENSORWISE: # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel return Float8LinearConfig() - elif recipe_name == "all_axiswise": + elif recipe_name is _Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) @@ -252,7 +265,7 @@ def _recipe_name_to_linear_config(recipe_name: str) -> Float8LinearConfig: gemm_config_grad_weight=gc_gw, ) - elif recipe_name == "lw_axiswise_with_gw_hp": + elif recipe_name is _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: # lw's recipe for a modification on all-axiswise: # From 743b4c185192732795f1037c2f6f700088eec5bc Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 30 Sep 2024 12:57:47 -0700 Subject: [PATCH 23/32] Update [ghstack-poisoned] --- torchao/float8/config.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 55b93f5687..0187ec134d 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -195,15 +195,6 @@ def __post_init__(self): cc_w_gi = self.cast_config_weight_for_grad_input cc_go_gw = self.cast_config_grad_output_for_grad_weight - # for now, we only have gemm kernels where both operands are scaled with the same - # granularity. In the future this may be relaxed. - assert cc_i.scaling_granularity == cc_w.scaling_granularity, \ - "incompatible scaling granularity for output" - # assert cc_go.scaling_granularity == cc_w_gi.scaling_granularity, \ - # "incompatible scaling granularity for grad_input" - assert cc_i_gw.scaling_granularity == cc_go_gw.scaling_granularity, \ - "incompatible scaling granularity for grad_weight" - # for now, we only have gemm kernels where both operands are either both # in high precision, or both in float8. In the future, this may be relaxed. # TODO(future): make the float8 check more precise with the specific dtypes. From fc8d4efc031c16453f8430cbb7a172a1121a819a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 2 Oct 2024 16:06:33 -0700 Subject: [PATCH 24/32] Update [ghstack-poisoned] --- torchao/ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/ops.py b/torchao/ops.py index 99e19dbbd4..79c02dfd85 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -5,11 +5,10 @@ lib = torch.library.Library("torchao", "FRAGMENT") -# TODO(before land): undo this, this is to work around https://github.com/pytorch/ao/issues/991 -# lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") -# lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") -# lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") -# lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") +lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") +lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") +lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") +lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") def register_custom_op(name): From ac6f768ba2fbc189381967e35bd0796d60256a06 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 2 Oct 2024 16:09:58 -0700 Subject: [PATCH 25/32] Update [ghstack-poisoned] --- torchao/float8/float8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index acb814404c..b6f42c5081 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -8,9 +8,9 @@ import torch import torch.distributed as dist -from torchao.float8.config import ScalingGranularity import torchao.float8.config as config +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html From 4bb59a6b083f2d561726344320510235da5dae46 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 4 Oct 2024 10:11:55 -0700 Subject: [PATCH 26/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 48 +++------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index a42d4467c7..d3b5b40823 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -35,6 +35,7 @@ linear_requires_sync, sync_float8_amax_and_scale_history, ) +from torchao.testing.float8.test_utils import get_test_float8_linear_config from torch.profiler import profile, ProfilerActivity, record_function from utils import ( kernel_name_to_category, @@ -259,7 +260,6 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", - scaling_granularity: str = "tensorwise", recipe_name: Optional[str] = None, model_type: str = "linear", dtype_filter: str = "both", @@ -272,50 +272,14 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) - scaling_granularity = ScalingGranularity(scaling_granularity) if recipe_name is None: - - if scaling_type_input is ScalingType.STATIC: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + emulate=False, ) - elif recipe_name is not None: recipe_name = _Float8LinearRecipeName(recipe_name) config = _recipe_name_to_linear_config(recipe_name) From c1c218f7c4cc6218daeca189d2cd8c948ac8e7e3 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 4 Oct 2024 10:37:46 -0700 Subject: [PATCH 27/32] Update [ghstack-poisoned] --- benchmarks/float8/profile_linear_float8.py | 8 ++++---- test/float8/test_base.py | 8 ++++---- test/float8/test_compile.py | 8 ++++---- test/float8/test_numerics_integration.py | 8 ++++---- torchao/float8/config.py | 20 ++++++++++---------- 5 files changed, 26 insertions(+), 26 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index d3b5b40823..f4f2813a37 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -27,8 +27,8 @@ Float8LinearConfig, ScalingType, ScalingGranularity, - _Float8LinearRecipeName, - _recipe_name_to_linear_config, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -281,8 +281,8 @@ def main( emulate=False, ) elif recipe_name is not None: - recipe_name = _Float8LinearRecipeName(recipe_name) - config = _recipe_name_to_linear_config(recipe_name) + recipe_name = Float8LinearRecipeName(recipe_name) + config = recipe_name_to_linear_config(recipe_name) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 66d152f2dc..c383fbc873 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -28,8 +28,8 @@ Float8LinearConfig, ScalingGranularity, ScalingType, - _Float8LinearRecipeName, - _recipe_name_to_linear_config, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -378,7 +378,7 @@ def test_linear_from_config_params( # TODO(future PR): make this cleaner. @pytest.mark.parametrize( "recipe_name", - [_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @@ -398,7 +398,7 @@ def test_linear_from_recipe( linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - config = _recipe_name_to_linear_config(recipe_name) + config = recipe_name_to_linear_config(recipe_name) self._test_linear_impl( x, m_ref, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c235685e6f..7c445f8803 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -23,8 +23,8 @@ CastConfig, Float8LinearConfig, ScalingType, - _Float8LinearRecipeName, - _recipe_name_to_linear_config, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -200,12 +200,12 @@ def test_inductor_from_config_params( # TODO(future PR): make this cleaner. @pytest.mark.parametrize( "recipe_name", - [_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], ) @unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() - config = _recipe_name_to_linear_config(recipe_name) + config = recipe_name_to_linear_config(recipe_name) fullgraph = True dtype = torch.bfloat16 _test_compile_base( diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 0cdd3a56dd..a91b784c85 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -24,8 +24,8 @@ Float8LinearConfig, ScalingType, ScalingGranularity, - _Float8LinearRecipeName, - _recipe_name_to_linear_config, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -196,7 +196,7 @@ def test_encoder_fw_bw_from_config_params( @pytest.mark.parametrize( "recipe_name", - [_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], ) @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") @@ -204,7 +204,7 @@ def test_encoder_fw_bw_from_recipe( self, recipe_name: str, ): - config = _recipe_name_to_linear_config(recipe_name) + config = recipe_name_to_linear_config(recipe_name) self._test_impl(config) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 37a017683d..2b0250c7ef 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -233,25 +233,25 @@ def __post_init__(self): # Pre-made recipes for common configurations # TODO(future PR): go through a round of design on this, and eventually expose # as a top level public API. -class _Float8LinearRecipeName(enum.Enum): +class Float8LinearRecipeName(enum.Enum): ALL_TENSORWISE = "all_tensorwise" ALL_AXISWISE = "all_axiswise" LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" -def _recipe_name_to_linear_config( - recipe_name: _Float8LinearRecipeName, +def recipe_name_to_linear_config( + recipe_name: Float8LinearRecipeName, ) -> Float8LinearConfig: """ - Input: `_Float8LinearRecipeName` value + Input: `Float8LinearRecipeName` value Output: a `Float8LinearConfig` configured to implement the recipe """ - if recipe_name is _Float8LinearRecipeName.ALL_TENSORWISE: + if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE: # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel return Float8LinearConfig() - elif recipe_name is _Float8LinearRecipeName.ALL_AXISWISE: + elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) @@ -274,7 +274,7 @@ def _recipe_name_to_linear_config( gemm_config_grad_weight=gc_gw, ) - elif recipe_name is _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: + elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: # lw's recipe for a modification on all-axiswise: # @@ -284,9 +284,9 @@ def _recipe_name_to_linear_config( # # key characteristics: # * increased accuracy for grad_weight - # * `output` and `weight` now only need to be scaled axiswise across a - # single dim compared to vanilla all-axiswise, which is more - # amenable to fast kernels + # * `input`, `weight` and `grad_output` now only need to be scaled + # axiswise across a single dim compared to vanilla all-axiswise, + # which is more amenable to fast kernels # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) From 024fe94fd86a94f43be1f0dd62201bd92a313dcb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 4 Oct 2024 11:19:04 -0700 Subject: [PATCH 28/32] Update [ghstack-poisoned] --- torchao/float8/config.py | 15 ++++++--- torchao/float8/float8_linear.py | 51 +++++++++++++---------------- torchao/float8/float8_utils.py | 58 --------------------------------- 3 files changed, 33 insertions(+), 91 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 2b0250c7ef..bb0ed3ad76 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -103,7 +103,7 @@ class Float8GemmConfig: use_fast_accum: bool = False -@dataclass(frozen=False) +@dataclass(frozen=True) class Float8LinearConfig: """ Configuration for converting a `torch.nn.Linear` module to float8 @@ -192,13 +192,18 @@ class Float8LinearConfig: force_recompute_fp8_weight_in_bwd: bool = False def __post_init__(self): - # populate the additional cast overrides, if the user did not specify them + # Populate the additional cast overrides, if the user did not specify them + # Note: this hacks around the frozen-ness of this dataclass + # by using `object.__setattr__`. This is fine, as what we really need + # is for this object to be frozen after `__post_init__` for torch.compile + # to work. + # Source of hack: https://stackoverflow.com/a/65959419/ if self.cast_config_input_for_grad_weight is None: - self.cast_config_input_for_grad_weight = self.cast_config_input + object.__setattr__(self, "cast_config_input_for_grad_weight", self.cast_config_input) if self.cast_config_weight_for_grad_input is None: - self.cast_config_weight_for_grad_input = self.cast_config_weight + object.__setattr__(self, "cast_config_weight_for_grad_input", self.cast_config_weight) if self.cast_config_grad_output_for_grad_weight is None: - self.cast_config_grad_output_for_grad_weight = self.cast_config_grad_output + object.__setattr__(self, "cast_config_grad_output_for_grad_weight", self.cast_config_grad_output) # float8 all-gather only supports tensorwise, in the future may support blockwise if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index e8a5bf4191..b99e5274d4 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -42,8 +42,6 @@ e5m2_dtype, tensor_to_amax, tensor_to_scale, - float8_linear_config_to_concise_casts_config, - Float8LinearConciseCastsConfig, ) from torchao.float8.fsdp_utils import ( @@ -131,15 +129,15 @@ def forward( input_hp: torch.Tensor, weight_hp_t: torch.Tensor, linear_mm_config: LinearMMConfig, - concise_casts_config: Float8LinearConciseCastsConfig, + config: Float8LinearConfig, ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config - ctx.concise_casts_config = concise_casts_config + ctx.config = config - c = concise_casts_config + c = config - if c.cc_i.orig_prec: + if c.cast_config_input.keep_in_original_precision: input_maybe_fp8 = input_hp else: input_maybe_fp8 = hp_tensor_to_float8_dynamic( @@ -147,11 +145,11 @@ def forward( e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=c.cc_i.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_i.sc_gr), + scaling_granularity=c.cast_config_input.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cast_config_input.scaling_granularity), ) - if c.cc_w.orig_prec: + if c.cast_config_weight.keep_in_original_precision: weight_maybe_fp8_t = weight_hp_t else: weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( @@ -159,8 +157,8 @@ def forward( e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=c.cc_w.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(0, c.cc_w.sc_gr), + scaling_granularity=c.cast_config_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(0, c.cast_config_weight.scaling_granularity), ) # the reshapes are needed in order to make the shapes compatible with @@ -174,7 +172,7 @@ def forward( @staticmethod def backward(ctx, grad_output): input_hp, weight_hp_t = ctx.saved_tensors - c = ctx.concise_casts_config + c = ctx.config # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -187,7 +185,7 @@ def backward(ctx, grad_output): # calculate grad_input # - if c.cc_go.orig_prec: + if c.cast_config_grad_output.keep_in_original_precision: grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped else: grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( @@ -195,11 +193,11 @@ def backward(ctx, grad_output): e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=c.cc_go.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_go.sc_gr), + scaling_granularity=c.cast_config_grad_output.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cast_config_grad_output.scaling_granularity), ) - if c.cc_w_gi.orig_prec: + if c.cast_config_weight_for_grad_input.keep_in_original_precision: weight_t_maybe_fp8_dim0 = weight_hp_t else: # Note: we need https://github.com/pytorch/pytorch/issues/136267 @@ -211,8 +209,8 @@ def backward(ctx, grad_output): e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=c.cc_w_gi.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_w_gi.sc_gr), + scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cast_config_weight_for_grad_input.scaling_granularity), ) grad_input = torch.mm( @@ -230,7 +228,7 @@ def backward(ctx, grad_output): # calculate grad_weight # - if c.cc_go_gw.orig_prec: + if c.cast_config_grad_output_for_grad_weight.keep_in_original_precision: grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped else: grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( @@ -238,11 +236,11 @@ def backward(ctx, grad_output): e5m2_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=c.cc_go_gw.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(0, c.cc_go_gw.sc_gr), + scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(0, c.cast_config_grad_output_for_grad_weight.scaling_granularity), ) - if c.cc_i_gw.orig_prec: + if c.cast_config_input_for_grad_weight.keep_in_original_precision: input_reshaped_maybe_fp8_dim1 = input_hp_reshaped else: input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( @@ -250,8 +248,8 @@ def backward(ctx, grad_output): e4m3_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=c.cc_i_gw.sc_gr, - axiswise_dim=get_maybe_axiswise_dim(0, c.cc_i_gw.sc_gr), + scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(0, c.cast_config_input_for_grad_weight.scaling_granularity), ) grad_weight = torch.mm( @@ -346,9 +344,6 @@ def __init__(self, *args, **kwargs): # would be initialized in every iteration. self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward - self.concise_casts_config: Float8LinearConciseCastsConfig = \ - float8_linear_config_to_concise_casts_config(self.config) - def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len @@ -596,7 +591,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.weight.t(), self.linear_mm_config, - self.concise_casts_config, + self.config, ) if self.bias is not None: diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 1ee0101ad1..a0c9c08ed4 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -264,61 +264,3 @@ def pad_tensor_for_matmul( pad_dim2 = dim2_aligned - dim2 return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) - - -# The code below introduces a bit of duplication with Float8LinearConfig in -# order to improve readability of the implementation of how Float8Linear -# uses the config. Specifically, we do two things: -# 1. wrap the relevant parts of configs in namedtuple, so we can pass -# them around in compile-friendly code. -# 2. make the tuple key names more brief, to make the implementation -# code less verbose (the code was so verbose that I felt the need -# to add this workaround). -# As I was writing this, it became less and less clear on why not just have -# a namedtuple as a top level config. Punting that to a future PR as -# that might be BC-breaking, but probably worth exploring. -# Note: I also think below is pretty hacky, it's good enough to unblock -# further prototyping, but IMO pretty important to clean up sooner rather -# than later. - - -class ConciseCastConfig(NamedTuple): - sc_tp: config.ScalingType - sc_gr: config.ScalingGranularity - st_sc: Optional[torch.Tensor] - orig_prec: bool - - @classmethod - def from_cast_config(cls, c: config.CastConfig): - return cls( - sc_tp=c.scaling_type, - sc_gr=c.scaling_granularity, - st_sc=c.static_scale, - orig_prec=c.keep_in_original_precision, - ) - - -class Float8LinearConciseCastsConfig(NamedTuple): - cc_i: ConciseCastConfig - cc_w: ConciseCastConfig - cc_go: ConciseCastConfig - cc_i_gw: ConciseCastConfig - cc_w_gi: ConciseCastConfig - cc_go_gw: ConciseCastConfig - - -def float8_linear_config_to_concise_casts_config( - c: config.Float8LinearConfig, -) -> Float8LinearConciseCastsConfig: - concise_config = Float8LinearConciseCastsConfig( - cc_i=ConciseCastConfig.from_cast_config(c.cast_config_input), - cc_w=ConciseCastConfig.from_cast_config(c.cast_config_weight), - cc_go=ConciseCastConfig.from_cast_config(c.cast_config_grad_output), - cc_i_gw=ConciseCastConfig.from_cast_config(c.cast_config_input_for_grad_weight), - cc_w_gi=ConciseCastConfig.from_cast_config(c.cast_config_weight_for_grad_input), - cc_go_gw=ConciseCastConfig.from_cast_config( - c.cast_config_grad_output_for_grad_weight - ), - ) - - return concise_config From 402769458b889efb21ec0b4bab0ec838e61cbe44 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 4 Oct 2024 12:11:25 -0700 Subject: [PATCH 29/32] Update [ghstack-poisoned] --- torchao/float8/config.py | 32 ++++++++++++++++++-------------- torchao/float8/float8_linear.py | 14 ++++++++------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index bb0ed3ad76..f9c0373d8d 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -111,23 +111,28 @@ class Float8LinearConfig: """ # - # Per-tensor configuration for `input`, `weight`, `grad_output` + # Per-tensor configuration for casting of `input`, `weight`, `grad_output` + # for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`. # - cast_config_input: CastConfig = CastConfig() - cast_config_weight: CastConfig = CastConfig() - cast_config_grad_output: CastConfig = CastConfig() - - # - # Optional per-tensor configuration for `input`, `weight`, `grad_output` to - # calculate `grad_weight`, `grad_input`, and `grad_weight` respectively. - # If not specified, then the configuration from `cast_config_input`, - # `cast_config_weight` and `cast_config_grad_output`, respectively, is reused. - # TODO(future PR): maybe rename `cast_config_input` to - # `cast_config_input_for_output`, etc, to make the names consistent, - # will be BC-breaking. + # Note: + # 1. if `cast_config_input_for_grad_weight` is None, then + # `cast_config_input` is used for scaling `input` for both gemms that + # use `input. + # 2. if `cast_config_input_for_grad_weight` is specified, then + # a. `cast_config_input` is used for scaling `input` for the gemm that calculates + # `output` + # b. `cast_config_input_for_grad_weight` is used for scaling `input` for + # the gemm that calculates `grad_weight` + # 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`. # + # `input` + cast_config_input: CastConfig = CastConfig() cast_config_input_for_grad_weight: Optional[CastConfig] = None + # `weight` + cast_config_weight: CastConfig = CastConfig() cast_config_weight_for_grad_input: Optional[CastConfig] = None + # `grad_output` + cast_config_grad_output: CastConfig = CastConfig() cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None # @@ -326,5 +331,4 @@ def recipe_name_to_linear_config( ) else: - # TODO(before land): make recipe_name an enum and tell users what the options are raise AssertionError(f"unknown recipe_name {recipe_name}") diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index b99e5274d4..ef2f30d276 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -554,14 +554,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - # TODO(this PR): reuse with config, make a property - has_all_axiswise_scaling = ( - self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and - self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and - self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + has_any_axiswise_scaling = ( + self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_input_for_grad_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_weight_for_grad_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_grad_output_for_grad_weight.scaling_granularity is ScalingGranularity.AXISWISE ) - if not has_all_axiswise_scaling: + if not has_any_axiswise_scaling: input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, # weight_scale should be saved. From 076de91110aa7d5663d8ebabad7161b042674d43 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 4 Oct 2024 12:29:43 -0700 Subject: [PATCH 30/32] Update [ghstack-poisoned] --- torchao/float8/float8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index a0c9c08ed4..b6f42c5081 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, NamedTuple, Optional, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import torch import torch.distributed as dist From ca127f01b44c13bb8dbef71b5bdec4a19d223167 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 4 Oct 2024 12:53:55 -0700 Subject: [PATCH 31/32] Update [ghstack-poisoned] --- test/float8/test_base.py | 2 +- torchao/float8/config.py | 5 +++++ torchao/float8/float8_linear.py | 31 +++++++++++++------------------ 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c383fbc873..478f89bf49 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -508,7 +508,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn,w:del,go:dyn" in s + assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") def test_inference_mode(self): diff --git a/torchao/float8/config.py b/torchao/float8/config.py index f9c0373d8d..17d8793429 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -59,6 +59,11 @@ class CastConfig: # TODO(ideally before this PR lands): a better name for this keep_in_original_precision: bool = False + def short_str(self): + if self.keep_in_original_precision: + return "orig_prec" + return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" + def __post_init__(self): if self.scaling_type is ScalingType.STATIC: assert ( diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index ef2f30d276..4838f170e7 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -603,25 +603,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_post_forward() return output - def scaling_type_repr(self): - # add scaling type settings without using too many characters - # example: "i:del,w:del,go:dyn" - return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" - - def scaling_granularity_repr(self): - # add scaling granularity settings without using too many characters - # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" - c = self.config - gi = c.cast_config_input.scaling_granularity.short_str() - gw = c.cast_config_weight.scaling_granularity.short_str() - ggo = c.cast_config_grad_output.scaling_granularity.short_str() - gi2 = c.cast_config_input_for_grad_weight.scaling_granularity.short_str() - gw2 = c.cast_config_weight_for_grad_input.scaling_granularity.short_str() - ggo2 = c.cast_config_grad_output_for_grad_weight.scaling_granularity.short_str() - return f"i:{gi},w:{gw},go:{ggo},i2:{gi2},w2:{gw2},go2:{ggo2}" - def extra_repr(self): - s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' + c = self.config + ci = f"i:{c.cast_config_input.short_str()}" + cw = f"w:{c.cast_config_weight.short_str()}" + cgo = f"go:{c.cast_config_grad_output.short_str()}" + parts = [ci, cw, cgo] + if c.cast_config_input_for_grad_weight != c.cast_config_input: + parts.append(f"i_gw:{c.cast_config_input_for_grad_weight.short_str()}") + if c.cast_config_weight_for_grad_input != c.cast_config_weight: + parts.append(f"w_gi:{c.cast_config_weight_for_grad_input.short_str()}") + if c.cast_config_grad_output_for_grad_weight != c.cast_config_grad_output: + parts.append(f"go_gw:{c.cast_config_grad_output_for_grad_weight.short_str()}") + cast_config_str = ",".join(parts) + s = f'{super().extra_repr()}, cast_configs={cast_config_str}"' return s @classmethod From 712fd5d212a9cf85ef4e63f3550ba44e29d67a5f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Sat, 5 Oct 2024 13:27:53 -0700 Subject: [PATCH 32/32] Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 53 ++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2f04b8ee8d..19c6cc21bc 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -70,6 +70,7 @@ ScalingType, CastConfig, ) +from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName class LNLinearSigmoid(torch.nn.Module): @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None): else: # cache does not exist yet, create it cache = dict() + else: + cache = dict() key = f"{M},{K},{N},{fast_accum}" if key in cache: return cache[key] @@ -153,13 +156,18 @@ def do_matmul(A, B): ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) + fast_accum = True # for axiswise + f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + # save to cache if needed if cache_filename is not None: - cache[key] = [bf16_time_s, f8_time_s] + cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s] with open(cache_filename, 'w') as f: json.dump(cache, f) - return bf16_time_s, f8_time_s + return bf16_time_s, f8_time_s, f8_axs_time_s def run( outfile: str, @@ -231,13 +239,15 @@ def run( headers = [ 'fwd_M', 'fwd_K', 'fwd_N', # gemm microbenchmarks - 'bf16_gemm_s', 'fp8_gemm_s', + 'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s', # roofline memory overhead estimates 'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit', 'fp8_oh_del_limit', 'fp8_oh_del_nolimit', # actual e2e measurements - 'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s', - 'fp8_dyn_speedup', 'fp8_del_speedup', + 'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s', + # 'fp8_lw_s', + 'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp', + # 'fp8_lw_sp', ] results = [] @@ -248,15 +258,18 @@ def run( break if gemm_time_strategy == "benchmarks": - bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) - bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) - bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) + bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) + bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) + bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 + fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs else: assert gemm_time_strategy == "roofline", "unsupported" bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + # for now, assume axiswise gemm is similar to tensorwise + fp8_axs_gemm_time_s = fp8_gemm_time_s fp8_mem_time_dyn_limit_s = \ fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) @@ -291,14 +304,30 @@ def run( cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) - m_fp8_del = convert_to_float8_training(m_orig) + m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_del = torch.compile(m_fp8_del) fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) + # get the float8 dynamic axiswise scaling gpu kernel time + torch._dynamo.reset() + config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) + m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) + fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) + + # get the lw recipe scaling gpu kernel time + # TODO(future PR): enable below once basic performance issues + # are fixed + # torch._dynamo.reset() + # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) + # m_fp8_lw = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw = torch.compile(m_fp8_lw) + # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) + results.append([ M_val, K_val, N_val, # gemm microbenchmarks - bf16_time_val, fp8_gemm_time_s, + bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s, # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, @@ -306,8 +335,12 @@ def run( fp8_mem_time_del_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s, + fp8_dyn_axs_time_actual_s, + # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, bf16_time_actual_s / fp8_del_time_actual_s, + bf16_time_actual_s / fp8_dyn_axs_time_actual_s, + # bf16_time_actual_s / fp8_lw_time_actual_s, ]) df = pd.DataFrame(results, columns=headers)