diff --git a/src/pie_modules/models/base_models/bart_as_pointer_network.py b/src/pie_modules/models/base_models/bart_as_pointer_network.py index 5599e1401..ffb1ce9bf 100644 --- a/src/pie_modules/models/base_models/bart_as_pointer_network.py +++ b/src/pie_modules/models/base_models/bart_as_pointer_network.py @@ -52,6 +52,7 @@ def __init__( use_constraints_encoder_mlp: bool = False, # optimizer lr: float = 5e-5, + task_lr: Optional[float] = None, weight_decay: float = 1e-2, head_decay: Optional[float] = None, shared_decay: Optional[float] = None, @@ -74,6 +75,7 @@ def __init__( self.decoder_position_id_mapping = decoder_position_id_mapping self.lr = lr + self.task_lr = task_lr self.weight_decay = weight_decay self.head_decay = head_decay self.shared_decay = shared_decay @@ -408,7 +410,7 @@ def configure_optimizer(self) -> Optimizer: else self.config.weight_decay ) params = { - "lr": self.config.lr, + "lr": self.config.task_lr if self.config.task_lr is not None else self.config.lr, "weight_decay": head_decay, "params": dict(self.head_named_params()).values(), }