Skip to content

Commit

Permalink
Update megatron/utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Oct 13, 2024
1 parent 2b31b44 commit 5e9eed0
Showing 1 changed file with 60 additions and 34 deletions.
94 changes: 60 additions & 34 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,14 @@

import sys
import os
import time
import logging
from typing import ContextManager, Optional
from typing import Optional

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from deepspeed.accelerator import get_accelerator

if get_accelerator().device_name() == "cuda":
try:
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

HAS_APEX = True
except Exception:
HAS_APEX = False

from megatron import get_args, get_adlr_autoresume, get_num_microbatches
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
Expand All @@ -30,70 +20,104 @@

import ezpz as ez

ACCELERATOR = get_accelerator()
assert ACCELERATOR is not None

if ACCELERATOR.device_name() == "cuda":
try:
from apex.multi_tensor_apply import multi_tensor_applier # type:ignore
import amp_C # type:ignore

HAS_APEX = True
except Exception:
HAS_APEX = False

RANK = ez.get_rank()
log = logging.getLogger(__name__)
# log.setLevel("INFO") if RANK == 0 else log.setLevel("CRITICAL")

log.setLevel(os.environ.get("LOG_LEVEL", ("INFO" if RANK == 0 else "CRITICAL")))

_DLIO_PROFILER_EXIST = True
_DFTRACER_EXIST=True
_DFTRACER_EXIST = True

try:
import dftracer
import dftracer # type:ignore
except Exception:
_DFTRACER_EXIST=False
_DFTRACER_EXIST = False

try:
import dlio_profiler
import dlio_profiler # type:ignore
except Exception:
_DLIO_PROFILER_EXIST = False


if _DFTRACER_EXIST:
from dftracer.logger import dftracer as PerfTrace, dft_fn as Profile, DFTRACER_ENABLE as DFTRACER_ENABLE
from dftracer.logger import ( # type:ignore
dftracer as PerfTrace,
dft_fn as Profile,
DFTRACER_ENABLE as DFTRACER_ENABLE,
)
elif _DLIO_PROFILER_EXIST:
from dlio_profiler.logger import fn_interceptor as Profile
from dlio_profiler.logger import dlio_logger as PerfTrace
from dlio_profiler.logger import fn_interceptor as Profile # type:ignore
from dlio_profiler.logger import dlio_logger as PerfTrace # type:ignore
else:
from functools import wraps
# from contextlib import nullcontext
# Profile: ContextManager = nullcontext
#
# class Profile(nullable_schema)

class Profile(object):
def __init__(self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None):
return
def log(self, func):
def __init__(
self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None
):
return

def log(self, func):
return func
def log_init(self, func):

def log_init(self, func):
return func
def iter(self, func, iter_name="step"):

def iter(self, func, iter_name="step"):
return func

def __enter__(self):
return

def __exit__(self, type, value, traceback):
return
def update(self, epoch=None, step=None, image_idx=None, image_size=None, args={}):

def update(
self, epoch=None, step=None, image_idx=None, image_size=None, args={}
):
return

def flush(self):
return

def reset(self):
return

def log_static(self, func):
return

class dftracer(object):
def __init__(self,):
def __init__(
self,
):
self.type = None

def initialize_log(self, logfile=None, data_dir=None, process_id=-1):
return

def get_time(self):
return

def enter_event(self):
return

def exit_event(self):
return

def log_event(self, name, cat, start_time, duration, string_args=None):
return

def finalize(self):
return

Expand All @@ -103,16 +127,18 @@ def finalize(self):

def get_logger(
name: str,
level: str = "INFO",
rank_zero_only: Optional[bool] = None,
level: Optional[str] = None,
rank_zero_only: Optional[bool] = True,
) -> logging.Logger:
"""Returns a `logging.Logger` object.
If `rank_zero_only` passed, the level will be set to CRITICAL on all
non-zero ranks (and will be set to `level` on RANK==0).
"""
logger = logging.getLogger(name)
logger.setLevel(level)
logger.setLevel(
str(level if level is not None else os.environ.get("LOG_LEVEL", "INFO")).upper()
)
if rank_zero_only and ez.get_rank() != 0:
logger.setLevel("CRITICAL")
return logger
Expand Down

0 comments on commit 5e9eed0

Please sign in to comment.