diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 607484a..30523ef 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -316,7 +316,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _release_covariance_matrices(self) -> None: """Clears the stored activation and pseudo-gradient covariance matrices from memory.""" @@ -491,7 +493,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _release_lambda_matrix(self) -> None: """Clears the stored Lambda matrix from memory.""" @@ -607,7 +611,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _release_preconditioned_gradient(self) -> None: """Clears the preconditioned per-sample-gradient from memory.""" @@ -727,7 +733,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _register_self_score_hooks(self) -> None: """Installs forward and backward hooks for computation of self-influence scores.""" @@ -785,7 +793,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def release_scores(self) -> None: """Clears the influence scores from memory."""