Skip to content

Commit

Permalink
Apply manual ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed May 11, 2023
1 parent b5ecff6 commit 646818a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 2 additions & 2 deletions k_diffusion/gns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
except AttributeError as ae:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') from ae
self._clear_state()

def _clear_state(self):
Expand Down
14 changes: 10 additions & 4 deletions k_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,11 @@ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
stacklevel=1,
)

return self._get_closed_form_lr()

Expand Down Expand Up @@ -219,8 +222,11 @@ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
stacklevel=1,
)

return self._get_closed_form_lr()

Expand Down

0 comments on commit 646818a

Please sign in to comment.