diff --git a/graxpert/ai_model_handling.py b/graxpert/ai_model_handling.py index 998d578..c5dab1d 100644 --- a/graxpert/ai_model_handling.py +++ b/graxpert/ai_model_handling.py @@ -4,10 +4,10 @@ import shutil import zipfile +import onnxruntime as ort from appdirs import user_data_dir from minio import Minio from packaging import version -import onnxruntime as ort try: from graxpert.s3_secrets import endpoint, ro_access_key, ro_secret_key @@ -166,7 +166,11 @@ def validate_local_version(ai_models_dir, local_version): return os.path.isfile(os.path.join(ai_models_dir, local_version, "model.onnx")) -def get_execution_providers_ordered(): - supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"] +def get_execution_providers_ordered(gpu_acceleration=True): + + if gpu_acceleration: + supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"] + else: + supported_providers = ["CPUExecutionProvider"] return [provider for provider in supported_providers if provider in ort.get_available_providers()] diff --git a/graxpert/application/__init__.py b/graxpert/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graxpert/application/app.py b/graxpert/application/app.py index 637d476..23d9533 100644 --- a/graxpert/application/app.py +++ b/graxpert/application/app.py @@ -86,11 +86,15 @@ def initialize(self): eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, self.on_denoise_ai_version_changed) eventbus.add_listener(AppEvents.SCALING_CHANGED, self.on_scaling_changed) eventbus.add_listener(AppEvents.AI_BATCH_SIZE_CHANGED, self.on_ai_batch_size_changed) + eventbus.add_listener(AppEvents.AI_GPU_ACCELERATION_CHANGED, self.on_ai_gpu_acceleration_changed) # event handling def on_ai_batch_size_changed(self, event): self.prefs.ai_batch_size = event["ai_batch_size"] + def on_ai_gpu_acceleration_changed(self, event): + self.prefs.ai_gpu_acceleration = event["ai_gpu_acceleration"] + def on_bge_ai_version_changed(self, event): self.prefs.bge_ai_version = event["bge_ai_version"] @@ -155,6 +159,7 @@ def on_calculate_request(self, event=None): self.prefs.corr_type, ai_model_path_from_version(bge_ai_models_dir, self.prefs.bge_ai_version), progress, + self.prefs.ai_gpu_acceleration, ) ) @@ -188,7 +193,7 @@ def on_calculate_request(self, event=None): def on_change_saturation_request(self, event): if self.images.get("Original") is None: return - + self.prefs.saturation = event["saturation"] eventbus.emit(AppEvents.CHANGE_SATURATION_BEGIN) @@ -323,7 +328,7 @@ def on_smoothing_changed(self, event): def on_denoise_strength_changed(self, event): self.prefs.denoise_strength = event["denoise_strength"] - + def on_denoise_threshold_changed(self, event): self.prefs.denoise_threshold = event["denoise_threshold"] @@ -346,23 +351,33 @@ def on_denoise_request(self, event): self.prefs.images_linked_option = True ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version) - imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, threshold=self.prefs.denoise_threshold, progress=progress) + imarray = denoise( + img_array_to_be_processed, + ai_model_path, + self.prefs.denoise_strength, + batch_size=self.prefs.ai_batch_size, + threshold=self.prefs.denoise_threshold, + progress=progress, + ai_gpu_acceleration=self.prefs.ai_gpu_acceleration, + ) - denoised = AstroImage() - denoised.set_from_array(imarray) + if imarray is not None: - # Update fits header and metadata - background_mean = np.mean(self.images.get("Original").img_array) - denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state) + denoised = AstroImage() + denoised.set_from_array(imarray) - denoised.copy_metadata(self.images.get("Original")) + # Update fits header and metadata + background_mean = np.mean(self.images.get("Original").img_array) + denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state) - self.images.set("Denoised", denoised) + denoised.copy_metadata(self.images.get("Original")) - self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation) + self.images.set("Denoised", denoised) + + self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation) - eventbus.emit(AppEvents.DENOISE_SUCCESS) - eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"}) + eventbus.emit(AppEvents.DENOISE_SUCCESS) + eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"}) except Exception as e: logging.exception(e) diff --git a/graxpert/application/app_events.py b/graxpert/application/app_events.py index e5649b3..b8d0a84 100644 --- a/graxpert/application/app_events.py +++ b/graxpert/application/app_events.py @@ -78,3 +78,6 @@ class AppEvents(Enum): LANGUAGE_CHANGED = auto() SCALING_CHANGED = auto() AI_BATCH_SIZE_CHANGED = auto() + AI_GPU_ACCELERATION_CHANGED = auto() + # process control + CANCEL_PROCESSING = auto() diff --git a/graxpert/background_extraction.py b/graxpert/background_extraction.py index 27d0343..bd1a896 100644 --- a/graxpert/background_extraction.py +++ b/graxpert/background_extraction.py @@ -24,7 +24,7 @@ def gaussian_kernel(sigma=1.0, truncate=4.0): # follow simulate skimage.filters return (ksize, ksize) -def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None): +def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None, ai_gpu_acceleration=True): num_colors = in_imarray.shape[-1] @@ -71,7 +71,7 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth if progress is not None: progress.update(8) - providers = get_execution_providers_ordered() + providers = get_execution_providers_ordered(ai_gpu_acceleration) session = ort.InferenceSession(ai_path, providers=providers) logging.info(f"Providers : {providers}") diff --git a/graxpert/cmdline_tools.py b/graxpert/cmdline_tools.py index 4dd9487..3bb97d4 100644 --- a/graxpert/cmdline_tools.py +++ b/graxpert/cmdline_tools.py @@ -10,9 +10,9 @@ from graxpert.ai_model_handling import ai_model_path_from_version, bge_ai_models_dir, denoise_ai_models_dir, download_version, latest_version, list_local_versions from graxpert.astroimage import AstroImage from graxpert.background_extraction import extract_background +from graxpert.denoising import denoise from graxpert.preferences import Prefs, load_preferences, save_preferences from graxpert.s3_secrets import bge_bucket_name, denoise_bucket_name -from graxpert.denoising import denoise user_preferences_filename = os.path.join(user_config_dir(appname="GraXpert"), "preferences.json") @@ -85,6 +85,8 @@ def execute(self): preferences.corr_type = json_prefs["corr_type"] if "ai_version" in json_prefs: preferences.ai_version = json_prefs["ai_version"] + if "ai_gpu_acceleration" in json_prefs: + preferences.ai_gpu_acceleration = json_prefs["ai_gpu_acceleration"] if preferences.interpol_type_option == "Kriging" or preferences.interpol_type_option == "RBF": downscale_factor = 4 @@ -109,6 +111,12 @@ def execute(self): else: logging.info(f"Using stored correction type {preferences.corr_type}.") + if self.args.gpu_acceleration is not None: + preferences.ai_gpu_acceleration = True if self.args.gpu_acceleration == "true" else False + logging.info(f"Using user-supplied gpu acceleration setting {preferences.ai_gpu_acceleration}.") + else: + logging.info(f"Using stored gpu acceleration setting {preferences.ai_gpu_acceleration}.") + if preferences.interpol_type_option == "AI": ai_model_path = ai_model_path_from_version(bge_ai_models_dir, self.get_ai_version(preferences)) else: @@ -153,6 +161,7 @@ def execute(self): preferences.spline_order, preferences.corr_type, ai_model_path, + ai_gpu_acceleration=preferences.ai_gpu_acceleration, ) ) @@ -222,6 +231,8 @@ def execute(self): preferences.denoise_strength = json_prefs["denoise_strength"] if "ai_batch_size" in json_prefs: preferences.ai_batch_size = json_prefs["ai_batch_size"] + if "ai_gpu_acceleration" in json_prefs: + preferences.ai_gpu_acceleration = json_prefs["ai_gpu_acceleration"] except Exception as e: logging.exception(e) @@ -229,19 +240,25 @@ def execute(self): sys.exit(1) else: preferences = Prefs() - + if self.args.denoise_strength is not None: preferences.denoise_strength = self.args.denoise_strength logging.info(f"Using user-supplied denoise strength value {preferences.denoise_strength}.") else: logging.info(f"Using stored denoise strength value {preferences.denoise_strength}.") - + if self.args.ai_batch_size is not None: preferences.ai_batch_size = self.args.ai_batch_size logging.info(f"Using user-supplied batch size value {preferences.ai_batch_size}.") else: logging.info(f"Using stored batch size value {preferences.ai_batch_size}.") + if self.args.gpu_acceleration is not None: + preferences.ai_gpu_acceleration = True if self.args.gpu_acceleration == "true" else False + logging.info(f"Using user-supplied gpu acceleration setting {preferences.ai_gpu_acceleration}.") + else: + logging.info(f"Using stored gpu acceleration setting {preferences.ai_gpu_acceleration}.") + ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.get_ai_version(preferences)) logging.info( @@ -254,12 +271,8 @@ def execute(self): ) processed_Astro_Image.set_from_array( - denoise( - astro_Image.img_array, - ai_model_path, - preferences.denoise_strength, - batch_size=preferences.ai_batch_size - )) + denoise(astro_Image.img_array, ai_model_path, preferences.denoise_strength, batch_size=preferences.ai_batch_size, ai_gpu_acceleration=preferences.ai_gpu_acceleration) + ) processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format()) def get_ai_version(self, prefs): diff --git a/graxpert/denoising.py b/graxpert/denoising.py index d0cfc92..e6075ad 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -11,7 +11,7 @@ from graxpert.ui.ui_events import UiEvents -def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, threshold=1.0, progress=None): +def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, threshold=1.0, progress=None, ai_gpu_acceleration=True): logging.info("Starting denoising") @@ -26,7 +26,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, batch_size = 2 ** (batch_size).bit_length() // 2 # map batch_size to power of two input = copy.deepcopy(image) - + median = np.median(image[::4, ::4, :], axis=[0, 1]) mad = np.median(np.abs(image[::4, ::4, :] - median), axis=[0, 1]) @@ -61,20 +61,33 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, output = copy.deepcopy(image) - providers = get_execution_providers_ordered() + providers = get_execution_providers_ordered(ai_gpu_acceleration) session = ort.InferenceSession(ai_path, providers=providers) logging.info(f"Available inference providers : {providers}") logging.info(f"Used inference providers : {session.get_providers()}") - + if "1.0.0" in ai_path or "1.1.0" in ai_path: model_threshold = 1.0 else: model_threshold = 10.0 + cancel_flag = False + + def cancel_listener(event): + nonlocal cancel_flag + cancel_flag = True + + eventbus.add_listener(AppEvents.CANCEL_PROCESSING, cancel_listener) + last_progress = 0 for b in range(0, ith * itw + batch_size, batch_size): + if cancel_flag: + logging.info("Denoising cancelled") + eventbus.remove_listener(AppEvents.CANCEL_PROCESSING, cancel_listener) + return None + input_tiles = [] input_tile_copies = [] for t_idx in range(0, batch_size): @@ -141,6 +154,7 @@ def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, cached_denoised_image = output output = blend_images(input, output, strength, threshold, median, mad) + eventbus.remove_listener(AppEvents.CANCEL_PROCESSING, cancel_listener) logging.info("Finished denoising") return output @@ -153,7 +167,6 @@ def blend_images(original_image, denoised_image, strength, threshold, median, ma return np.clip(blend, 0, 1) - def reset_cached_denoised_image(event): global cached_denoised_image cached_denoised_image = None diff --git a/graxpert/main.py b/graxpert/main.py index 5c22b34..dbc5ae4 100644 --- a/graxpert/main.py +++ b/graxpert/main.py @@ -49,6 +49,7 @@ def collect_available_versions(ai_models_dir, bucket_name): def bge_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")): return version_type(bge_ai_models_dir, bge_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")) + def denoise_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")): return version_type(denoise_ai_models_dir, denoise_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")) @@ -205,6 +206,7 @@ def main(): type=str, help="Allows GraXpert commandline to run all extraction methods based on a preferences file that contains background grid points", ) + parser.add_argument("-gpu", "--gpu_acceleration", type=str, choices=["true", "false"], default=None, help="Set to 'false' in order to disable gpu acceleration during AI inference.") parser.add_argument("-v", "--version", action="version", version=f"GraXpert version: {graxpert_version} release: {graxpert_release}") bge_parser = argparse.ArgumentParser("GraXpert Background Extraction", parents=[parser], description="GraXpert, the astronomical background extraction tool") diff --git a/graxpert/preferences.py b/graxpert/preferences.py index 581dec7..5af1a59 100644 --- a/graxpert/preferences.py +++ b/graxpert/preferences.py @@ -42,6 +42,7 @@ class Prefs: denoise_strength: float = 0.5 denoise_threshold: float = 10.0 ai_batch_size: int = 4 + ai_gpu_acceleration: bool = True def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs: diff --git a/graxpert/ui/canvas.py b/graxpert/ui/canvas.py index 2e9b050..747806b 100644 --- a/graxpert/ui/canvas.py +++ b/graxpert/ui/canvas.py @@ -183,12 +183,14 @@ def on_calculate_end(self, event=None): def on_denoise_begin(self, event=None): self.dynamic_progress_frame.text.set(_("Denoising")) + self.dynamic_progress_frame.cancellable = True self.show_progress_frame(True) def on_denoise_progress(self, event=None): self.dynamic_progress_frame.update_progress(event["progress"]) def on_denoise_success(self, event=None): + self.dynamic_progress_frame.cancellable = False if not "Denoised" in self.display_options: self.display_options.append("Denoised") self.display_menu.grid_forget() @@ -196,6 +198,7 @@ def on_denoise_success(self, event=None): self.display_menu.grid(column=0, row=0, sticky=tk.N) def on_denoise_end(self, event=None): + self.dynamic_progress_frame.cancellable = False self.dynamic_progress_frame.text.set("") self.dynamic_progress_frame.variable.set(0.0) self.show_progress_frame(False) @@ -497,6 +500,7 @@ def show_loading_frame(self, show): def show_progress_frame(self, show): if show: + self.dynamic_progress_frame.place_children() self.dynamic_progress_frame.grid(column=0, row=0, rowspan=2) else: self.dynamic_progress_frame.grid_forget() diff --git a/graxpert/ui/left_menu.py b/graxpert/ui/left_menu.py index e55c94e..7ea7098 100644 --- a/graxpert/ui/left_menu.py +++ b/graxpert/ui/left_menu.py @@ -1,6 +1,6 @@ import tkinter as tk -from customtkinter import StringVar, ThemeManager +from customtkinter import ThemeManager import graxpert.ui.tooltip as tooltip from graxpert.application.app import graxpert @@ -224,7 +224,7 @@ def __init__(self, parent, **kwargs): self.denoise_strength = tk.DoubleVar() self.denoise_strength.set(graxpert.prefs.denoise_strength) self.denoise_strength.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_STRENGTH_CHANGED, {"denoise_strength": self.denoise_strength.get()})) - + self.denoise_threshold = tk.DoubleVar() self.denoise_threshold.set(graxpert.prefs.denoise_threshold) self.denoise_threshold.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_THRESHOLD_CHANGED, {"denoise_threshold": self.denoise_threshold.get()})) @@ -251,7 +251,7 @@ def create_children(self): self.sub_frame, width=default_label_width, variable_name=_("Denoise Strength"), variable=self.denoise_strength, min_value=0.0, max_value=1.0, precision=2 ) tooltip.Tooltip(self.denoise_strength_slider, text=tooltip.denoise_strength_text) - + self.denoise_threshold_slider = ValueSlider( self.sub_frame, width=default_label_width, variable_name=_("Denoise Threshold"), variable=self.denoise_threshold, min_value=0.1, max_value=10.0, precision=1 ) diff --git a/graxpert/ui/loadingframe.py b/graxpert/ui/loadingframe.py index bfc7772..7148dd4 100644 --- a/graxpert/ui/loadingframe.py +++ b/graxpert/ui/loadingframe.py @@ -1,13 +1,14 @@ import logging -import sys import tkinter as tk from os import path from queue import Empty, Queue from threading import Thread -from customtkinter import CTkFont, CTkFrame, CTkImage, CTkLabel, CTkProgressBar, DoubleVar, StringVar +from customtkinter import CTkButton, CTkFont, CTkFrame, CTkImage, CTkLabel, CTkProgressBar, DoubleVar, StringVar, ThemeManager from PIL import Image +from graxpert.application.app_events import AppEvents +from graxpert.application.eventbus import eventbus from graxpert.localization import _ from graxpert.resource_utils import resource_path @@ -38,11 +39,12 @@ def place_children(self): class DynamicProgressFrame(CTkFrame): - def __init__(self, parent, label_lext=_("Progress:"), **kwargs): + def __init__(self, parent, label_lext=_("Progress:"), cancellable=False, **kwargs): super().__init__(parent, **kwargs) self.text = StringVar(self, value=label_lext) self.variable = DoubleVar(self, value=0.0) + self.cancellable = cancellable self.create_children() self.setup_layout() @@ -55,6 +57,13 @@ def create_children(self): width=280, ) self.pb = CTkProgressBar(self, variable=self.variable) + self.cancel_button = CTkButton( + self, + text=_("Cancel"), + command=lambda: eventbus.emit(AppEvents.CANCEL_PROCESSING), + fg_color=ThemeManager.theme["Accent.CTkButton"]["fg_color"], + hover_color=ThemeManager.theme["Accent.CTkButton"]["hover_color"], + ) def setup_layout(self): self.columnconfigure(0, weight=1) @@ -63,9 +72,14 @@ def setup_layout(self): def place_children(self): self.label.grid(column=0, row=0, sticky=tk.NSEW) self.pb.grid(column=0, row=1, sticky=tk.NSEW) + if self.cancellable: + self.cancel_button.grid(column=0, row=2) + else: + self.cancel_button.grid_forget() def close(self): - self.pb.pack_forget() + self.pb.grid_forget() + self.cancel_button.grid_forget() self.update() self.destroy() diff --git a/graxpert/ui/right_menu.py b/graxpert/ui/right_menu.py index 67c8e19..27d8111 100644 --- a/graxpert/ui/right_menu.py +++ b/graxpert/ui/right_menu.py @@ -3,7 +3,7 @@ from tkinter import messagebox import customtkinter as ctk -from customtkinter import CTkFont, CTkImage, CTkLabel, CTkTextbox +from customtkinter import CTkFont, CTkImage, CTkLabel, CTkSwitch, CTkTextbox from packaging import version from PIL import Image @@ -187,11 +187,15 @@ def __init__(self, master, **kwargs): self.denoise_ai_version.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_AI_VERSION_CHANGED, {"denoise_ai_version": self.denoise_ai_version.get()})) # ai settings - self.ai_batch_size_options = ["1","2","4","8","16","32"] + self.ai_batch_size_options = ["1", "2", "4", "8", "16", "32"] self.ai_batch_size = tk.IntVar() self.ai_batch_size.set(graxpert.prefs.ai_batch_size) self.ai_batch_size.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.AI_BATCH_SIZE_CHANGED, {"ai_batch_size": self.ai_batch_size.get()})) + self.ai_gpu_acceleration = tk.BooleanVar() + self.ai_gpu_acceleration.set(graxpert.prefs.ai_gpu_acceleration) + self.ai_gpu_acceleration.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.AI_GPU_ACCELERATION_CHANGED, {"ai_gpu_acceleration": self.ai_gpu_acceleration.get()})) + self.create_and_place_children() self.setup_layout() @@ -243,6 +247,9 @@ def lang_change(lang): CTkLabel(self, text=_("AI inference batch size"), font=self.heading_font2).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N) GraXpertOptionMenu(self, variable=self.ai_batch_size, values=self.ai_batch_size_options).grid(**self.default_grid()) + CTkLabel(self, text=_("AI Hardware Acceleration"), font=self.heading_font2).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N) + CTkSwitch(self, text=_("Enable Acceleration"), variable=self.ai_gpu_acceleration).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N) + def setup_layout(self): self.columnconfigure(0, weight=1) self.rowconfigure(0, weight=1)