diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 2eb9b32..062a561 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from einconv.utils import get_conv_paddings from einops import rearrange, reduce -from opt_einsum import DynamicProgramming, contract_expression +from opt_einsum import DynamicProgramming, contract_expression, contract from torch import nn from torch.nn.modules.utils import _pair @@ -160,7 +160,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - summed_gradient = torch.einsum("bci,bco->io", output_gradient, input_activation) + summed_gradient = contract("bci,bco->io", output_gradient, input_activation) return summed_gradient.view((1, *summed_gradient.size())) def compute_per_sample_gradient( @@ -171,7 +171,7 @@ def compute_per_sample_gradient( input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - per_sample_gradient = torch.einsum("bci,bco->bio", output_gradient, input_activation) + per_sample_gradient = contract("bci,bco->bio", output_gradient, input_activation) if self.per_sample_gradient_process_fnc is not None: per_sample_gradient = self.per_sample_gradient_process_fnc( module_name=self.name, gradient=per_sample_gradient diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 35aeadf..bb2b739 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -2,7 +2,7 @@ import torch from einops import rearrange -from opt_einsum import DynamicProgramming, contract_expression +from opt_einsum import DynamicProgramming, contract, contract_expression from torch import nn from kronfluence.module.tracked_module import TrackedModule @@ -62,14 +62,14 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) - summed_gradient = torch.einsum("b...i,b...o->io", output_gradient, input_activation) - return summed_gradient.view((1, *summed_gradient.size())) + summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(0) + return summed_gradient def compute_per_sample_gradient( self, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) - per_sample_gradient = torch.einsum("b...i,b...o->bio", output_gradient, input_activation) + per_sample_gradient = contract("b...i,b...o->bio", output_gradient, input_activation) if self.per_sample_gradient_process_fnc is not None: per_sample_gradient = self.per_sample_gradient_process_fnc( module_name=self.name, gradient=per_sample_gradient @@ -82,53 +82,37 @@ def compute_pairwise_score( input_activation = self._flatten_input_activation(input_activation=input_activation) if isinstance(preconditioned_gradient, list): left_mat, right_mat = preconditioned_gradient - if self.einsum_expression is None: if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: - self.einsum_expression = contract_expression( - "qik,qko,bti,bto->qbt", - left_mat.shape, - right_mat.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), - ) + expr = "qik,qko,bti,bto->qbt" else: - self.einsum_expression = contract_expression( - "qik,qko,b...i,b...o->qb", - left_mat.shape, - right_mat.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), - ) - return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) - - if self.einsum_expression is None: - if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: + expr = "qik,qko,b...i,b...o->qb" self.einsum_expression = contract_expression( - "qio,bti,bto->qbt", - preconditioned_gradient.shape, + expr, + left_mat.shape, + right_mat.shape, output_gradient.shape, input_activation.shape, optimize=DynamicProgramming( search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" ), ) + return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) + + if self.einsum_expression is None: + if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: + expr = "qio,bti,bto->qbt" else: - self.einsum_expression = contract_expression( - "qio,b...i,b...o->qb", - preconditioned_gradient.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), - ) + expr = "qio,b...i,b...o->qb" + self.einsum_expression = contract_expression( + expr, + preconditioned_gradient.shape, + output_gradient.shape, + input_activation.shape, + optimize=DynamicProgramming( + search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" + ), + ) return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) def compute_self_measurement_score( diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py index bd42afd..6d0f969 100644 --- a/kronfluence/module/tracker/factor.py +++ b/kronfluence/module/tracker/factor.py @@ -259,9 +259,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None: output_gradient.mul_(self.module.gradient_scale) else: output_gradient = output_gradient * self.module.gradient_scale - self.cached_activations = self.cached_activations.to(device=output_gradient.device) per_sample_gradient = self.module.compute_per_sample_gradient( - input_activation=self.cached_activations, + input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, ).to(dtype=self.module.factor_args.lambda_dtype) self.clear_all_cache() @@ -281,9 +280,8 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None: else: output_gradient = output_gradient * self.module.gradient_scale cached_activation = self.cached_activations.pop() - cached_activation = cached_activation.to(device=output_gradient.device) per_sample_gradient = self.module.compute_per_sample_gradient( - input_activation=cached_activation, + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) if self.cached_per_sample_gradient is None: diff --git a/kronfluence/module/tracker/gradient.py b/kronfluence/module/tracker/gradient.py index e88f835..82863ec 100644 --- a/kronfluence/module/tracker/gradient.py +++ b/kronfluence/module/tracker/gradient.py @@ -1,62 +1,65 @@ -from typing import List, Tuple +from typing import Tuple import torch import torch.distributed as dist import torch.nn as nn -from kronfluence.factor.config import FactorConfig from kronfluence.module.tracker.base import BaseTracker -from kronfluence.utils.constants import ( - ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, - AGGREGATED_GRADIENT_NAME, - PRECONDITIONED_GRADIENT_NAME, -) +from kronfluence.utils.constants import AGGREGATED_GRADIENT_NAME class GradientTracker(BaseTracker): - """Tracks and computes summed gradient for a given module.""" + """Tracks and computes aggregated gradient for a given module.""" def register_hooks(self) -> None: - """Sets up hooks to compute and keep track of summed gradient.""" + """Sets up hooks to compute and keep track of aggregated gradient.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.per_sample_gradient_dtype, - copy=True, - ) - - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation + with torch.no_grad(): + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation - outputs.register_hook( - shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook - ) + self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self.cached_activations is None: self._raise_cache_not_found_exception() - - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype - ) + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.per_sample_gradient_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale + + if isinstance(self.cached_activations, list): + cached_activation = self.cached_activations.pop() + else: + cached_activation = self.cached_activations if self.module.per_sample_gradient_process_fnc is None: summed_gradient = self.module.compute_summed_gradient( - input_activation=self.cached_activations.to(device=output_gradient.device), + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) else: summed_gradient = self.module.compute_per_sample_gradient( - input_activation=self.cached_activations.to(device=output_gradient.device), + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ).sum(dim=0, keepdim=True) self.clear_all_cache() @@ -65,49 +68,16 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False) self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient) - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype - ) - cached_activation = self.cached_activations.pop() - if self.module.per_sample_gradient_process_fnc is None: - summed_gradient = self.module.compute_summed_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient, - ) - else: - summed_gradient = self.module.comute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient, - ).sum(dim=0, keepdim=True) - - if self.cached_per_sample_gradient is None: - self.cached_per_sample_gradient = torch.zeros_like(summed_gradient, requires_grad=False) - self.cached_per_sample_gradient.add_(summed_gradient) - - self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) - - def exist(self) -> bool: - return self.module.storage[AGGREGATED_GRADIENT_NAME] is not None + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @torch.no_grad() def finalize_iteration(self): - """Computes preconditioned gradient using cached per-sample gradients.""" - if not self.module.factor_args.has_shared_parameters: - return - if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: - self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like( - self.cached_per_sample_gradient, requires_grad=False - ) - self.module.storage[AGGREGATED_GRADIENT_NAME].add_(self.cached_per_sample_gradient) + """Clears all cached activations from memory.""" self.clear_all_cache() - def release_memory(self) -> None: - """Clears summed gradients from memory.""" - del self.module.storage[AGGREGATED_GRADIENT_NAME] - self.module.storage[AGGREGATED_GRADIENT_NAME] = None - self.clear_all_cache() + def exist(self) -> bool: + """Checks if aggregated gradient is available.""" + return self.module.storage[AGGREGATED_GRADIENT_NAME] is not None def synchronize(self, num_processes: int = 1) -> None: """Aggregates summed gradient across multiple devices or nodes in a distributed setting.""" @@ -124,3 +94,8 @@ def synchronize(self, num_processes: int = 1) -> None: tensor=self.module.storage[AGGREGATED_GRADIENT_NAME], op=dist.ReduceOp.SUM, ) + + def release_memory(self) -> None: + """Clears aggregated gradients from memory.""" + self.clear_all_cache() + self.module.storage[AGGREGATED_GRADIENT_NAME] = None diff --git a/kronfluence/module/tracker/pairwise_score.py b/kronfluence/module/tracker/pairwise_score.py index 029afc4..a8c400c 100644 --- a/kronfluence/module/tracker/pairwise_score.py +++ b/kronfluence/module/tracker/pairwise_score.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Tuple import torch import torch.nn as nn @@ -23,8 +23,9 @@ def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tenso per_sample_gradient (torch.Tensor): The per-sample-gradient tensor for the given batch. """ - if isinstance(self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], list): - left_mat, right_mat = self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] + precondition_name = ACCUMULATED_PRECONDITIONED_GRADIENT_NAME + if isinstance(self.module.storage[precondition_name], list): + left_mat, right_mat = self.module.storage[precondition_name] if self.module.einsum_expression is None: self.module.einsum_expression = contract_expression( "qki,toi,qok->qt", @@ -35,85 +36,81 @@ def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tenso search_outer=True, minimize="size" if self.module.score_args.einsum_minimize_size else "flops" ), ) - self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = self.module.einsum_expression( - right_mat, per_sample_gradient, left_mat - ) + scores = self.module.einsum_expression(right_mat, per_sample_gradient, left_mat) else: - self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = torch.einsum( + scores = torch.einsum( "qio,tio->qt", - self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], + self.module.storage[precondition_name], per_sample_gradient, ) + if self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] is not None: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME].add_(scores) + else: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = scores + def register_hooks(self) -> None: """Sets up hooks to compute pairwise influence scores.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.score_dtype, - copy=True, - ) - - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation + with torch.no_grad(): + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation - outputs.register_hook( - shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook - ) + self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self.cached_activations is None: self._raise_cache_not_found_exception() - - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype - ) + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.score_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale + + if isinstance(self.cached_activations, list): + cached_activation = self.cached_activations.pop() + else: + cached_activation = self.cached_activations + # Computes pairwise influence scores during backward pass. if self.module.per_sample_gradient_process_fnc is None: self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = self.module.compute_pairwise_score( preconditioned_gradient=self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], - input_activation=self.cached_activations.to(device=output_gradient.device), + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) self.clear_all_cache() else: per_sample_gradient = self.module.compute_per_sample_gradient( - input_activation=self.cached_activations.to(device=output_gradient.device), + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) self.clear_all_cache() self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient) - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype - ) - cached_activation = self.cached_activations.pop() - per_sample_gradient = self.module.compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient, - ) - if self.cached_per_sample_gradient is None: - self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self.cached_per_sample_gradient.add_(per_sample_gradient) - - self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @torch.no_grad() def finalize_iteration(self) -> None: - """Computes pairwise influence scores using cached per-sample gradients.""" - if self.module.factor_args.has_shared_parameters: - self._compute_pairwise_score_with_gradient(per_sample_gradient=self.cached_per_sample_gradient) + """Clears all cached activations from memory.""" self.clear_all_cache() def exist(self) -> bool: @@ -134,10 +131,6 @@ def finalize_all_iterations(self) -> None: self._compute_pairwise_score_with_gradient( per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME] ) - del ( - self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], - self.module.storage[PRECONDITIONED_GRADIENT_NAME], - ) self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None self.clear_all_cache() @@ -145,5 +138,4 @@ def finalize_all_iterations(self) -> None: def release_memory(self) -> None: """Releases pairwise scores from memory.""" self.clear_all_cache() - del self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = None diff --git a/kronfluence/module/tracker/precondition.py b/kronfluence/module/tracker/precondition.py index f9ccd60..a566bfe 100644 --- a/kronfluence/module/tracker/precondition.py +++ b/kronfluence/module/tracker/precondition.py @@ -83,48 +83,63 @@ def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> def register_hooks(self) -> None: """Sets up hooks to compute preconditioned per-sample gradient.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.per_sample_gradient_dtype, - copy=True, - ) - - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation + with torch.no_grad(): + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation - outputs.register_hook( - shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + self.cached_hooks.append( + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) ) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self.cached_activations is None: self._raise_cache_not_found_exception() - - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype - ) + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.per_sample_gradient_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, ).to(dtype=self.module.score_args.precondition_dtype) self.clear_all_cache() + # Computes preconditioned per-sample gradient during backward pass. self._compute_preconditioned_gradient(per_sample_gradient=per_sample_gradient) @torch.no_grad() def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype - ) + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.per_sample_gradient_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale cached_activation = self.cached_activations.pop() per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=cached_activation.to(device=output_gradient.device), @@ -132,9 +147,10 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None: ) if self.cached_per_sample_gradient is None: self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + # Aggregates per-sample gradients during backward pass. self.cached_per_sample_gradient.add_(per_sample_gradient) - self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @torch.no_grad() def finalize_iteration(self) -> None: @@ -148,7 +164,10 @@ def finalize_iteration(self) -> None: def exist(self) -> bool: """Checks if preconditioned gradient is available.""" - return self.module.storage[PRECONDITIONED_GRADIENT_NAME] is not None + return ( + self.module.storage[PRECONDITIONED_GRADIENT_NAME] is not None + or self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] is not None + ) def synchronize(self, num_processes: int = 1) -> None: """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" @@ -223,7 +242,7 @@ def accumulate_iterations(self) -> None: self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat( (accumulated_gradient, gradient), dim=0 ).contiguous() - del self.module.storage[PRECONDITIONED_GRADIENT_NAME] + del gradient, self.module.storage[PRECONDITIONED_GRADIENT_NAME] self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None def finalize_all_iterations(self) -> None: @@ -233,16 +252,11 @@ def finalize_all_iterations(self) -> None: dtype=self.module.score_args.precondition_dtype ) self._compute_preconditioned_gradient(per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME]) - del self.module.storage[AGGREGATED_GRADIENT_NAME] self.module.storage[AGGREGATED_GRADIENT_NAME] = None self.accumulate_iterations() def release_memory(self) -> None: """Clears preconditioned gradients from memory.""" - del ( - self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], - self.module.storage[PRECONDITIONED_GRADIENT_NAME], - ) self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None self.clear_all_cache() diff --git a/kronfluence/module/tracker/self_score.py b/kronfluence/module/tracker/self_score.py index 1d84175..a08abd6 100644 --- a/kronfluence/module/tracker/self_score.py +++ b/kronfluence/module/tracker/self_score.py @@ -60,36 +60,41 @@ def _compute_self_score(self, per_sample_gradient: torch.Tensor) -> None: def register_hooks(self) -> None: """Sets up hooks to compute self-influence scores.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.per_sample_gradient_dtype, - copy=True, - ) - - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - - outputs.register_hook( + with torch.no_grad(): + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + self.cached_hooks.append(outputs.register_hook( shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook - ) + )) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self.cached_activations is None: self._raise_cache_not_found_exception() - - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype - ) + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.per_sample_gradient_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, @@ -99,9 +104,16 @@ def backward_hook(output_gradient: torch.Tensor) -> None: @torch.no_grad() def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype - ) + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.per_sample_gradient_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale cached_activation = self.cached_activations.pop() per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=cached_activation.to(device=output_gradient.device), @@ -111,7 +123,7 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None: self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) self.cached_per_sample_gradient.add_(per_sample_gradient) - self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @torch.no_grad() def finalize_iteration(self) -> None: @@ -128,12 +140,13 @@ def exist(self) -> bool: return self.module.storage[SELF_SCORE_VECTOR_NAME] is not None def accumulate_iterations(self) -> None: - """Removes self-scores from memory after a single iteration.""" + """Removes self-influence scores from memory after a single iteration.""" self.release_memory() def release_memory(self) -> None: - """Releases pairwise scores from memory.""" + """Releases self-influence scores from memory.""" self.clear_all_cache() + self.storage_at_device = False del self.module.storage[SELF_SCORE_VECTOR_NAME] self.module.storage[SELF_SCORE_VECTOR_NAME] = None @@ -144,42 +157,41 @@ class SelfScoreWithMeasurementTracker(BaseTracker): storage_at_device: bool = False def _compute_self_measurement_score_with_gradient(self, per_sample_gradient: torch.Tensor) -> None: - """Computes self-influence scores using per-sample-gradients. + """Computes self-influence scores with measurement using per-sample-gradients. Args: per_sample_gradient (torch.Tensor): The per-sample-gradient tensor for the given batch. """ - self.module.storage[SELF_SCORE_VECTOR_NAME] = per_sample_gradient.mul_( + scores = per_sample_gradient.mul_( self.module.storage[PRECONDITIONED_GRADIENT_NAME] ).sum(dim=(1, 2)) - del self.module.storage[PRECONDITIONED_GRADIENT_NAME] self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + if self.module.storage[SELF_SCORE_VECTOR_NAME] is None: + self.module.storage[SELF_SCORE_VECTOR_NAME] = scores + else: + self.module.storage[SELF_SCORE_VECTOR_NAME].add_(scores) def register_hooks(self) -> None: """Sets up hooks to compute pairwise influence scores.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - cached_activation = inputs[0].detach() - device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.score_args.score_dtype, - copy=True, - ) - - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - - outputs.register_hook( - shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook - ) + with torch.no_grad(): + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: @@ -192,22 +204,37 @@ def backward_hook(output_gradient: torch.Tensor) -> None: target_device=output_gradient.device, ) self.storage_at_device = True + handle = self.cached_hooks.pop() + handle.remove() + original_dtype = output_gradient.dtype + target_dtype = self.module.score_args.score_dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale + + if isinstance(self.cached_activations, list): + cached_activation = self.cached_activations.pop() + else: + cached_activation = self.cached_activations - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype - ) if self.module.per_sample_gradient_process_fnc is None: - self.module.storage[SELF_SCORE_VECTOR_NAME] = self.module.compute_self_measurement_score( + scores = self.module.compute_self_measurement_score( preconditioned_gradient=self.module.storage[PRECONDITIONED_GRADIENT_NAME], - input_activation=self.cached_activations.to(device=output_gradient.device), + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) - del self.module.storage[PRECONDITIONED_GRADIENT_NAME] self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None self.clear_all_cache() + if self.module.storage[SELF_SCORE_VECTOR_NAME] is None: + self.module.storage[SELF_SCORE_VECTOR_NAME] = scores + else: + self.module.storage[SELF_SCORE_VECTOR_NAME].add_(scores) else: per_sample_gradient = self.module.compute_per_sample_gradient( - input_activation=self.cached_activations.to(device=output_gradient.device), + input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) self.clear_all_cache() @@ -227,25 +254,24 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None: self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) self.cached_per_sample_gradient.add_(per_sample_gradient) - self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @torch.no_grad() def finalize_iteration(self) -> None: - """Computes pairwise influence scores using cached per-sample gradients.""" - if self.module.factor_args.has_shared_parameters: - self._compute_self_measurement_score_with_gradient(per_sample_gradient=self.cached_per_sample_gradient) + """Removes all cached activations from memory.""" self.clear_all_cache() def exist(self) -> bool: - """Checks if pairwise score is available.""" + """Checks if self-influence score is available.""" return self.module.storage[SELF_SCORE_VECTOR_NAME] is not None def accumulate_iterations(self) -> None: - """Removes pairwise scores from memory after a single iteration.""" + """Removes self-influence scores from memory after a single iteration.""" self.release_memory() def release_memory(self) -> None: - """Releases pairwise scores from memory.""" + """Releases self-influence scores from memory.""" self.clear_all_cache() + self.storage_at_device = False del self.module.storage[SELF_SCORE_VECTOR_NAME] self.module.storage[SELF_SCORE_VECTOR_NAME] = None diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py index 23c1774..e401a12 100644 --- a/kronfluence/score/dot_product.py +++ b/kronfluence/score/dot_product.py @@ -66,6 +66,11 @@ def compute_dot_products_with_loader( else: score_chunks[ALL_MODULE_NAME] = [] + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + total_steps = 0 enable_amp = score_args.amp_dtype is not None @@ -92,31 +97,30 @@ def compute_dot_products_with_loader( finalize_iteration(model=model, tracked_module_names=tracked_module_names) if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() - ) + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() + ) else: pairwise_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if pairwise_scores is None: - pairwise_scores = torch.zeros_like( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False - ) - try: - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) - except RuntimeError: - if score_args.compute_per_token_scores: - raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) - raise + for module in cached_module_lst: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) + raise score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) accumulate_iterations(model=model, tracked_module_names=tracked_module_names) if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: state.wait_for_everyone() + del batch, loss total_steps += 1 pbar.update(1) @@ -200,6 +204,7 @@ def compute_aggregated_dot_products_with_loader( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) + del batch, loss pbar.update(1) if state.use_distributed: diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index b391c4b..2a993cb 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -43,7 +43,7 @@ def pairwise_scores_save_path( Args: output_dir (Path): - Directory to save the matrices. + Directory to save or load the matrices. partition (PARTITION_TYPE, optional): Partition information, if any. @@ -97,7 +97,7 @@ def load_pairwise_scores( Partition information, if any. Returns: - FACTOR_TYPE: + SCORE_TYPE: Dictionary of loaded scores. """ save_path = pairwise_scores_save_path( @@ -190,9 +190,7 @@ def compute_pairwise_scores_with_loaders( model=model, factor_name=name, factors=loaded_factors[name], - clone=True, ) - del loaded_factors prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) total_scores_chunks: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {} @@ -245,6 +243,7 @@ def compute_pairwise_scores_with_loaders( # Removes duplicate data points if the dataset is not evenly divisible by the current batch size. truncate(model=model, tracked_module_names=tracked_module_names, keep_size=query_remainder) accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + del query_batch, measurement num_accumulations += 1 if ( @@ -321,7 +320,6 @@ def compute_pairwise_query_aggregated_scores_with_loaders( for name in loaded_factors: set_factors(model=model, factor_name=name, factors=loaded_factors[name]) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) - release_memory() enable_amp = score_args.amp_dtype is not None scaler = GradScaler(enabled=enable_amp) @@ -356,6 +354,7 @@ def compute_pairwise_query_aggregated_scores_with_loaders( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) + del query_batch, measurement pbar.update(1) if state.use_distributed: diff --git a/kronfluence/score/self.py b/kronfluence/score/self.py index 40c6af7..e501b17 100644 --- a/kronfluence/score/self.py +++ b/kronfluence/score/self.py @@ -44,7 +44,18 @@ def self_scores_save_path( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading self-influence scores.""" + """Generates the path for saving or loading self-influence scores. + + Args: + output_dir (Path): + Directory to save or load the matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the score file. + """ if partition is not None: data_partition, module_partition = partition return output_dir / ( @@ -53,25 +64,24 @@ def self_scores_save_path( return output_dir / "self_scores.safetensors" -def self_scores_exist( - output_dir: Path, - partition: Optional[PARTITION_TYPE] = None, -) -> bool: - """Checks if the self-influence scores exist at the specified path.""" - save_path = self_scores_save_path( - output_dir=output_dir, - partition=partition, - ) - return save_path.exists() - - def save_self_scores( output_dir: Path, scores: SCORE_TYPE, partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves self-influence scores to disk.""" + """Saves self-influence scores to disk. + + Args: + output_dir (Path): + Directory to save the scores. + scores (FACTOR_TYPE): + Dictionary of scores to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the scores. + """ save_path = self_scores_save_path( output_dir=output_dir, partition=partition, @@ -83,7 +93,18 @@ def load_self_scores( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> Dict[str, torch.Tensor]: - """Loads self-influence scores from disk.""" + """Loads self-influence scores from disk. + + Args: + output_dir (Path): + Directory to load the scores from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + SCORE_TYPE: + Dictionary of loaded scores. + """ save_path = self_scores_save_path( output_dir=output_dir, partition=partition, @@ -91,6 +112,29 @@ def load_self_scores( return load_file(filename=save_path) +def self_scores_exist( + output_dir: Path, + partition: Optional[PARTITION_TYPE] = None, +) -> bool: + """Checks if self-influence scores exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for scores. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + bool: + `True` if scores exist, `False` otherwise. + """ + save_path = self_scores_save_path( + output_dir=output_dir, + partition=partition, + ) + return save_path.exists() + + def compute_self_scores_with_loaders( loaded_factors: FACTOR_TYPE, model: nn.Module, @@ -106,7 +150,7 @@ def compute_self_scores_with_loaders( Args: loaded_factors (FACTOR_TYPE): - The factor results to load from, before computing the self-influence scores. + Computed factors. model (nn.Module): The model for which self-influence scores will be computed. state (State): @@ -116,14 +160,14 @@ def compute_self_scores_with_loaders( train_loader (data.DataLoader): The data loader that will be used to compute training gradients. score_args (ScoreArguments): - Arguments related to computing self-influence scores. + Arguments for computing self-influence scores. factor_args (FactorArguments): - Arguments related to computing preconditioning factors. + Arguments used to compute factors. tracked_module_names (List[str], optional): A list of module names that self-influence scores will be computed. If not specified, scores will be computed for all available tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: Dict[str, torch.Tensor]: @@ -141,8 +185,7 @@ def compute_self_scores_with_loaders( ) if len(loaded_factors) > 0: for name in loaded_factors: - set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) - del loaded_factors + set_factors(model=model, factor_name=name, factors=loaded_factors[name]) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) dataset_size = len(train_loader.dataset) @@ -154,6 +197,11 @@ def compute_self_scores_with_loaders( else: score_chunks[ALL_MODULE_NAME] = [] + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + total_steps = 0 enable_amp = score_args.amp_dtype is not None scaler = GradScaler(enabled=enable_amp) @@ -187,26 +235,25 @@ def compute_self_scores_with_loaders( finalize_iteration(model=model, tracked_module_names=tracked_module_names) if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() - ) + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() + ) else: self_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if self_scores is None: - self_scores = torch.zeros_like( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False - ) - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + for module in cached_module_lst: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + ) + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) accumulate_iterations(model=model, tracked_module_names=tracked_module_names) if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: state.wait_for_everyone() + del batch, loss total_steps += 1 pbar.update(1) @@ -260,9 +307,7 @@ def compute_self_measurement_scores_with_loaders( model=model, factor_name=name, factors=loaded_factors[name], - clone=True, ) - del loaded_factors prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) dataset_size = len(train_loader.dataset) @@ -274,6 +319,11 @@ def compute_self_measurement_scores_with_loaders( else: score_chunks[ALL_MODULE_NAME] = [] + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + total_steps = 0 enable_amp = score_args.amp_dtype is not None scaler = GradScaler(enabled=enable_amp) @@ -307,6 +357,7 @@ def compute_self_measurement_scores_with_loaders( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) + del measurement set_mode( model=model, @@ -326,22 +377,21 @@ def compute_self_measurement_scores_with_loaders( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) + del batch, loss if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() - ) + for module in cached_module_lst: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() + ) else: self_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if self_scores is None: - self_scores = torch.zeros_like( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False - ) - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + for module in cached_module_lst: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False + ) + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) accumulate_iterations(model=model, tracked_module_names=tracked_module_names) diff --git a/tests/modules/test_matmul.py b/tests/modules/test_matmul.py new file mode 100644 index 0000000..e69de29