Skip to content

Commit

Permalink
fix supports_device() in python interface (#22473)
Browse files Browse the repository at this point in the history
### Description

`get_device()` returns a string of hyphen connected device names, such
as "GPU-DML". It's a problem that when CUDA is disabled but OpenVino GPU
is enabled in the build, because in this case `get_device()` returns
"CPU-OPENVINO_GPU", so `supports_device("CUDA")` will return `True` in
this build.

Splitting the value of `get_device()` by "-" and check if the input is
in the list is not an option because it seems some code in the code base
stores the value of `get_device()` and use the value to call
`supports_device()`. Using this implementation will cause
`supports_device("GPU-DML")` to return `False` for a build with
`get_device() == "GPU-DML"` because `"GPU-DML" in ["GPU", "DML"]` is
`False`.

This change also helps to avoid further problems when "WebGPU" is
introduced.
  • Loading branch information
fs-eire authored and tianleiwu committed Oct 18, 2024
1 parent 3a2968a commit a1f4b31
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion onnxruntime/python/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def supports_device(cls, device):
"""
if device == "CUDA":
device = "GPU"
return device in get_device()
return "-" + device in get_device() or device + "-" in get_device() or device == get_device()

@classmethod
def prepare(cls, model, device=None, **kwargs):
Expand Down

0 comments on commit a1f4b31

Please sign in to comment.