diff --git a/megatron/utils.py b/megatron/utils.py index d00f4cd0ef..3d5eef4672 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -4,24 +4,14 @@ import sys import os -import time import logging -from typing import ContextManager, Optional +from typing import Optional import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from deepspeed.accelerator import get_accelerator -if get_accelerator().device_name() == "cuda": - try: - from apex.multi_tensor_apply import multi_tensor_applier - import amp_C - - HAS_APEX = True - except Exception: - HAS_APEX = False - from megatron import get_args, get_adlr_autoresume, get_num_microbatches from megatron.core import mpu from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate @@ -30,70 +20,104 @@ import ezpz as ez +ACCELERATOR = get_accelerator() +assert ACCELERATOR is not None + +if ACCELERATOR.device_name() == "cuda": + try: + from apex.multi_tensor_apply import multi_tensor_applier # type:ignore + import amp_C # type:ignore + + HAS_APEX = True + except Exception: + HAS_APEX = False + RANK = ez.get_rank() log = logging.getLogger(__name__) -# log.setLevel("INFO") if RANK == 0 else log.setLevel("CRITICAL") - +log.setLevel(os.environ.get("LOG_LEVEL", ("INFO" if RANK == 0 else "CRITICAL"))) _DLIO_PROFILER_EXIST = True -_DFTRACER_EXIST=True +_DFTRACER_EXIST = True try: - import dftracer + import dftracer # type:ignore except Exception: - _DFTRACER_EXIST=False + _DFTRACER_EXIST = False try: - import dlio_profiler + import dlio_profiler # type:ignore except Exception: _DLIO_PROFILER_EXIST = False if _DFTRACER_EXIST: - from dftracer.logger import dftracer as PerfTrace, dft_fn as Profile, DFTRACER_ENABLE as DFTRACER_ENABLE + from dftracer.logger import ( # type:ignore + dftracer as PerfTrace, + dft_fn as Profile, + DFTRACER_ENABLE as DFTRACER_ENABLE, + ) elif _DLIO_PROFILER_EXIST: - from dlio_profiler.logger import fn_interceptor as Profile - from dlio_profiler.logger import dlio_logger as PerfTrace + from dlio_profiler.logger import fn_interceptor as Profile # type:ignore + from dlio_profiler.logger import dlio_logger as PerfTrace # type:ignore else: from functools import wraps - # from contextlib import nullcontext - # Profile: ContextManager = nullcontext - # - # class Profile(nullable_schema) + class Profile(object): - def __init__(self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None): - return - def log(self, func): + def __init__( + self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None + ): + return + + def log(self, func): return func - def log_init(self, func): + + def log_init(self, func): return func - def iter(self, func, iter_name="step"): + + def iter(self, func, iter_name="step"): return func + def __enter__(self): return + def __exit__(self, type, value, traceback): return - def update(self, epoch=None, step=None, image_idx=None, image_size=None, args={}): + + def update( + self, epoch=None, step=None, image_idx=None, image_size=None, args={} + ): return + def flush(self): return + def reset(self): return + def log_static(self, func): return + class dftracer(object): - def __init__(self,): + def __init__( + self, + ): self.type = None + def initialize_log(self, logfile=None, data_dir=None, process_id=-1): return + def get_time(self): return + def enter_event(self): return + def exit_event(self): return + def log_event(self, name, cat, start_time, duration, string_args=None): return + def finalize(self): return @@ -103,8 +127,8 @@ def finalize(self): def get_logger( name: str, - level: str = "INFO", - rank_zero_only: Optional[bool] = None, + level: Optional[str] = None, + rank_zero_only: Optional[bool] = True, ) -> logging.Logger: """Returns a `logging.Logger` object. @@ -112,7 +136,9 @@ def get_logger( non-zero ranks (and will be set to `level` on RANK==0). """ logger = logging.getLogger(name) - logger.setLevel(level) + logger.setLevel( + str(level if level is not None else os.environ.get("LOG_LEVEL", "INFO")).upper() + ) if rank_zero_only and ez.get_rank() != 0: logger.setLevel("CRITICAL") return logger