Skip to content

Commit

Permalink
support group-wise influence in HF integration
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Jun 6, 2024
1 parent bd0de02 commit c928e4b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 5 additions & 2 deletions logix/huggingface/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import List
from typing import List, Optional

import torch.nn as nn

Expand Down Expand Up @@ -47,12 +47,15 @@ class LogIXArguments:
input_key: str = field(
default="input_ids", metadata={"help": "The dictionary key for 'input_ids'."}
)
influence_damping: float = field(
influence_damping: Optional[float] = field(
default=None, metadata={"help": "A damping term in influence functions."}
)
influence_mode: str = field(
default="dot", metadata={"help": "Influence function mode."}
)
influence_groups: Optional[List[str]] = field(
default=None, metadata={"help": "Influence function groups."}
)
label_key: str = field(
default="labels", metadata={"help": "The dictionary key for 'labels'."}
)
Expand Down
5 changes: 4 additions & 1 deletion logix/huggingface/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def on_step_end(self, args, state, control, **kwargs):
self.log_dataloader(),
mode=self.args.influence_mode,
damping=self.args.influence_damping,
influence_groups=self.args.influence_groups,
save=True,
)

Expand All @@ -83,7 +84,9 @@ def on_step_end(self, args, state, control, **kwargs):
accumulated_log = merge_logs(self.accumulated_log)

self.logix.influence.compute_self_influence(
accumulated_log, damping=self.args.influence_damping
accumulated_log,
damping=self.args.influence_damping,
influence_groups=self.args.influence_groups,
)

self.accumulated_log = []
Expand Down

0 comments on commit c928e4b

Please sign in to comment.