Skip to content

Commit

Permalink
fix flattened log loading
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 28, 2024
1 parent 8aae35c commit 689908f
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 28 deletions.
5 changes: 4 additions & 1 deletion examples/mnist/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parser.add_argument("--hessian", type=str, default="none")
parser.add_argument("--lora", type=str, default="none")
parser.add_argument("--save", type=str, default="grad")
parser.add_argument("--flatten", action="store_true")
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -56,7 +57,9 @@
logix.finalize()

# Influence Analysis
log_loader = logix.build_log_dataloader(batch_size=64, num_workers=0)
log_loader = logix.build_log_dataloader(
batch_size=64, num_workers=0, flatten=args.flatten
)

# logix.add_analysis({"influence": InfluenceFunction})
logix.setup({"log": "grad"})
Expand Down
13 changes: 3 additions & 10 deletions logix/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compute_influence(
log=src, path=self._state.get_state("model_module")["path"]
)
tgt = tgt.to(device=src.device)
total_influence += cross_dot_product(src, tgt)
total_influence["total"] += cross_dot_product(src, tgt)
else:
synchronize_device(src, tgt)
for module_name in src.keys():
Expand Down Expand Up @@ -158,8 +158,8 @@ def compute_influence(
assert value.shape[1] == len(tgt_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["tgt_ids"] = tgt_ids
result["src_ids"] = list(src_ids)
result["tgt_ids"] = list(tgt_ids)
result["influence"] = (
total_influence.pop("total")
if influence_groups is None
Expand Down Expand Up @@ -266,16 +266,9 @@ def compute_influence_all(
influence_groups=influence_groups,
damping=damping,
)

# if result_all is None:
# result_all = result
# else:
merge_influence_results(result_all, result, axis="tgt")

if save:
# if self.influence_scores is None:
# self.influence_scores = result_all
# else:
merge_influence_results(self.influence_scores, result_all, axis="src")

return result_all
2 changes: 1 addition & 1 deletion logix/analysis/influence_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def precondition_raw(
damping: Optional[float] = None,
) -> Dict[str, Dict[str, torch.Tensor]]:
preconditioned = nested_dict()
cov_inverse = state.get_covariance_inverse_state()
cov_inverse = state.get_covariance_inverse_state(damping=damping)
for module_name in src.keys():
device = src[module_name]["grad"].device
grad_cov_inverse = cov_inverse[module_name]["grad"].to(device=device)
Expand Down
30 changes: 17 additions & 13 deletions logix/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Optional, Dict, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -84,38 +85,41 @@ def covariance_svd(self) -> None:
)

@torch.no_grad()
def covariance_inverse(self, set_attr: bool = False) -> None:
def covariance_inverse(self, damping: Optional[float] = None) -> None:
"""
Compute the inverse of the covariance.
"""
self.register_state("covariance_inverse_state", save=True)

for module_name, module_state in self.covariance_state.items():
for mode, covariance in module_state.items():
for mode, cov in module_state.items():
damping_module = (
0.1 * torch.trace(cov) / cov.size(0) if damping is None else damping
)
self.covariance_inverse_state[module_name][mode] = torch.inverse(
covariance
+ 0.1
* torch.trace(covariance)
* torch.eye(covariance.shape[0]).to(device=covariance.device)
/ covariance.shape[0]
cov + damping_module * torch.eye(cov.size(0)).to(device=cov.device)
)

def get_covariance_state(self):
def get_covariance_state(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
Return the covariance state.
"""
return self.covariance_state

def get_covariance_inverse_state(self):
def get_covariance_inverse_state(
self, damping: Optional[float] = None
) -> Dict[str, Dict[str, torch.Tensor]]:
"""
Return the covariance inverse state. If the state is not computed, compute
it first.
"""
if not hasattr(self, "covariance_inverse_state"):
self.covariance_inverse()
self.covariance_inverse(damping=damping)
return self.covariance_inverse_state

def get_covariance_svd_state(self):
def get_covariance_svd_state(
self,
) -> Tuple[Dict[str, Dict[str, torch.Tensor]], Dict[str, Dict[str, torch.Tensor]]]:
"""
Return the covariance SVD state. If the state is not computed, compute
it first.
Expand Down Expand Up @@ -213,7 +217,7 @@ def load_state(self, log_dir: str) -> None:
state_dict = torch.load(os.path.join(state_log_dir, f"{state_name}.pt"))
setattr(self, state_name, state_dict)

def set_state(self, state_name, **kwargs):
def set_state(self, state_name: str, **kwargs) -> None:
"""
set_state sets the state for the given state_name with input kwargs.
"""
Expand All @@ -223,7 +227,7 @@ def set_state(self, state_name, **kwargs):
state = getattr(self, state_name)
state[key] = value

def get_state(self, state_name):
def get_state(self, state_name: str) -> Dict[str, Dict[str, torch.Tensor]]:
return getattr(self, state_name)

def clear_log_state(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions logix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def merge_logs(log_list):
def flatten_log(log, path) -> torch.Tensor:
flat_log_list = []
for module, log_type in path:
log = log[module][log_type]
bsz = log.shape[0]
flat_log_list.append(log.view(bsz, -1))
log_module = log[module][log_type]
bsz = log_module.shape[0]
flat_log_list.append(log_module.reshape(bsz, -1))
flat_log = torch.cat(flat_log_list, dim=1)

return flat_log
Expand Down

0 comments on commit 689908f

Please sign in to comment.