diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index a96f810..047d28d 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -136,8 +136,10 @@ def __init__( def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: """A forward pass of the tracked module. This should have identical behavior to that of the original module.""" - # return self.original_module(inputs + self._constant, *args, **kwargs) - return self.original_module(inputs, *args, **kwargs) + self._constant + outputs = self.original_module(inputs, *args, **kwargs) + if outputs.requires_grad: + return outputs + return outputs + self._constant def prepare_storage(self, device: torch.device) -> None: """Performs any necessary operations on storage before computing any metrics.""" diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py index 917c226..e3798cc 100644 --- a/kronfluence/module/tracker/factor.py +++ b/kronfluence/module/tracker/factor.py @@ -87,7 +87,8 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. with torch.no_grad(): # Computes and updates activation covariance during forward pass. input_activation = ( - inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype, copy=True) + inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype, + copy=self.module.attention_mask is not None) ) self._update_activation_covariance_matrix(input_activation=input_activation) self.cached_hooks.append(outputs.register_hook(backward_hook)) @@ -107,7 +108,6 @@ def backward_hook(output_gradient: torch.Tensor) -> None: output_gradient = output_gradient * self.module.gradient_scale self._update_gradient_covariance_matrix(output_gradient=output_gradient) - # self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) def exist(self) -> bool: