Skip to content

Commit

Permalink
Final covariance cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 6, 2024
1 parent 288427e commit d0154f1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def __init__(

def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
"""A forward pass of the tracked module. This should have identical behavior to that of the original module."""
# return self.original_module(inputs + self._constant, *args, **kwargs)
return self.original_module(inputs, *args, **kwargs) + self._constant
outputs = self.original_module(inputs, *args, **kwargs)
if outputs.requires_grad:
return outputs
return outputs + self._constant

def prepare_storage(self, device: torch.device) -> None:
"""Performs any necessary operations on storage before computing any metrics."""
Expand Down
4 changes: 2 additions & 2 deletions kronfluence/module/tracker/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
with torch.no_grad():
# Computes and updates activation covariance during forward pass.
input_activation = (
inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype, copy=True)
inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype,
copy=self.module.attention_mask is not None)
)
self._update_activation_covariance_matrix(input_activation=input_activation)
self.cached_hooks.append(outputs.register_hook(backward_hook))
Expand All @@ -107,7 +108,6 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient = output_gradient * self.module.gradient_scale
self._update_gradient_covariance_matrix(output_gradient=output_gradient)

# self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook))
self.registered_hooks.append(self.module.register_forward_hook(forward_hook))

def exist(self) -> bool:
Expand Down

0 comments on commit d0154f1

Please sign in to comment.