diff --git a/turkish_lm_tuner/predictor.py b/turkish_lm_tuner/predictor.py index 5f29709..4095666 100644 --- a/turkish_lm_tuner/predictor.py +++ b/turkish_lm_tuner/predictor.py @@ -9,7 +9,7 @@ class TaskConfig: max_new_tokens: int = None length_penalty: float = None no_repeat_ngram_size: int = None - early_stopping: bool = True + early_stopping: bool = False decoder_start_token_id: int = None eos_token_id: int = None pad_token_id: int = None @@ -142,6 +142,6 @@ def __init__(self, model_name, task, task_format='conditional_generation', max_i self.task_config = TaskConfig(**task_parameters[task]) def predict(self, text, **kwargs): - generation_config = vars(self.task_config, **kwargs) if self.task_format == 'conditional_generation' else {} + generation_config = {**vars(self.task_config), **kwargs} if self.task_format == 'conditional_generation' else {} return super().predict(text, generation_config) \ No newline at end of file