diff --git a/logix/huggingface/arguments.py b/logix/huggingface/arguments.py index e888082..10d41d9 100644 --- a/logix/huggingface/arguments.py +++ b/logix/huggingface/arguments.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import List +from typing import List, Optional import torch.nn as nn @@ -47,12 +47,15 @@ class LogIXArguments: input_key: str = field( default="input_ids", metadata={"help": "The dictionary key for 'input_ids'."} ) - influence_damping: float = field( + influence_damping: Optional[float] = field( default=None, metadata={"help": "A damping term in influence functions."} ) influence_mode: str = field( default="dot", metadata={"help": "Influence function mode."} ) + influence_groups: Optional[List[str]] = field( + default=None, metadata={"help": "Influence function groups."} + ) label_key: str = field( default="labels", metadata={"help": "The dictionary key for 'labels'."} ) diff --git a/logix/huggingface/callback.py b/logix/huggingface/callback.py index 959d00e..e7ecc56 100644 --- a/logix/huggingface/callback.py +++ b/logix/huggingface/callback.py @@ -74,6 +74,7 @@ def on_step_end(self, args, state, control, **kwargs): self.log_dataloader(), mode=self.args.influence_mode, damping=self.args.influence_damping, + influence_groups=self.args.influence_groups, save=True, ) @@ -83,7 +84,9 @@ def on_step_end(self, args, state, control, **kwargs): accumulated_log = merge_logs(self.accumulated_log) self.logix.influence.compute_self_influence( - accumulated_log, damping=self.args.influence_damping + accumulated_log, + damping=self.args.influence_damping, + influence_groups=self.args.influence_groups, ) self.accumulated_log = []