From cbedb73f80b3ee038fe04e7617dcbfd4b4910567 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 14 Jan 2025 09:54:52 -0800 Subject: [PATCH] add handling for selective per layer ac in float8nocompile --- torchtitan/float8.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index d5aa8f28..b9d60959 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -13,7 +13,7 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -from typing import List, Union +from typing import Callable, List, Union import torch import torch.nn as nn @@ -48,9 +48,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) from e self.use_float8nocompile = float8_config.float8nocompile - self.use_float8nocompile_no_precompute_for_backward = ( - float8_config.float8nocompile_no_precompute_for_backward - ) + self.ac_config = job_config.activation_checkpoint # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -95,20 +93,30 @@ def convert_to_float8_training(self, model: nn.Module): if not self.enabled: return - # TODO: should we implicitly use this if self.compile is False, rather - # than having an explicit flag? if self.use_float8nocompile: logger.info("Using float8nocompile prototype") from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) + # for full AC or no AC + no_precompute_for_backward = self.ac_config.mode == "full" convert_to_float8_nocompile_training( model, config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", - no_precompute_for_backward=self.use_float8nocompile_no_precompute_for_backward, + no_precompute_for_backward=no_precompute_for_backward, ) + + # for selective per layer AC + if ( + self.ac_config.mode == "selective" + and self.ac_config.selective_ac_option.isdigit() + ): + no_precompute_for_backward_every_nth_layer( + model, + int(self.ac_config.selective_ac_option), + ) else: logger.info("Using float8 training") from torchao.float8 import convert_to_float8_training @@ -166,3 +174,22 @@ def sync_float8_amax_and_scale_history( models = [model] if isinstance(model, nn.Module) else model for m in models: self._sync_float8_amax_and_scale_history(m) + + +def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int): + """Set no_precompute_for_backward to True for every nth layer in the model.""" + for layer_idx, (layer_id, transformer_block) in enumerate( + model.layers.named_children() + ): + if layer_idx % n == 0: + logger.info(f"Enabling no_precompute_for_backward to layer {layer_id}") + _enable_no_precompute_for_backward(transformer_block) + + +def _enable_no_precompute_for_backward(model: nn.Module): + """Recursively set no_precompute_for_backward to True for all linear layers in the given model.""" + for layer in model.children(): + if isinstance(layer, nn.Linear): + layer.no_precompute_for_backward = True + else: + _enable_no_precompute_for_backward(layer)