Skip to content

Commit

Permalink
Merge pull request #131 from Rikyf3/feature/onnxruntime_implementation
Browse files Browse the repository at this point in the history
onnxruntime implementation with hardware accelerators
  • Loading branch information
Steffenhir authored Mar 16, 2024
2 parents 104b70c + 3fe3a6d commit d5a82f5
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
run: |
sudo apt install alien -y && \
pip install "cx_freeze>=6.16.0.dev" && \
pip install onnxruntime-gpu && \
pip install -r requirements.txt
- name: patch version
run: |
Expand Down Expand Up @@ -68,6 +69,7 @@ jobs:
- name: install dependencies
run: |
pip install setuptools wheel cx_freeze && \
pip install onnxruntime-gpu && \
pip install -r requirements.txt
- name: patch version
run: |
Expand Down Expand Up @@ -105,6 +107,7 @@ jobs:
- name: install dependencies
run: |
pip install setuptools wheel cx_freeze; `
pip install onnxruntime-directml; `
pip install -r requirements.txt
- name: patch version
run: ./releng/patch_version.ps1
Expand Down Expand Up @@ -138,6 +141,7 @@ jobs:
run: |
brew install python-tk && \
pip3 install setuptools wheel pyinstaller && \
pip3 install onnxruntime && \
pip3 install -r requirements.txt
- name: patch version
run: |
Expand Down
8 changes: 8 additions & 0 deletions graxpert/ai_model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from appdirs import user_data_dir
from minio import Minio
from packaging import version
import onnxruntime as ort

try:
from graxpert.s3_secrets import endpoint, ro_access_key, ro_secret_key
Expand Down Expand Up @@ -150,3 +151,10 @@ def download_version(ai_models_dir, bucket_name, remote_version, progress=None):

def validate_local_version(ai_models_dir, local_version):
return os.path.isfile(os.path.join(ai_models_dir, local_version, "model.onnx"))


def get_execution_providers_ordered():
supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider",
"CPUExecutionProvider"]

return [provider for provider in supported_providers if provider in ort.get_available_providers()]
7 changes: 6 additions & 1 deletion graxpert/background_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from graxpert.mp_logging import get_logging_queue, worker_configurer
from graxpert.parallel_processing import executor
from graxpert.radialbasisinterpolation import RadialBasisInterpolation
from graxpert.ai_model_handling import get_execution_providers_ordered


def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None):
Expand Down Expand Up @@ -61,7 +62,11 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth
if progress is not None:
progress.update(8)

session = ort.InferenceSession(ai_path, providers=ort.get_available_providers())
providers = get_execution_providers_ordered()
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Providers : {providers}")
logging.info(f"Used providers : {session.get_providers()}")

background = session.run(None, {"gen_input_image": np.expand_dims(imarray_shrink, axis=0)})[0][0]

Expand Down
8 changes: 7 additions & 1 deletion graxpert/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import onnxruntime as ort

from graxpert.ai_model_handling import get_execution_providers_ordered


def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None):

Expand Down Expand Up @@ -34,7 +36,11 @@ def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None

output = copy.deepcopy(image)

session = ort.InferenceSession(ai_path, providers=ort.get_available_providers())
providers = get_execution_providers_ordered()
session = ort.InferenceSession(ai_path, providers=providers)

logging.info(f"Providers : {providers}")
logging.info(f"Used providers : {session.get_providers()}")

for i in range(ith):
for j in range(itw):
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,3 @@ requests
scikit-image == 0.21.0
scipy
xisf
onnxruntime-silicon; sys_platform == "darwin"
onnxruntime-directml; sys_platform != "darwin"
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from graxpert.version import release, version

sys.setrecursionlimit(15_000)

astropy_path = os.path.dirname(os.path.abspath(astropy.__file__))

directory_table = [("ProgramMenuFolder", "TARGETDIR", "."), ("GraXpert", "ProgramMenuFolder", "GraXpert")]
Expand Down

0 comments on commit d5a82f5

Please sign in to comment.