Skip to content

Commit

Permalink
add axiswise granularity to Float8Tensor (#919)
Browse files Browse the repository at this point in the history
Summary:

This is a copy-paste of pytorch-labs/float8_experimental#352
which never landed.

Test Plan:


Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored and jainapurva committed Oct 15, 2024
1 parent 89a8403 commit 92feafa
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 34 deletions.
116 changes: 111 additions & 5 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -51,14 +57,15 @@


is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True


class TestFloat8Tensor(unittest.TestCase):
class TestFloat8Tensor:
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand All @@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)
assert x3_hp.dtype == hp_dtype

def test_differentiable_casts(self) -> None:
lp_dtypes = (e4m3_dtype, e5m2_dtype)
Expand Down Expand Up @@ -103,7 +110,7 @@ def test_index_put(self):
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
Expand All @@ -117,7 +124,7 @@ def test_copy_(self):
b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
Expand All @@ -129,6 +136,105 @@ def test_copy_(self):
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()

# if we scale across dim0, we can only reshape to [3, -1]
a_fp8_d0 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
assert list(a_fp8_d0._data.shape) == [3, 5, 7]
assert list(a_fp8_d0._scale.shape) == [1, 5, 7]

a_fp8_d0_r = a_fp8_d0.reshape(3, -1)
assert list(a_fp8_d0_r.shape) == [3, 5 * 7]
assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7]
# verify numerics did not change
assert torch.allclose(
a_fp8_d0.to_original_precision(),
a_fp8_d0_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)

# if we scale across dim2, we can only reshape to [-1, 7]
a_fp8_d2 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
assert list(a_fp8_d2._data.shape) == [3, 5, 7]
assert list(a_fp8_d2._scale.shape) == [3, 5, 1]

a_fp8_d2_r = a_fp8_d2.reshape(-1, 7)
assert list(a_fp8_d2_r.shape) == [3 * 5, 7]
assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1]
# verify numerics did not change
assert torch.allclose(
a_fp8_d2.to_original_precision(),
a_fp8_d2_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1, # will be transposed
)
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0


class TestFloat8Linear:
Expand Down
14 changes: 14 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def short_str(self):
return "sta"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down Expand Up @@ -146,6 +158,8 @@ class Float8LinearConfig:
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.
# TODO(future PR): either enable by default or have a warning and set up the
# tests so that the warning does not spam the CI stdout.

force_recompute_fp8_weight_in_bwd: bool = False

Expand Down
12 changes: 0 additions & 12 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import dataclasses
import enum
import logging

from typing import Optional

Expand Down Expand Up @@ -50,8 +49,6 @@
WeightWithStaticFloat8CastTensor,
)

logger = logging.getLogger(__name__)


# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files
@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -191,15 +188,6 @@ def __init__(self, *args, **kwargs):
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward

# See the comments in config.py for more details of this option.
if (
self.config.enable_pre_and_post_forward
and not self.config.force_recompute_fp8_weight_in_bwd
):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
Expand Down
Loading

0 comments on commit 92feafa

Please sign in to comment.