Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several Improvements for the latest PyTorch Framework #1564

Closed
wants to merge 15 commits into from
Closed
3 changes: 2 additions & 1 deletion mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ def compile_model(
Returns:
nn.Module: Compiled model.
"""
if isinstance(compile, bool) and not compile:
if isinstance(compile, bool) and not compile or \
isinstance(compile, dict) and not compile.get('disable', False):
return model

assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
Expand Down
3 changes: 3 additions & 0 deletions mmengine/model/wrappers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self,

def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
return self.module.train_step(data, optim_wrapper)
"""Interface for model forward, backward and parameters updating during
training process.

Expand Down Expand Up @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list],
return log_vars

def val_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.val_step(data)
"""Gets the prediction of module during validation process.

Args:
Expand All @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.val_step(data)

def test_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.test_step(data)
"""Gets the predictions of module during testing process.

Args:
Expand Down
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from functools import partial
from typing import Union

import torch
Expand All @@ -17,7 +18,8 @@
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler as amp_GradScaler
GradScaler = partial(amp_GradScaler, device='cuda')


@OPTIM_WRAPPERS.register_module()
Expand Down
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def register_transformers_optimizers():
except ImportError:
pass
else:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
# KeyError: 'Adafactor is already registered in optimizer at torch.optim'
# OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
transformer_optimizers.append('Adafactor')
return transformer_optimizers

Expand Down Expand Up @@ -211,5 +212,6 @@ def build_optim_wrapper(model: nn.Module,
type=constructor_type,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))

optim_wrapper = optim_wrapper_constructor(model)
return optim_wrapper
2 changes: 1 addition & 1 deletion mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def load_from_local(filename, map_location):
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
return checkpoint


Expand Down
21 changes: 19 additions & 2 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from mmengine.dataset.sampler import InfiniteSampler
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
Expand Down Expand Up @@ -274,14 +275,28 @@ def run(self) -> None:
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler):
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
break # NOTE MGAM: override all preprocessing steps during resume.
next(self.dataloader_iterator)

# with torch.profiler.profile(
# activities=[torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=2, active=3),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'),
# record_shapes=True,
# profile_memory=True,
# with_stack=True,
# with_flops=True,
# with_modules=True,
# ) as p:

while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

Expand All @@ -292,8 +307,10 @@ def run(self) -> None:
if (self.runner.val_loop is not None
and self._iter >= self.val_begin
and (self._iter % self.val_interval == 0
or self._iter == self._max_iters)):
or self._iter == self._max_iters)):
self.runner.val_loop.run()

# p.step()

self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
Expand Down
19 changes: 16 additions & 3 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import logging
import os
import os.path as osp
Expand Down Expand Up @@ -902,8 +903,20 @@ def wrap_model(
find_unused_parameters=find_unused_parameters)
else:
model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_cfg.get('type')) # type: ignore

model_wrapper_type = model_wrapper_cfg.get('type')
if isinstance(model_wrapper_type, str):
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_type) # type: ignore
elif inspect.isclass(model_wrapper_type):
pass
else:
raise KeyError(
f'{model_wrapper_type} is not in the '
'registry. Please check whether the value of '
f'`{model_wrapper_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
default_args: dict = dict()
if issubclass(
model_wrapper_type, # type: ignore
Expand Down Expand Up @@ -1838,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None:
try:
getattr(hook, fn_name)(self, **kwargs)
except TypeError as e:
raise TypeError(f'{e} in {hook}') from None
raise TypeError(f'{e} in {hook}') from e

def register_hook(
self,
Expand Down
3 changes: 2 additions & 1 deletion mmengine/visualization/vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ def add_scalar(self,
(int, float, torch.Tensor, np.ndarray, np.number)):
self._tensorboard.add_scalar(name, value, step)
else:
warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, '
warnings.warn(f'Got type {type(value)} with name {name}, '
'but numpy array, torch tensor, '
f'int or float are expected. skip it!')

@force_init_env
Expand Down