From 6c9ccd14da0d040e7a6c9a5202100095f0227300 Mon Sep 17 00:00:00 2001 From: sangkeun00 Date: Fri, 31 May 2024 15:54:24 -0400 Subject: [PATCH] add gradietn accumulation in HF Trainer integration --- examples/huggingface/bert_influence.py | 2 ++ examples/huggingface/gpt_influence.py | 2 ++ logix/huggingface/callback.py | 35 +++++++++++++------------- logix/logix.py | 15 ----------- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/examples/huggingface/bert_influence.py b/examples/huggingface/bert_influence.py index ea3d8ff..237bdd8 100644 --- a/examples/huggingface/bert_influence.py +++ b/examples/huggingface/bert_influence.py @@ -17,6 +17,7 @@ def main(): parser.add_argument("--config_path", type=str, default="./config.yaml") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--data_name", type=str, default="sst2") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) args = parser.parse_args() set_seed(0) @@ -41,6 +42,7 @@ def main(): output_dir="./output", num_train_epochs=1, per_device_train_batch_size=args.batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, report_to="none", ) diff --git a/examples/huggingface/gpt_influence.py b/examples/huggingface/gpt_influence.py index 26ffa09..ba8b25d 100644 --- a/examples/huggingface/gpt_influence.py +++ b/examples/huggingface/gpt_influence.py @@ -17,6 +17,7 @@ def main(): parser.add_argument("--config_path", type=str, default="./config.yaml") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--data_name", type=str, default="sst2") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) args = parser.parse_args() set_seed(0) @@ -41,6 +42,7 @@ def main(): num_train_epochs=1, per_device_train_batch_size=args.batch_size, report_to="none", + gradient_accumulation_steps=args.gradient_accumulation_steps, ) LogIXTrainer = patch_trainer(Trainer) diff --git a/logix/huggingface/callback.py b/logix/huggingface/callback.py index 0ad960b..a052923 100644 --- a/logix/huggingface/callback.py +++ b/logix/huggingface/callback.py @@ -3,6 +3,7 @@ from transformers.trainer import TrainerCallback from logix import LogIX, LogIXScheduler +from logix.utils import merge_logs from logix.huggingface.arguments import LogIXArguments @@ -17,6 +18,8 @@ def __init__( self.logix_scheduler = iter(logix_scheduler) self.args = args + self.accumulated_log = [] + self._log_dataloader = None def on_init_end(self, args, state, control, **kwargs): @@ -51,35 +54,31 @@ def on_train_begin(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs): if self.args.mode == "influence": - test_log = self.logix.get_log() + self.accumulated_log.append(self.logix.get_log(copy=True)) + accumulated_log = merge_logs(self.accumulated_log) + self.logix.influence.compute_influence_all( - test_log, + accumulated_log, self.log_dataloader(), mode=self.args.influence_mode, damping=self.args.influence_damping, save=True, ) + + self.accumulated_log = [] elif self.args.mode == "self_influence": - test_log = self.logix.get_log() + self.accumulated_log.append(self.logix.get_log(copy=True)) + accumulated_log = merge_logs(self.accumulated_log) + self.logix.influence.compute_self_influence( - test_log, damping=self.args.influence_damping + accumulated_log, damping=self.args.influence_damping ) + self.accumulated_log = [] + def on_substep_end(self, args, state, control, **kwargs): - if self.args.mode == "influence": - test_log = self.logix.get_log() - self.logix.influence.compute_influence_all( - test_log, - self.log_dataloader(), - mode=self.args.influence_mode, - damping=self.args.influence_damping, - save=True, - ) - elif self.args.mode == "self_influence": - test_log = self.logix.get_log() - self.logix.influence.compute_self_influence( - test_log, damping=self.args.influence_damping - ) + if self.args.mode in ["influence", "self_influence"]: + self.accumulated_log.append(self.logix.get_log(copy=True)) def log_dataloader(self): if self._log_dataloader is None: diff --git a/logix/logix.py b/logix/logix.py index 9daff68..1ced0f2 100644 --- a/logix/logix.py +++ b/logix/logix.py @@ -434,21 +434,6 @@ def compute_self_influence( ) return result - def save_config(self) -> None: - """ - Save LogIX state to disk. - """ - config_file = os.path.join(self.log_dir, "config.yaml") - config_dict = asdict(self.config) - with open(config_file, "w", encoding="utf-8") as f: - yaml.dump(config_dict, f, default_flow_style=False) - - def save_state(self) -> None: - """ - Save Hessian state to disk. - """ - self.state.save_state(self.log_dir) - def save_lora(self) -> None: """ Save LoRA state to disk.