Skip to content

Commit

Permalink
Fix ORT execution provider test for onnxruntime v1.16.0 (#1406)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Sep 21, 2023
1 parent 3c4ad78 commit 6d1ae0e
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,25 +357,22 @@ def test_missing_execution_provider(self):

if not is_onnxruntime_gpu_installed:
for provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
with self.assertRaises(ImportError) as cm:
with self.assertRaises(ValueError) as cm:
_ = ORTModel.from_pretrained(self.ONNX_MODEL_ID, provider=provider)

self.assertTrue(
f"Asked to use {provider}, but `onnxruntime-gpu` package was not found." in str(cm.exception)
)
self.assertTrue("but the available execution providers" in str(cm.exception))

else:
logger.info("Skipping CUDAExecutionProvider/TensorrtExecutionProvider without `onnxruntime-gpu` test")

# need to install first onnxruntime-gpu, then onnxruntime for this test to pass,
# thus overwritting onnxruntime/capi/_ld_preload.py
if is_onnxruntime_installed and is_onnxruntime_gpu_installed:
for provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
with self.assertRaises(ImportError) as cm:
with self.assertRaises(ValueError) as cm:
_ = ORTModel.from_pretrained(self.ONNX_MODEL_ID, provider=provider)

self.assertTrue(
"`onnxruntime-gpu` is installed, but GPU dependencies are not loaded." in str(cm.exception)
)
self.assertTrue("but the available execution providers" in str(cm.exception))
else:
logger.info("Skipping double onnxruntime + onnxruntime-gpu install test")

Expand Down

0 comments on commit 6d1ae0e

Please sign in to comment.