Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axiswise granularity to Float8Tensor #919

Merged
merged 9 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 111 additions & 5 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -51,14 +57,15 @@


is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True


class TestFloat8Tensor(unittest.TestCase):
class TestFloat8Tensor:
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand All @@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)
assert x3_hp.dtype == hp_dtype

def test_differentiable_casts(self) -> None:
lp_dtypes = (e4m3_dtype, e5m2_dtype)
Expand Down Expand Up @@ -103,7 +110,7 @@ def test_index_put(self):
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
Expand All @@ -117,7 +124,7 @@ def test_copy_(self):
b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
Expand All @@ -129,6 +136,105 @@ def test_copy_(self):
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for another API: instead of an enum + extra params on a case-by-case basis, we could reuse the same idea that @drisspg used in the _scaled_mm operator: deduce the kind of scaling based on the size/shape of the desired scale tensor!

Concretely, we could add a single scale_shape=... parameter, which for row-wise would be [-1, 1], indicating that:

  • all columns (second dim) should be grouped and reduced into a single scaling factor (because the second element has a value of 1)
  • but that for the rows (first dim) there should be as many scaling factors as there are rows (because the first element has a value of -1, which gets replaced with the dim of the input tensor).

The scale shape is right-aligned to the shape of the tensor (thus following PyTorch's standard broadcast semantics), and then left-padded with 1 (again, standard semantics). This means that tensor-wise scaling is achieved with a scale_size=[].

Using this convention will later allow to express block-wise scaling (e.g., 128x128), group-wise scaling (1x128) and maybe even column-wise scaling if that ever becomes a thing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One wrinkle to work through would be that Float8Tensor can be of any rank, but operand inputs to torch._scaled_mm are required to be of rank 2, to match torch.mm|torch.addmm.

I'm definitely open to making this more flexible in the future. We've been careful to keep Float8Tensor and these utility functions out of the public API, to give us the freedom to make these kinds of changes as other scaling types become more important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, if someone puts up a PR for ^, sgtm!

)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()

# if we scale across dim0, we can only reshape to [3, -1]
a_fp8_d0 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
assert list(a_fp8_d0._data.shape) == [3, 5, 7]
assert list(a_fp8_d0._scale.shape) == [1, 5, 7]

a_fp8_d0_r = a_fp8_d0.reshape(3, -1)
assert list(a_fp8_d0_r.shape) == [3, 5 * 7]
assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7]
# verify numerics did not change
assert torch.allclose(
a_fp8_d0.to_original_precision(),
a_fp8_d0_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)

# if we scale across dim2, we can only reshape to [-1, 7]
a_fp8_d2 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
assert list(a_fp8_d2._data.shape) == [3, 5, 7]
assert list(a_fp8_d2._scale.shape) == [3, 5, 1]

a_fp8_d2_r = a_fp8_d2.reshape(-1, 7)
assert list(a_fp8_d2_r.shape) == [3 * 5, 7]
assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1]
# verify numerics did not change
assert torch.allclose(
a_fp8_d2.to_original_precision(),
a_fp8_d2_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1, # will be transposed
)
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0


class TestFloat8Linear:
Expand Down
14 changes: 14 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def short_str(self):
return "sta"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down Expand Up @@ -146,6 +158,8 @@ class Float8LinearConfig:
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.
# TODO(future PR): either enable by default or have a warning and set up the
# tests so that the warning does not spam the CI stdout.

force_recompute_fp8_weight_in_bwd: bool = False

Expand Down
12 changes: 0 additions & 12 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import dataclasses
import enum
import logging

from typing import Optional

Expand Down Expand Up @@ -50,8 +49,6 @@
WeightWithStaticFloat8CastTensor,
)

logger = logging.getLogger(__name__)


# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files
@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -191,15 +188,6 @@ def __init__(self, *args, **kwargs):
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward

# See the comments in config.py for more details of this option.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically not related to this PR, but making the test logs non-spammy for now and we can add this back in a better way later

if (
self.config.enable_pre_and_post_forward
and not self.config.force_recompute_fp8_weight_in_bwd
):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
Expand Down
Loading
Loading