Skip to content

Commit

Permalink
improve interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Apr 23, 2024
1 parent 28d16b6 commit 73e4fba
Showing 1 changed file with 49 additions and 22 deletions.
71 changes: 49 additions & 22 deletions logix/logix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from typing import Optional, Iterable, Dict, Any, List
from typing import Optional, Iterable, Dict, Any, List, Union
from dataclasses import asdict
import yaml
from functools import reduce
Expand Down Expand Up @@ -73,6 +73,10 @@ def __init__(
config=self.logging_config, state=self.state, binfo=self.binfo
)

# Log data
self.log_dataset = None
self.log_dataloader = None

# Analysis
self.influence = InfluenceFunction(
config=self.influence_config, state=self.state
Expand Down Expand Up @@ -185,7 +189,7 @@ def add_lora(
if watch:
self.watch(model)

def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):
def log(self, data_id: Any, mask: Optional[torch.Tensor] = None) -> None:
"""
Logs the data. This is an experimental feature for now.
Expand Down Expand Up @@ -237,7 +241,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
self.logger.update()

def build_log_dataset(self):
def build_log_dataset(self) -> torch.utils.data.Dataset:
"""
Constructs the log dataset from the stored logs. This dataset can then be used
for analysis or visualization.
Expand All @@ -246,12 +250,15 @@ def build_log_dataset(self):
LogDataset:
An instance of LogDataset containing the logged data.
"""
log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config)
return log_dataset
if self.log_dataset is None:
self.log_dataset = LogDataset(
log_dir=self.log_dir, config=self.influence_config
)
return self.log_dataset

def build_log_dataloader(
self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False
):
) -> torch.utils.data.DataLoader:
"""
Constructs a DataLoader for the log dataset. This is useful for batch processing
of logged data during analysis.
Expand All @@ -260,19 +267,20 @@ def build_log_dataloader(
DataLoader:
A DataLoader instance for the log dataset.
"""
log_dataset = self.build_log_dataset()
collate_fn = None
if not self.flatten:
collate_fn = collate_nested_dicts
log_dataloader = torch.utils.data.DataLoader(
log_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
collate_fn=collate_fn,
)
return log_dataloader
if self.log_dataloader is None:
log_dataset = self.build_log_dataset()
collate_fn = None
if not self.flatten:
collate_fn = collate_nested_dicts
self.log_dataloader = torch.utils.data.DataLoader(
log_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
collate_fn=collate_fn,
)
return self.log_dataloader

def get_log(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
Expand All @@ -295,7 +303,13 @@ def get_covariance_svd_state(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
return self.state.get_covariance_svd_state()

def compute_influence(self, src_log, tgt_log, mode="dot", precondition=True):
def compute_influence(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
tgt_log: Dict[str, Dict[str, torch.Tensor]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Front-end interface for computing influence scores. It calls the
`compute_influence` method of the `InfluenceFunction` class.
Expand All @@ -310,7 +324,13 @@ def compute_influence(self, src_log, tgt_log, mode="dot", precondition=True):
src_log, tgt_log, mode=mode, precondition=precondition
)

def compute_influence_all(self, src_log, loader, mode="dot", precondition=True):
def compute_influence_all(
self,
src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
loader: Optional[torch.utils.data.DataLoader] = None,
mode: Optional[str] = "dot",
precondition: Optional[str] = True,
) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Front-end interface for computing influence scores against all train data in the log.
It calls the `compute_influence_all` method of the `InfluenceFunction` class.
Expand All @@ -321,11 +341,17 @@ def compute_influence_all(self, src_log, loader, mode="dot", precondition=True):
mode (str, optional): Influence function mode. Defaults to "dot".
precondition (bool, optional): Whether to precondition the gradients. Defaults to True.
"""
src_log = src_log if src_log is not None else self.get_log()
loader = loader if loader is not None else self.build_log_dataloader()
return self.influence.compute_influence_all(
src_log, loader, mode=mode, precondition=precondition
)

def compute_self_influence(self, src_log, precondition=True):
def compute_self_influence(
self,
src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
precondition: Optional[bool] = True,
) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Front-end interface for computing self-influence scores. It calls the
`compute_self_influence` method of the `InfluenceFunction` class.
Expand All @@ -334,6 +360,7 @@ def compute_self_influence(self, src_log, precondition=True):
src_log (Tuple[str, Dict[str, Dict[str, torch.Tensor]]]): Log of source gradients
precondition (bool, optional): Whether to precondition the gradients. Defaults to True.
"""
src_log = src_log if src_log is not None else self.get_log()
return self.influence.compute_self_influence(src_log, precondition=precondition)

def save_config(self) -> None:
Expand Down

0 comments on commit 73e4fba

Please sign in to comment.