diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 69dcdb72038b5..f0261c776609e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -443,7 +443,7 @@ def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device) return input_ids, input_mask, labels -def _get_bert_for_sequence_classification_sample_data(device, shape): +def _get_bert_for_sequence_classification_sample_data(device): """Returns sample data to be used with BertForSequenceClassification model""" input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)