Skip to content

Commit

Permalink
Minor fix (pylint)
Browse files Browse the repository at this point in the history
  • Loading branch information
xeon27 committed Mar 16, 2024
1 parent ab153ad commit 3737318
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _release_covariance_matrices(self) -> None:
"""Clears the stored activation and pseudo-gradient covariance matrices from memory."""
Expand Down Expand Up @@ -491,7 +493,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _release_lambda_matrix(self) -> None:
"""Clears the stored Lambda matrix from memory."""
Expand Down Expand Up @@ -607,7 +611,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _release_preconditioned_gradient(self) -> None:
"""Clears the preconditioned per-sample-gradient from memory."""
Expand Down Expand Up @@ -727,7 +733,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _register_self_score_hooks(self) -> None:
"""Installs forward and backward hooks for computation of self-influence scores."""
Expand Down Expand Up @@ -785,7 +793,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def release_scores(self) -> None:
"""Clears the influence scores from memory."""
Expand Down

0 comments on commit 3737318

Please sign in to comment.