Skip to content

Commit

Permalink
Merge pull request #75 from boun-tabi-LMG/gokceuludogan-patch-t5-clas…
Browse files Browse the repository at this point in the history
…sification

Add kwargs to T5ForClassification forward
  • Loading branch information
gokceuludogan authored Nov 17, 2024
2 parents 2c82393 + 0c3a4ab commit 096a71d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion turkish_lm_tuner/t5_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, pretrained_model_name, config, num_labels, problem_type, drop

self.model_parallel = False

def forward(self, input_ids, attention_mask=None, labels=None):
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
encoder_output = self.encoder(input_ids, attention_mask=attention_mask)
if self.config.problem_type == "token_classification":
sequence_output = encoder_output.last_hidden_state
Expand Down

0 comments on commit 096a71d

Please sign in to comment.