Skip to content

Commit

Permalink
Merge pull request #81 from ArneBinder/add-task-lr-bartasptr
Browse files Browse the repository at this point in the history
add task lr config for bart_as_pointer_network
  • Loading branch information
ArneBinder authored Apr 8, 2024
2 parents 63e77fe + a4b54e2 commit 1c9fce1
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
use_constraints_encoder_mlp: bool = False,
# optimizer
lr: float = 5e-5,
task_lr: Optional[float] = None,
weight_decay: float = 1e-2,
head_decay: Optional[float] = None,
shared_decay: Optional[float] = None,
Expand All @@ -74,6 +75,7 @@ def __init__(
self.decoder_position_id_mapping = decoder_position_id_mapping

self.lr = lr
self.task_lr = task_lr
self.weight_decay = weight_decay
self.head_decay = head_decay
self.shared_decay = shared_decay
Expand Down Expand Up @@ -408,7 +410,7 @@ def configure_optimizer(self) -> Optimizer:
else self.config.weight_decay
)
params = {
"lr": self.config.lr,
"lr": self.config.task_lr if self.config.task_lr is not None else self.config.lr,
"weight_decay": head_decay,
"params": dict(self.head_named_params()).values(),
}
Expand Down

0 comments on commit 1c9fce1

Please sign in to comment.