Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Oct 25, 2023
1 parent f12aeba commit 2cd979c
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

import warnings
from typing import ClassVar, Dict, Optional
from typing import ClassVar

from ._apex_amp_modifier import ApexAMPModifier
from ._ds_modifier import DeepSpeedZeROModifier
Expand All @@ -31,15 +33,15 @@ def get_full_qualified_type_name(o):


class OptimizerModifierTypeRegistry:
_MAP: ClassVar[Dict[str, FP16OptimizerModifier]] = {
_MAP: ClassVar[dict[str, FP16OptimizerModifier]] = {
"megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier,
"deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
"deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
"apex.amp.optimizer.unique_name_as_id": ApexAMPModifier,
}

@staticmethod
def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> Optional[FP16OptimizerModifier]:
def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> FP16OptimizerModifier | None:
"""Create modifier for optimizer."""
if optimizer_full_qualified_name in OptimizerModifierTypeRegistry._MAP:
return OptimizerModifierTypeRegistry._MAP[optimizer_full_qualified_name](optimizer, **kwargs)
Expand Down

0 comments on commit 2cd979c

Please sign in to comment.