Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Updates to sync_float_amax_history (#211)
Browse files Browse the repository at this point in the history
Summary:
Update docs, make sure this is friendly to dynamo

### Perf

PyTorch Version | float8 Version | Eager Iterations per Second | Compile
-- | -- | -- | --
Nightly | Main | 1.15 it/s |  2.10 it/s
Nightly | This PR | 1.16 it/s | 2.27 it/s

Trace | Compile URL | Eager
-- | -- | --
This PR | https://fburl.com/753ztao4 |  https://fburl.com/34yftzao
Main |  https://fburl.com/a0gh9iof |  https://fburl.com/u9c4ilmp

### Things I have done/changed

#### Commit 1
- [x] We previously had an `fp8_classes` argument that would be passed in, this was to enable working with the separate TP/SP classes, since we plan to have Dtensor be the solution I am removing for now.
- [x] I put the child.amax_and_scale_synced module mutation under the enable_amax_init flag, this seemed to be causing graphbreaks cause of the module mutation

#### Commit 2

- [x] We previously had all the history buffers be scaler tensors. This meant that to construct the combined tensor we needed to call torch.Tensor which was causing a HtoD sync under torch.compile. I needed to added a single dimension of size 1 and pipe that through all the places.
- [x] Note that this meant we needed to update the to_hp to send back to original precision because [line](f3630d0#diff-94b99416a4df6d75c548de330c1f71505e830b3afff114213d131cf2620597efR57-R59) the scale upcasts the _data tensor

#### Commit 3
- [x] Rewrote the sync function to do the torch.roll() on all the histories at once - side note not sure if this is more expensive than to clones since we really dont care about the wrapping behavior
- [x] Same for generating the new scales from the grouped histories

##### Things to do
- There is still two loops and those are for mutating the the actual module values, not sure if there is another way around this..
- Going to try the functional collectives

Pull Request resolved: #211

Reviewed By: awgu

Differential Revision: D53779974

Pulled By: drisspg

fbshipit-source-id: 0a07f247d41d58f1934a69d194f81c5dea230eb1
  • Loading branch information
drisspg authored and facebook-github-bot committed Feb 14, 2024
1 parent 0af8433 commit 956195b
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 97 deletions.
13 changes: 7 additions & 6 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,24 @@ def __init__(self, *args, **kwargs):
self.recipe = delayed_scaling_recipe
history_len = self.recipe.history_len

self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS]))
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0]))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS]))
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0]))
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS)
"fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS])
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))

# Whether to emulate the fp8 matmul logic in float32
self.emulate = False

Expand Down
226 changes: 136 additions & 90 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from enum import auto, Enum
from typing import List, Optional, Type

import float8_experimental.config as fp8_config

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import amax_history_to_scale
from float8_experimental.float8_utils import amax_history_to_scale_stack
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())


class LinearType(Enum):
Expand Down Expand Up @@ -57,14 +64,26 @@ def linear_requires_sync(linear_type: LinearType):
return linear_type in REQUIRES_SYNC


def _update_history_with_new_amax(new_amax, amax_history):
def _update_history_stack(
new_amax: torch.Tensor, amax_history_stack: torch.Tensor
) -> torch.Tensor:
"""
Updates `amax_history` (the last N cur_amax values) inplace with the value
of `new_amax`.
Args:
new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1)
amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length)
"""
new_amax_history = torch.roll(amax_history, 1)
new_amax_history[0] = new_amax
amax_history.copy_(new_amax_history)
assert (
amax_history_stack.dim() == 2
), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}"
assert new_amax.size(0) == amax_history_stack.size(
0
), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}"
new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1)
new_amax_history_stack[:, 0] = new_amax.squeeze(-1)
amax_history_stack.copy_(new_amax_history_stack)


def swap_linear_with_float8_linear(
Expand Down Expand Up @@ -121,21 +140,20 @@ def post_order_traversal(
return root_module


def get_float8_layers(model: torch.nn.Module, fp8_classes=None):
if fp8_classes is None:
fp8_classes = Float8Linear
def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
model (torch.nn.Module): The model to look for Float8Linear layers in.
"""

# Get all fp8 layers and tensors
fp8_layers = [
child for name, child in model.named_modules() if isinstance(child, fp8_classes)
]
fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)]

return fp8_layers


def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None, fp8_layers=None
) -> None:
@torch.no_grad()
def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None:
"""
Manages the float8 amax and scale bookkeeping. In detail, it does the
following:
Expand All @@ -147,95 +165,123 @@ def sync_float8_amax_and_scale_history(
TODO(future): design the UX for this (context manager, etc)
PERFORMANCE NOTE:
When you can, it is much more efficient to call get_float8_layers once at
the beginning of the training loop and pass the result to this function.
Because of how this interacts with torch.compile
Args:
model (torch.nn.Module): The model to track amaxes for
fp8_classes (optional): The fp8 classes to look for in the model.
The default is Float8Linear.
When using with TP, users can pass in the customized TP classes instead.
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
and we loop over all fp8_layers to sync and update amax scale histories.
Users can use get_float8_layers to get all fp8 layers.
"""

# For now, this is written in a naive way to maximize code readability.
# TODO(future): benchmark and optimize as needed, we have combined all
# the reductions into one and we can probably try other optimizatons to
# make the history update faster.

if fp8_layers is None:
fp8_layers = get_float8_layers(model, fp8_classes)
fp8_layers = get_float8_layers(model)

if dist.is_initialized():
fp8_amax_x_tensor = torch.tensor(
[child.fp8_amax_x for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
if len(fp8_layers) == 0:
log.warn(
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
)
fp8_amax_w_tensor = torch.tensor(
[child.fp8_amax_w for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
fp8_amax_dL_dY_tensor = torch.tensor(
[child.fp8_amax_dL_dY for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX)

for idx in range(len(fp8_layers)):
child = fp8_layers[idx]

#
# 1. in distributed contexts, syncs amax values across workers
#
if dist.is_initialized():
child.fp8_amax_x = fp8_amax_x_tensor[idx].clone()
child.fp8_amax_w = fp8_amax_w_tensor[idx].clone()
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone()

#
# 2. adds the `amax` values to history
#
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x)
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w)
_update_history_with_new_amax(
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY
return

# Loop over all fp8 layers and grab the needed tensors
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)

fp8_x_amax_history_stack = [None] * len(fp8_layers)
fp8_w_amax_history_stack = [None] * len(fp8_layers)
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)

x_dtypes = set()
scale_fn_recipes = set()

for idx, child in enumerate(fp8_layers):
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY

fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY

x_dtypes.add(child.last_seen_input_dtype)
scale_fn_recipes.add(child.recipe.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
raise ValueError(
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))

#
# 3. calculate the scales
#
# TODO what to do with x_dtype
x_dtype = child.last_seen_input_dtype
new_scale = amax_history_to_scale(
child.fp8_amax_history_x,
torch.float8_e4m3fn,
x_dtype,
child.recipe.scale_fn_name,
if len(scale_fn_recipes) != 1:
raise ValueError(
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
)
child.fp8_scale_x.copy_(new_scale)
new_scale = amax_history_to_scale(
child.fp8_amax_history_w,
torch.float8_e4m3fn,
x_dtype,
child.recipe.scale_fn_name,
scale_fn_recipe = next(iter(scale_fn_recipes))

assert (
len(fp8_amax_x_tensor_list)
== len(fp8_amax_w_tensor_list)
== len(fp8_amax_dL_dY_tensor_list)
), "Mismatched lengths of amax tensors."

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
)
child.fp8_scale_w.copy_(new_scale)
new_scale = amax_history_to_scale(
child.fp8_amax_history_dL_dY,
torch.float8_e5m2,
x_dtype,
child.recipe.scale_fn_name,
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
)
child.fp8_scale_dL_dY.copy_(new_scale)
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)

fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)

# Update the history stacks with the new amax values
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)

# Calculate the new scales from the updated history stacks
new_x_scales = amax_history_to_scale_stack(
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_w_scales = amax_history_to_scale_stack(
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_dL_dY_scales = amax_history_to_scale_stack(
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
)

# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
for idx, child in enumerate(fp8_layers):
child.fp8_scale_x.copy_(new_x_scales[idx])
child.fp8_scale_w.copy_(new_w_scales[idx])
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])

#
# 4. set a flag to signal amaxes/scales are ready
#
child.amax_and_scale_synced = True
# We only update the flag if we know it will be checked by the modules
if fp8_config.enable_amax_init:

This comment has been minimized.

Copy link
@awgu

awgu Feb 15, 2024

Contributor

I think adding this if broke enable_amax_init = False because we still check and raise:

raise AssertionError(
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
)

If I understand correctly, checking if amax and scale are synced is orthogonal to whether we enable amax init in general. Only on the 1st iteration, if we disable amax init, then we can assume that the amax and scale are already synced?

child.amax_and_scale_synced = True
2 changes: 2 additions & 0 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from typing import Optional, Tuple

import float8_experimental.float8_aten_api # noqa

import torch
from float8_experimental.float8_tensor import Float8Tensor

Expand Down
18 changes: 17 additions & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
scale = torch.empty((), device=amax.device, dtype=torch.float32)
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
Expand Down Expand Up @@ -51,6 +51,22 @@ def amax_history_to_scale(
raise NotImplementedError()


@torch.no_grad()
def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: str,
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor."""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
raise NotImplementedError(
f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
)


@torch.no_grad()
def tensor_to_amax(x, distributed_reduction=False):
amax = torch.max(torch.abs(x))
Expand Down

0 comments on commit 956195b

Please sign in to comment.