Skip to content

Commit

Permalink
Add kwargs to T5ForClassification forward
Browse files Browse the repository at this point in the history
  • Loading branch information
gokceuludogan authored Jun 1, 2024
1 parent a0e954b commit 0c3a4ab
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion turkish_lm_tuner/t5_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, pretrained_model_name, config, num_labels, problem_type, drop

self.model_parallel = False

def forward(self, input_ids, attention_mask=None, labels=None):
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
encoder_output = self.encoder(input_ids, attention_mask=attention_mask)
if self.config.problem_type == "token_classification":
sequence_output = encoder_output.last_hidden_state
Expand Down

0 comments on commit 0c3a4ab

Please sign in to comment.