Skip to content

Commit

Permalink
[fp8] zero support fp8 linear. (#6006)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* zero fp8

* zero fp8

* Update requirements.txt
  • Loading branch information
flybird11111 authored Aug 16, 2024
1 parent 3f09a61 commit 0a51319
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
19 changes: 16 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
Expand Down Expand Up @@ -62,7 +63,9 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -74,11 +77,16 @@ def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool =
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
self.use_fp8 = use_fp8
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
self.op_hooks = []
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
Expand Down Expand Up @@ -335,6 +343,7 @@ def __init__(
master_weights: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
Expand Down Expand Up @@ -362,6 +371,7 @@ def __init__(
)
self.lora_enabled = False
self.verbose = verbose
self.use_fp8 = use_fp8

# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
Expand Down Expand Up @@ -476,7 +486,10 @@ def configure(

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
use_fp8=self.use_fp8,
)

# TODO: Support Galore + ZeRO
Expand Down
1 change: 0 additions & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def empty_init():
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)

init_kwargs = {}
if config.model_type == "chatglm":
init_kwargs["empty_init"] = False
Expand Down

0 comments on commit 0a51319

Please sign in to comment.