Skip to content

Commit

Permalink
Try unscale
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 10, 2024
1 parent dd69813 commit c1eed9d
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions kronfluence/module/tracker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,18 @@ def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torc
torch.Tensor:
The preprocessed gradient.
"""
original_dtype = output_gradient.dtype
# original_dtype = output_gradient.dtype
# output_gradient = output_gradient.to(dtype=target_dtype)
# if self.module.gradient_scale != 1.0:
# if original_dtype != target_dtype:
# output_gradient.mul_(self.module.gradient_scale)
# else:
# output_gradient = output_gradient * self.module.gradient_scale
# return output_gradient
output_gradient = output_gradient.to(dtype=target_dtype)
if self.module.gradient_scale != 1.0:
if original_dtype != target_dtype:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale
output_gradient = output_gradient * self.module.gradient_scale
output_gradient = output_gradient.to(dtype=target_dtype)
return output_gradient

def register_hooks(self) -> None:
Expand Down

0 comments on commit c1eed9d

Please sign in to comment.