From 73e4fba4c112ec93e328dbad27f071d75a6e8122 Mon Sep 17 00:00:00 2001 From: sangkeun00 Date: Tue, 23 Apr 2024 00:15:59 -0400 Subject: [PATCH] improve interface --- logix/logix.py | 71 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/logix/logix.py b/logix/logix.py index 54920488..0fb732e0 100644 --- a/logix/logix.py +++ b/logix/logix.py @@ -1,6 +1,6 @@ import os -from typing import Optional, Iterable, Dict, Any, List +from typing import Optional, Iterable, Dict, Any, List, Union from dataclasses import asdict import yaml from functools import reduce @@ -73,6 +73,10 @@ def __init__( config=self.logging_config, state=self.state, binfo=self.binfo ) + # Log data + self.log_dataset = None + self.log_dataloader = None + # Analysis self.influence = InfluenceFunction( config=self.influence_config, state=self.state @@ -185,7 +189,7 @@ def add_lora( if watch: self.watch(model) - def log(self, data_id: Any, mask: Optional[torch.Tensor] = None): + def log(self, data_id: Any, mask: Optional[torch.Tensor] = None) -> None: """ Logs the data. This is an experimental feature for now. @@ -237,7 +241,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: """ self.logger.update() - def build_log_dataset(self): + def build_log_dataset(self) -> torch.utils.data.Dataset: """ Constructs the log dataset from the stored logs. This dataset can then be used for analysis or visualization. @@ -246,12 +250,15 @@ def build_log_dataset(self): LogDataset: An instance of LogDataset containing the logged data. """ - log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config) - return log_dataset + if self.log_dataset is None: + self.log_dataset = LogDataset( + log_dir=self.log_dir, config=self.influence_config + ) + return self.log_dataset def build_log_dataloader( self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False - ): + ) -> torch.utils.data.DataLoader: """ Constructs a DataLoader for the log dataset. This is useful for batch processing of logged data during analysis. @@ -260,19 +267,20 @@ def build_log_dataloader( DataLoader: A DataLoader instance for the log dataset. """ - log_dataset = self.build_log_dataset() - collate_fn = None - if not self.flatten: - collate_fn = collate_nested_dicts - log_dataloader = torch.utils.data.DataLoader( - log_dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=False, - collate_fn=collate_fn, - ) - return log_dataloader + if self.log_dataloader is None: + log_dataset = self.build_log_dataset() + collate_fn = None + if not self.flatten: + collate_fn = collate_nested_dicts + self.log_dataloader = torch.utils.data.DataLoader( + log_dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=False, + collate_fn=collate_fn, + ) + return self.log_dataloader def get_log(self) -> Dict[str, Dict[str, torch.Tensor]]: """ @@ -295,7 +303,13 @@ def get_covariance_svd_state(self) -> Dict[str, Dict[str, torch.Tensor]]: """ return self.state.get_covariance_svd_state() - def compute_influence(self, src_log, tgt_log, mode="dot", precondition=True): + def compute_influence( + self, + src_log: Dict[str, Dict[str, torch.Tensor]], + tgt_log: Dict[str, Dict[str, torch.Tensor]], + mode: Optional[str] = "dot", + precondition: Optional[bool] = True, + ) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]: """ Front-end interface for computing influence scores. It calls the `compute_influence` method of the `InfluenceFunction` class. @@ -310,7 +324,13 @@ def compute_influence(self, src_log, tgt_log, mode="dot", precondition=True): src_log, tgt_log, mode=mode, precondition=precondition ) - def compute_influence_all(self, src_log, loader, mode="dot", precondition=True): + def compute_influence_all( + self, + src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + loader: Optional[torch.utils.data.DataLoader] = None, + mode: Optional[str] = "dot", + precondition: Optional[str] = True, + ) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]: """ Front-end interface for computing influence scores against all train data in the log. It calls the `compute_influence_all` method of the `InfluenceFunction` class. @@ -321,11 +341,17 @@ def compute_influence_all(self, src_log, loader, mode="dot", precondition=True): mode (str, optional): Influence function mode. Defaults to "dot". precondition (bool, optional): Whether to precondition the gradients. Defaults to True. """ + src_log = src_log if src_log is not None else self.get_log() + loader = loader if loader is not None else self.build_log_dataloader() return self.influence.compute_influence_all( src_log, loader, mode=mode, precondition=precondition ) - def compute_self_influence(self, src_log, precondition=True): + def compute_self_influence( + self, + src_log: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + precondition: Optional[bool] = True, + ) -> Dict[str, Union[List[str], torch.Tensor, Dict[str, torch.Tensor]]]: """ Front-end interface for computing self-influence scores. It calls the `compute_self_influence` method of the `InfluenceFunction` class. @@ -334,6 +360,7 @@ def compute_self_influence(self, src_log, precondition=True): src_log (Tuple[str, Dict[str, Dict[str, torch.Tensor]]]): Log of source gradients precondition (bool, optional): Whether to precondition the gradients. Defaults to True. """ + src_log = src_log if src_log is not None else self.get_log() return self.influence.compute_self_influence(src_log, precondition=precondition) def save_config(self) -> None: