diff --git a/examples/mnist/compute_influences.py b/examples/mnist/compute_influences.py index 9046ad2..6405f35 100644 --- a/examples/mnist/compute_influences.py +++ b/examples/mnist/compute_influences.py @@ -20,6 +20,7 @@ parser.add_argument("--hessian", type=str, default="none") parser.add_argument("--lora", type=str, default="none") parser.add_argument("--save", type=str, default="grad") +parser.add_argument("--flatten", action="store_true") args = parser.parse_args() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -56,7 +57,9 @@ logix.finalize() # Influence Analysis -log_loader = logix.build_log_dataloader(batch_size=64, num_workers=0) +log_loader = logix.build_log_dataloader( + batch_size=64, num_workers=0, flatten=args.flatten +) # logix.add_analysis({"influence": InfluenceFunction}) logix.setup({"log": "grad"}) diff --git a/logix/analysis/influence_function.py b/logix/analysis/influence_function.py index ba2fe11..7526221 100644 --- a/logix/analysis/influence_function.py +++ b/logix/analysis/influence_function.py @@ -116,7 +116,7 @@ def compute_influence( log=src, path=self._state.get_state("model_module")["path"] ) tgt = tgt.to(device=src.device) - total_influence += cross_dot_product(src, tgt) + total_influence["total"] += cross_dot_product(src, tgt) else: synchronize_device(src, tgt) for module_name in src.keys(): @@ -158,8 +158,8 @@ def compute_influence( assert value.shape[1] == len(tgt_ids) total_influence[key] = value.cpu() - result["src_ids"] = src_ids - result["tgt_ids"] = tgt_ids + result["src_ids"] = list(src_ids) + result["tgt_ids"] = list(tgt_ids) result["influence"] = ( total_influence.pop("total") if influence_groups is None @@ -266,16 +266,9 @@ def compute_influence_all( influence_groups=influence_groups, damping=damping, ) - - # if result_all is None: - # result_all = result - # else: merge_influence_results(result_all, result, axis="tgt") if save: - # if self.influence_scores is None: - # self.influence_scores = result_all - # else: merge_influence_results(self.influence_scores, result_all, axis="src") return result_all diff --git a/logix/analysis/influence_function_utils.py b/logix/analysis/influence_function_utils.py index c5120e3..a7351af 100644 --- a/logix/analysis/influence_function_utils.py +++ b/logix/analysis/influence_function_utils.py @@ -60,7 +60,7 @@ def precondition_raw( damping: Optional[float] = None, ) -> Dict[str, Dict[str, torch.Tensor]]: preconditioned = nested_dict() - cov_inverse = state.get_covariance_inverse_state() + cov_inverse = state.get_covariance_inverse_state(damping=damping) for module_name in src.keys(): device = src[module_name]["grad"].device grad_cov_inverse = cov_inverse[module_name]["grad"].to(device=device) diff --git a/logix/state.py b/logix/state.py index 3690758..a9877c7 100644 --- a/logix/state.py +++ b/logix/state.py @@ -1,5 +1,6 @@ import logging import os +from typing import Optional, Dict, Tuple import torch import torch.distributed as dist @@ -84,38 +85,41 @@ def covariance_svd(self) -> None: ) @torch.no_grad() - def covariance_inverse(self, set_attr: bool = False) -> None: + def covariance_inverse(self, damping: Optional[float] = None) -> None: """ Compute the inverse of the covariance. """ self.register_state("covariance_inverse_state", save=True) for module_name, module_state in self.covariance_state.items(): - for mode, covariance in module_state.items(): + for mode, cov in module_state.items(): + damping_module = ( + 0.1 * torch.trace(cov) / cov.size(0) if damping is None else damping + ) self.covariance_inverse_state[module_name][mode] = torch.inverse( - covariance - + 0.1 - * torch.trace(covariance) - * torch.eye(covariance.shape[0]).to(device=covariance.device) - / covariance.shape[0] + cov + damping_module * torch.eye(cov.size(0)).to(device=cov.device) ) - def get_covariance_state(self): + def get_covariance_state(self) -> Dict[str, Dict[str, torch.Tensor]]: """ Return the covariance state. """ return self.covariance_state - def get_covariance_inverse_state(self): + def get_covariance_inverse_state( + self, damping: Optional[float] = None + ) -> Dict[str, Dict[str, torch.Tensor]]: """ Return the covariance inverse state. If the state is not computed, compute it first. """ if not hasattr(self, "covariance_inverse_state"): - self.covariance_inverse() + self.covariance_inverse(damping=damping) return self.covariance_inverse_state - def get_covariance_svd_state(self): + def get_covariance_svd_state( + self, + ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], Dict[str, Dict[str, torch.Tensor]]]: """ Return the covariance SVD state. If the state is not computed, compute it first. @@ -213,7 +217,7 @@ def load_state(self, log_dir: str) -> None: state_dict = torch.load(os.path.join(state_log_dir, f"{state_name}.pt")) setattr(self, state_name, state_dict) - def set_state(self, state_name, **kwargs): + def set_state(self, state_name: str, **kwargs) -> None: """ set_state sets the state for the given state_name with input kwargs. """ @@ -223,7 +227,7 @@ def set_state(self, state_name, **kwargs): state = getattr(self, state_name) state[key] = value - def get_state(self, state_name): + def get_state(self, state_name: str) -> Dict[str, Dict[str, torch.Tensor]]: return getattr(self, state_name) def clear_log_state(self) -> None: diff --git a/logix/utils.py b/logix/utils.py index 55d29ff..4417712 100644 --- a/logix/utils.py +++ b/logix/utils.py @@ -174,9 +174,9 @@ def merge_logs(log_list): def flatten_log(log, path) -> torch.Tensor: flat_log_list = [] for module, log_type in path: - log = log[module][log_type] - bsz = log.shape[0] - flat_log_list.append(log.view(bsz, -1)) + log_module = log[module][log_type] + bsz = log_module.shape[0] + flat_log_list.append(log_module.reshape(bsz, -1)) flat_log = torch.cat(flat_log_list, dim=1) return flat_log