From 8f520dde1ad3732dd45f4a641e684358d4febe6f Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Sun, 3 Nov 2024 17:17:35 +0000 Subject: [PATCH] add eval_data_collator arg --- trl/trainer/sft_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c23a5c84c6..c7d797700c 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -130,6 +130,7 @@ def __init__( model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, args: Optional[SFTConfig] = None, data_collator: Optional[DataCollator] = None, # type: ignore + eval_data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, processing_class: Optional[ @@ -409,6 +410,7 @@ def make_inputs_require_grad(module, input, output): model=model, args=args, data_collator=data_collator, + eval_data_collator=eval_data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,