Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axiswise granularity to Float8Tensor #919

Merged
merged 9 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 111 additions & 5 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -51,14 +57,15 @@


is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True


class TestFloat8Tensor(unittest.TestCase):
class TestFloat8Tensor:
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand All @@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)
assert x3_hp.dtype == hp_dtype

def test_differentiable_casts(self) -> None:
lp_dtypes = (e4m3_dtype, e5m2_dtype)
Expand Down Expand Up @@ -103,7 +110,7 @@ def test_index_put(self):
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
Expand All @@ -117,7 +124,7 @@ def test_copy_(self):
b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
Expand All @@ -129,6 +136,105 @@ def test_copy_(self):
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for another API: instead of an enum + extra params on a case-by-case basis, we could reuse the same idea that @drisspg used in the _scaled_mm operator: deduce the kind of scaling based on the size/shape of the desired scale tensor!

Concretely, we could add a single scale_shape=... parameter, which for row-wise would be [-1, 1], indicating that:

  • all columns (second dim) should be grouped and reduced into a single scaling factor (because the second element has a value of 1)
  • but that for the rows (first dim) there should be as many scaling factors as there are rows (because the first element has a value of -1, which gets replaced with the dim of the input tensor).

The scale shape is right-aligned to the shape of the tensor (thus following PyTorch's standard broadcast semantics), and then left-padded with 1 (again, standard semantics). This means that tensor-wise scaling is achieved with a scale_size=[].

Using this convention will later allow to express block-wise scaling (e.g., 128x128), group-wise scaling (1x128) and maybe even column-wise scaling if that ever becomes a thing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One wrinkle to work through would be that Float8Tensor can be of any rank, but operand inputs to torch._scaled_mm are required to be of rank 2, to match torch.mm|torch.addmm.

I'm definitely open to making this more flexible in the future. We've been careful to keep Float8Tensor and these utility functions out of the public API, to give us the freedom to make these kinds of changes as other scaling types become more important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, if someone puts up a PR for ^, sgtm!

)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()

# if we scale across dim0, we can only reshape to [3, -1]
a_fp8_d0 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
assert list(a_fp8_d0._data.shape) == [3, 5, 7]
assert list(a_fp8_d0._scale.shape) == [1, 5, 7]

a_fp8_d0_r = a_fp8_d0.reshape(3, -1)
assert list(a_fp8_d0_r.shape) == [3, 5 * 7]
assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7]
# verify numerics did not change
assert torch.allclose(
a_fp8_d0.to_original_precision(),
a_fp8_d0_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)

# if we scale across dim2, we can only reshape to [-1, 7]
a_fp8_d2 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
assert list(a_fp8_d2._data.shape) == [3, 5, 7]
assert list(a_fp8_d2._scale.shape) == [3, 5, 1]

a_fp8_d2_r = a_fp8_d2.reshape(-1, 7)
assert list(a_fp8_d2_r.shape) == [3 * 5, 7]
assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1]
# verify numerics did not change
assert torch.allclose(
a_fp8_d2.to_original_precision(),
a_fp8_d2_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1, # will be transposed
)
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0


class TestFloat8Linear:
Expand Down
12 changes: 12 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def short_str(self):
return "sta"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down
93 changes: 92 additions & 1 deletion torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}


def _assert_tensorwise_scale(aten_op, scale):
assert (
# TODO(future PR): figure out why tensorwise scaling can have
# both rank 0 and rank 1
len(scale.shape)
in (0, 1)
), f"{aten_op} with axiswise scaling is not supported yet"


def implements(aten_ops):
"""Register aten ops to the float8 op table"""

Expand All @@ -45,6 +54,7 @@ def decorator(func):
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data,
Expand All @@ -55,10 +65,82 @@ def float8_desugar_op(aten_op, args, kwargs=None):
)


@implements(
[
aten.t.default,
aten.transpose.int,
]
)
def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)

if aten_op == aten.transpose.int:
_assert_tensorwise_scale(aten_op, args[0]._scale)

old_axiswise_dim = args[0]._axiswise_dim
new_axiswise_dim = old_axiswise_dim
if old_axiswise_dim is not None:
if old_axiswise_dim == 0:
new_axiswise_dim == -1
else:
new_axiswise_dim == 0

return Float8Tensor(
new_data,
new_scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
new_axiswise_dim,
)


@implements([aten.view.default])
def float8_view(aten_op, args, kwargs=None):
if len(args[0]._scale.shape) < 2:
# tensorwise scaling
return float8_desugar_op(aten_op, args, kwargs)

t, new_shape = args[0], args[1]
# for now, only support reshaping to [-1, dim] or [dim, -1]
axiswise_dim = t._axiswise_dim
if len(new_shape) == 2:

if axiswise_dim == 0:
new_data = aten_op(t._data, new_shape, **kwargs)
new_scale_shape = [1, new_shape[-1]]
new_scale = aten_op(t._scale, new_scale_shape, **kwargs)
return Float8Tensor(
new_data,
new_scale,
t._orig_dtype,
t._linear_mm_config,
t._gemm_input_role,
t._axiswise_dim,
)
elif axiswise_dim == -1 or axiswise_dim == (len(t.shape) - 1):
new_data = aten_op(t._data, new_shape, **kwargs)
new_scale_shape = [new_shape[0], 1]
new_scale = aten_op(t._scale, new_scale_shape, **kwargs)
new_axiswise_dim = -1
return Float8Tensor(
new_data,
new_scale,
t._orig_dtype,
t._linear_mm_config,
t._gemm_input_role,
new_axiswise_dim,
)
raise AssertionError(
f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet."
)


@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)

_assert_tensorwise_scale(aten_op, args[0]._scale)
def make_float8(data):
return Float8Tensor(
data,
Expand Down Expand Up @@ -102,6 +184,7 @@ def float8_cat(aten_op, args, kwargs=None):
assert (
chunk._gemm_input_role is gemm_input_role
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
_assert_tensorwise_scale(aten_op, chunk._scale)
chunk_data.append(chunk._data.view(torch.uint8))

new_data = aten_op(chunk_data, *args[1:], **kwargs)
Expand All @@ -118,6 +201,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
"addmm" -> out
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)

def unwrap(x):
if isinstance(x, Float8Tensor):
Expand Down Expand Up @@ -230,6 +314,7 @@ def float8_addmm(aten_op, args, kwargs=None):

@implements([aten.is_same_size.default])
def float8_is_same_size(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
return args[0].shape == args[1].shape


Expand All @@ -239,6 +324,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
when the input is a Float8Tensor, presenting as a fp32
tensor.
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
assert isinstance(args[0], Float8Tensor)
assert (
len(kwargs) == 1 and "dtype" in kwargs
Expand Down Expand Up @@ -266,6 +352,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
"""
override funcol with FP8 handling
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(
fp8_input, Float8Tensor
Expand All @@ -285,6 +372,7 @@ def allgather_fp8(aten_op, args, kwargs=None):

@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
def wait_tensor_fp8(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(fp8_input, Float8Tensor)

Expand All @@ -305,6 +393,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
_assert_tensorwise_scale(fp8_self, args[0]._scale)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype
Expand Down Expand Up @@ -335,8 +424,10 @@ def copy_fp8(aten_op, args, kwargs=None):

if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
_assert_tensorwise_scale(aten_op, src._scale)
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
_assert_tensorwise_scale(aten_op, src._scale)
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
Expand Down
Loading
Loading