From 0c3a4ab53e310b29b6ee2a13233fc2cf6e768209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6k=C3=A7e=20Uludo=C4=9Fan?= Date: Sat, 1 Jun 2024 17:21:35 +0300 Subject: [PATCH] Add kwargs to T5ForClassification forward --- turkish_lm_tuner/t5_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/turkish_lm_tuner/t5_classifier.py b/turkish_lm_tuner/t5_classifier.py index 2773453..c05955c 100644 --- a/turkish_lm_tuner/t5_classifier.py +++ b/turkish_lm_tuner/t5_classifier.py @@ -39,7 +39,7 @@ def __init__(self, pretrained_model_name, config, num_labels, problem_type, drop self.model_parallel = False - def forward(self, input_ids, attention_mask=None, labels=None): + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): encoder_output = self.encoder(input_ids, attention_mask=attention_mask) if self.config.problem_type == "token_classification": sequence_output = encoder_output.last_hidden_state