From 2cd979c54c6ef5c0467799c20a525ecd66a65bae Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Wed, 25 Oct 2023 01:30:27 +0000 Subject: [PATCH] minor --- .../python/training/optim/_modifier_registry.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/optim/_modifier_registry.py b/orttraining/orttraining/python/training/optim/_modifier_registry.py index 3de30203cb1a9..a88740dac60b7 100644 --- a/orttraining/orttraining/python/training/optim/_modifier_registry.py +++ b/orttraining/orttraining/python/training/optim/_modifier_registry.py @@ -3,8 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import warnings -from typing import ClassVar, Dict, Optional +from typing import ClassVar from ._apex_amp_modifier import ApexAMPModifier from ._ds_modifier import DeepSpeedZeROModifier @@ -31,7 +33,7 @@ def get_full_qualified_type_name(o): class OptimizerModifierTypeRegistry: - _MAP: ClassVar[Dict[str, FP16OptimizerModifier]] = { + _MAP: ClassVar[dict[str, FP16OptimizerModifier]] = { "megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier, "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, @@ -39,7 +41,7 @@ class OptimizerModifierTypeRegistry: } @staticmethod - def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> Optional[FP16OptimizerModifier]: + def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> FP16OptimizerModifier | None: """Create modifier for optimizer.""" if optimizer_full_qualified_name in OptimizerModifierTypeRegistry._MAP: return OptimizerModifierTypeRegistry._MAP[optimizer_full_qualified_name](optimizer, **kwargs)