Skip to content

Commit

Permalink
add handling for selective per layer ac in float8nocompile
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Jan 14, 2025
1 parent 73715c6 commit cbedb73
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance

from typing import List, Union
from typing import Callable, List, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,9 +48,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
) from e

self.use_float8nocompile = float8_config.float8nocompile
self.use_float8nocompile_no_precompute_for_backward = (
float8_config.float8nocompile_no_precompute_for_backward
)
self.ac_config = job_config.activation_checkpoint

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
Expand Down Expand Up @@ -95,20 +93,30 @@ def convert_to_float8_training(self, model: nn.Module):
if not self.enabled:
return

# TODO: should we implicitly use this if self.compile is False, rather
# than having an explicit flag?
if self.use_float8nocompile:
logger.info("Using float8nocompile prototype")
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)

# for full AC or no AC
no_precompute_for_backward = self.ac_config.mode == "full"
convert_to_float8_nocompile_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
no_precompute_for_backward=self.use_float8nocompile_no_precompute_for_backward,
no_precompute_for_backward=no_precompute_for_backward,
)

# for selective per layer AC
if (
self.ac_config.mode == "selective"
and self.ac_config.selective_ac_option.isdigit()
):
no_precompute_for_backward_every_nth_layer(
model,
int(self.ac_config.selective_ac_option),
)
else:
logger.info("Using float8 training")
from torchao.float8 import convert_to_float8_training
Expand Down Expand Up @@ -166,3 +174,22 @@ def sync_float8_amax_and_scale_history(
models = [model] if isinstance(model, nn.Module) else model
for m in models:
self._sync_float8_amax_and_scale_history(m)


def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int):
"""Set no_precompute_for_backward to True for every nth layer in the model."""
for layer_idx, (layer_id, transformer_block) in enumerate(
model.layers.named_children()
):
if layer_idx % n == 0:
logger.info(f"Enabling no_precompute_for_backward to layer {layer_id}")
_enable_no_precompute_for_backward(transformer_block)


def _enable_no_precompute_for_backward(model: nn.Module):
"""Recursively set no_precompute_for_backward to True for all linear layers in the given model."""
for layer in model.children():
if isinstance(layer, nn.Linear):
layer.no_precompute_for_backward = True
else:
_enable_no_precompute_for_backward(layer)

0 comments on commit cbedb73

Please sign in to comment.