From ab7a4109df5bb7d9d0407765b735ae50ee403cc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6k=C3=A7e=20Uludo=C4=9Fan?= Date: Sun, 4 Feb 2024 21:29:11 +0300 Subject: [PATCH] Set early_stopping False and fix TextPredictor's predict --- turkish_lm_tuner/predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/turkish_lm_tuner/predictor.py b/turkish_lm_tuner/predictor.py index 5f29709..4095666 100644 --- a/turkish_lm_tuner/predictor.py +++ b/turkish_lm_tuner/predictor.py @@ -9,7 +9,7 @@ class TaskConfig: max_new_tokens: int = None length_penalty: float = None no_repeat_ngram_size: int = None - early_stopping: bool = True + early_stopping: bool = False decoder_start_token_id: int = None eos_token_id: int = None pad_token_id: int = None @@ -142,6 +142,6 @@ def __init__(self, model_name, task, task_format='conditional_generation', max_i self.task_config = TaskConfig(**task_parameters[task]) def predict(self, text, **kwargs): - generation_config = vars(self.task_config, **kwargs) if self.task_format == 'conditional_generation' else {} + generation_config = {**vars(self.task_config), **kwargs} if self.task_format == 'conditional_generation' else {} return super().predict(text, generation_config) \ No newline at end of file