diff --git a/examples/bert_influence/compute_influence.py b/examples/bert_influence/compute_influence.py index 3fa56b4e..735f2f89 100644 --- a/examples/bert_influence/compute_influence.py +++ b/examples/bert_influence/compute_influence.py @@ -34,7 +34,10 @@ def single_checkpoint_influence( _, eval_train_loader, test_loader = get_loaders(data_name=data_name) # Set-up - analog = AnaLog(project="test", config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml") + analog = AnaLog( + project="test", + config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml", + ) # Hessian logging analog.watch(model, type_filter=[torch.nn.Linear], lora=False) @@ -52,9 +55,7 @@ def single_checkpoint_influence( logits = outputs.view(-1, outputs.shape[-1]) labels = batch["labels"].view(-1).to(DEVICE) - loss = F.cross_entropy( - logits, labels, reduction="sum", ignore_index=-100 - ) + loss = F.cross_entropy(logits, labels, reduction="sum", ignore_index=-100) loss.backward() analog.finalize() @@ -74,7 +75,10 @@ def single_checkpoint_influence( logits = outputs.view(-1, outputs.shape[-1]) labels = batch["labels"].view(-1).to(DEVICE) loss = F.cross_entropy( - logits, labels, reduction="sum", ignore_index=-100, + logits, + labels, + reduction="sum", + ignore_index=-100, ) loss.backward() analog.finalize() @@ -97,7 +101,10 @@ def single_checkpoint_influence( test_logits = test_outputs.view(-1, outputs.shape[-1]) test_labels = test_batch["labels"].view(-1).to(DEVICE) test_loss = F.cross_entropy( - test_logits, test_labels, reduction="sum", ignore_index=-100, + test_logits, + test_labels, + reduction="sum", + ignore_index=-100, ) test_loss.backward() diff --git a/examples/bert_influence/train.py b/examples/bert_influence/train.py index 2ea5fd97..fd7ee740 100644 --- a/examples/bert_influence/train.py +++ b/examples/bert_influence/train.py @@ -26,6 +26,7 @@ train_loader, _, valid_loader = get_loaders(data_name=args.data_name) model = construct_model(data_name=args.data_name).to(device) + def train( model: nn.Module, loader: torch.utils.data.DataLoader, diff --git a/examples/bert_influence/utils.py b/examples/bert_influence/utils.py index a91a76db..cb3dcefb 100644 --- a/examples/bert_influence/utils.py +++ b/examples/bert_influence/utils.py @@ -84,9 +84,7 @@ def forward( ).logits -def construct_model( - data_name: str, ckpt_path: Union[None, str] = None -) -> nn.Module: +def construct_model(data_name: str, ckpt_path: Union[None, str] = None) -> nn.Module: model = SequenceClassificationModel(data_name) if ckpt_path is not None: model.load_state_dict(torch.load(ckpt_path, map_location="cpu")) @@ -190,4 +188,4 @@ def preprocess_function(examples): batch_size=batch_size, shuffle=split == "train", collate_fn=default_data_collator, - ) \ No newline at end of file + )