From db374a0e745fd554ea7725a8412c35e8d77cc7a7 Mon Sep 17 00:00:00 2001 From: Riccardo Alberghi Date: Sat, 16 Mar 2024 10:59:44 +0100 Subject: [PATCH 1/3] Added hardware acceleration to bkg extraction and denoise --- graxpert/ai_model_handling.py | 8 ++++++++ graxpert/background_extraction.py | 7 ++++++- graxpert/denoising.py | 8 +++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/graxpert/ai_model_handling.py b/graxpert/ai_model_handling.py index e9fd95a..d2fc19d 100644 --- a/graxpert/ai_model_handling.py +++ b/graxpert/ai_model_handling.py @@ -7,6 +7,7 @@ from appdirs import user_data_dir from minio import Minio from packaging import version +import onnxruntime as ort try: from graxpert.s3_secrets import endpoint, ro_access_key, ro_secret_key @@ -150,3 +151,10 @@ def download_version(ai_models_dir, bucket_name, remote_version, progress=None): def validate_local_version(ai_models_dir, local_version): return os.path.isfile(os.path.join(ai_models_dir, local_version, "model.onnx")) + + +def get_execution_providers_ordered(): + supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", + "CPUExecutionProvider"] + + return [provider for provider in supported_providers if provider in ort.get_available_providers()] diff --git a/graxpert/background_extraction.py b/graxpert/background_extraction.py index 3b3da31..e4d43ad 100644 --- a/graxpert/background_extraction.py +++ b/graxpert/background_extraction.py @@ -17,6 +17,7 @@ from graxpert.mp_logging import get_logging_queue, worker_configurer from graxpert.parallel_processing import executor from graxpert.radialbasisinterpolation import RadialBasisInterpolation +from graxpert.ai_model_handling import get_execution_providers_ordered def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None): @@ -61,7 +62,11 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth if progress is not None: progress.update(8) - session = ort.InferenceSession(ai_path, providers=ort.get_available_providers()) + providers = get_execution_providers_ordered() + session = ort.InferenceSession(ai_path, providers=providers) + + logging.info(f"Providers : {providers}") + logging.info(f"Used providers : {session.get_providers()}") background = session.run(None, {"gen_input_image": np.expand_dims(imarray_shrink, axis=0)})[0][0] diff --git a/graxpert/denoising.py b/graxpert/denoising.py index fbe6611..6ada9eb 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -4,6 +4,8 @@ import numpy as np import onnxruntime as ort +from graxpert.ai_model_handling import get_execution_providers_ordered + def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None): @@ -34,7 +36,11 @@ def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None output = copy.deepcopy(image) - session = ort.InferenceSession(ai_path, providers=ort.get_available_providers()) + providers = get_execution_providers_ordered() + session = ort.InferenceSession(ai_path, providers=providers) + + logging.info(f"Providers : {providers}") + logging.info(f"Used providers : {session.get_providers()}") for i in range(ith): for j in range(itw): From f020b7a878153b87218e36fb55e351802678071b Mon Sep 17 00:00:00 2001 From: Riccardo Alberghi Date: Sat, 16 Mar 2024 11:05:11 +0100 Subject: [PATCH 2/3] Changed onnxruntime installation from requirements.txt to build-release.yml --- .github/workflows/build-release.yml | 4 ++++ requirements.txt | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index 4fedd3a..4e558ba 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -36,6 +36,7 @@ jobs: run: | sudo apt install alien -y && \ pip install "cx_freeze>=6.16.0.dev" && \ + pip install onnxruntime-gpu && \ pip install -r requirements.txt - name: patch version run: | @@ -68,6 +69,7 @@ jobs: - name: install dependencies run: | pip install setuptools wheel cx_freeze && \ + pip install onnxruntime-gpu && \ pip install -r requirements.txt - name: patch version run: | @@ -105,6 +107,7 @@ jobs: - name: install dependencies run: | pip install setuptools wheel cx_freeze; ` + pip install onnxruntime-directml; ` pip install -r requirements.txt - name: patch version run: ./releng/patch_version.ps1 @@ -138,6 +141,7 @@ jobs: run: | brew install python-tk && \ pip3 install setuptools wheel pyinstaller && \ + pip3 install onnxruntime && \ pip3 install -r requirements.txt - name: patch version run: | diff --git a/requirements.txt b/requirements.txt index 5687b1b..a9c5f27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,3 @@ requests scikit-image == 0.21.0 scipy xisf -onnxruntime-silicon; sys_platform == "darwin" -onnxruntime-directml; sys_platform != "darwin" From 3fe3a6db47d9f2dbeb874ee5aaef38768b458445 Mon Sep 17 00:00:00 2001 From: Riccardo Alberghi Date: Sat, 16 Mar 2024 11:11:48 +0100 Subject: [PATCH 3/3] Fixed recursion error of cx_freeze --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 166ca25..ab60703 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,8 @@ from graxpert.version import release, version +sys.setrecursionlimit(15_000) + astropy_path = os.path.dirname(os.path.abspath(astropy.__file__)) directory_table = [("ProgramMenuFolder", "TARGETDIR", "."), ("GraXpert", "ProgramMenuFolder", "GraXpert")]