From cb20873e9fe89dd2bac88c1ec68b47deb980c354 Mon Sep 17 00:00:00 2001 From: David Schmelter Date: Mon, 25 Mar 2024 17:49:04 +0100 Subject: [PATCH 1/5] add batch processing for denoising --- .vscode/launch.json | 4 +-- graxpert/denoising.py | 78 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 82bada6..fd654f6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -3,10 +3,10 @@ "configurations": [ { "name": "Run GraXpert", - "type": "python", + "type": "debugpy", "request": "launch", "module": "graxpert.main", - "justMyCode": true + "justMyCode": true, } ] } \ No newline at end of file diff --git a/graxpert/denoising.py b/graxpert/denoising.py index f3fbcfd..1b134d3 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -1,5 +1,6 @@ import copy import logging +import time import numpy as np import onnxruntime as ort @@ -7,11 +8,13 @@ from graxpert.ai_model_handling import get_execution_providers_ordered -def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None): +def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, progress=None): + + logging.info("Starting denoising") input = copy.deepcopy(image) num_colors = image.shape[-1] - + if num_colors == 1: image = np.array([image[:, :, 0], image[:, :, 0], image[:, :, 0]]) image = np.moveaxis(image, 0, -1) @@ -42,11 +45,16 @@ def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None output = copy.deepcopy(image) providers = get_execution_providers_ordered() - session = ort.InferenceSession(ai_path, providers=providers) + ort_options = ort.SessionOptions() + ort_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL + session = ort.InferenceSession(ai_path, providers=providers, sess_options=ort_options) - logging.info(f"Providers : {providers}") - logging.info(f"Used providers : {session.get_providers()}") + logging.info(f"Available inference providers : {providers}") + logging.info(f"Used inference providers : {session.get_providers()}") + last_progress = 0 + + input_tiles = [] for i in range(ith): for j in range(itw): x = stride * i @@ -54,28 +62,70 @@ def denoise(image, ai_path, strength, window_size=256, stride=128, progress=None tile = image[x : x + window_size, y : y + window_size, :] tile = (tile - median) / mad * 0.04 - tile_copy = tile.copy() + # input_tile_copies.append(tile.copy()) tile = np.clip(tile, -1.0, 1.0) - tile = np.expand_dims(tile, axis=0) - tile = np.array(session.run(None, {"gen_input_image": tile})[0][0]) + input_tiles.append(tile) + + p = int(i / ith * 10) + if p > last_progress: + if progress is not None: + progress.update(p - last_progress) + else: + logging.info(f"Progress: {p}%") + last_progress = p + + input_tiles = np.array(input_tiles) + input_tile_copies = np.copy(input_tiles).reshape((ith, itw, window_size, window_size, 3)) + + output_tiles = [] + + elapsed_time = 0 + for i in range(0, ith * itw, batch_size): + start = time.time() + session_result = session.run(None, {"gen_input_image": input_tiles[i : i + batch_size]})[0] + elapsed_time += time.time() - start + for e in session_result: + output_tiles.append(e) + + p = int(10 + i / (ith * itw) * 80) + if p > last_progress: + if progress is not None: + progress.update(p - last_progress) + else: + logging.info(f"Progress: {p}%") + last_progress = p - tile = np.where(tile_copy < 0.95, tile, tile_copy) + output_tiles = np.array(output_tiles) + output_tiles = output_tiles.reshape((ith, itw, window_size, window_size, 3)) + + for i in range(ith): + for j in range(itw): + x = stride * i + y = stride * j + + tile = output_tiles[i, j, :] + tile = np.where(input_tile_copies[i, j] < 0.95, tile, input_tile_copies[i, j]) tile = tile / 0.04 * mad + median tile = tile[offset : offset + stride, offset : offset + stride, :] output[x + offset : stride * (i + 1) + offset, y + offset : stride * (j + 1) + offset, :] = tile - if progress is not None: - progress.update(int(100 / ith)) - else: - logging.info(f"Progress: {int(i/ith*100)}%") + p = int(90 + i / ith * 10) + if p > last_progress: + if progress is not None: + progress.update(p - last_progress) + else: + logging.info(f"Progress: {p}%") + last_progress = p output = np.clip(output, 0, 1) output = output[offset : H + offset, offset : W + offset, :] output = output * strength + input * (1 - strength) - + # if num_colors == 1: output = np.array([output[:, :, 0]]) output = np.moveaxis(output, 0, -1) + logging.info("Finished denoising") + return output From f210734ea78bffff456f2f41b767bfbbc2b9491c Mon Sep 17 00:00:00 2001 From: David Schmelter Date: Fri, 29 Mar 2024 15:14:58 +0100 Subject: [PATCH 2/5] optimize denoising memory footprint, optimize bge runtime --- graxpert/background_extraction.py | 52 +++++++++++++++--------- graxpert/denoising.py | 67 ++++++++++++++----------------- requirements.txt | 1 + 3 files changed, 64 insertions(+), 56 deletions(-) diff --git a/graxpert/background_extraction.py b/graxpert/background_extraction.py index e4d43ad..f5e689a 100644 --- a/graxpert/background_extraction.py +++ b/graxpert/background_extraction.py @@ -6,34 +6,39 @@ from concurrent.futures import wait from multiprocessing import shared_memory +import cv2 import numpy as np import onnxruntime as ort from astropy.stats import sigma_clipped_stats from pykrige.ok import OrdinaryKriging from scipy import interpolate, linalg -from skimage.filters import gaussian -from skimage.transform import resize +from graxpert.ai_model_handling import get_execution_providers_ordered from graxpert.mp_logging import get_logging_queue, worker_configurer from graxpert.parallel_processing import executor from graxpert.radialbasisinterpolation import RadialBasisInterpolation -from graxpert.ai_model_handling import get_execution_providers_ordered + + +def gaussian_kernel(sigma=1.0, truncate=4.0): # follow simulate skimage.filters.gaussian defaults + ksize = round(sigma * truncate) - 1 if round(sigma * truncate) % 2 == 0 else round(sigma * truncate) + return (ksize, ksize) def extract_background(in_imarray, background_points, interpolation_type, smoothing, downscale_factor, sample_size, RBF_kernel, spline_order, corr_type, ai_path, progress=None): - shm_imarray = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes) - shm_background = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes) - imarray = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_imarray.buf) - background = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_background.buf) - np.copyto(imarray, in_imarray) + num_colors = in_imarray.shape[-1] - num_colors = imarray.shape[-1] + shm_imarray = None + shm_background = None if interpolation_type == "AI": + imarray = np.ndarray(in_imarray.shape, dtype=np.float32) + background = np.ndarray(in_imarray.shape, dtype=np.float32) + np.copyto(imarray, in_imarray) + # Shrink and pad to avoid artifacts on borders padding = 8 - imarray_shrink = resize(imarray, output_shape=(256 - 2 * padding, 256 - 2 * padding)) + imarray_shrink = cv2.resize(imarray, dsize=(256 - 2 * padding, 256 - 2 * padding), interpolation=cv2.INTER_LINEAR) imarray_shrink = np.pad(imarray_shrink, ((padding, padding), (padding, padding), (0, 0)), mode="edge") median = [] @@ -77,7 +82,7 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth if smoothing != 0: sigma = smoothing * 20 - background = gaussian(image=background, sigma=sigma, channel_axis=-1) + background = cv2.GaussianBlur(background, ksize=gaussian_kernel(sigma), sigmaX=sigma, sigmaY=sigma) if progress is not None: progress.update(8) @@ -96,13 +101,20 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth if progress is not None: progress.update(8) - background = gaussian(background, sigma=3.0) # To simulate tensorflow method='gaussian' - background = resize(background, output_shape=(in_imarray.shape[0], in_imarray.shape[1])) + sigma = 3.0 + background = cv2.GaussianBlur(background, ksize=gaussian_kernel(sigma), sigmaX=sigma, sigmaY=sigma) + background = cv2.resize(background, dsize=(in_imarray.shape[1], in_imarray.shape[0]), interpolation=cv2.INTER_LINEAR) if progress is not None: progress.update(8) else: + shm_imarray = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes) + shm_background = shared_memory.SharedMemory(create=True, size=in_imarray.nbytes) + imarray = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_imarray.buf) + background = np.ndarray(in_imarray.shape, dtype=np.float32, buffer=shm_background.buf) + np.copyto(imarray, in_imarray) + x_sub = np.array(background_points[:, 0], dtype=int) y_sub = np.array(background_points[:, 1], dtype=int) @@ -154,15 +166,17 @@ def extract_background(in_imarray, background_points, interpolation_type, smooth imarray[:, :, :] = imarray.clip(min=0.0, max=1.0) in_imarray[:] = imarray[:] - background = np.copy(background) if progress is not None: progress.update(8) - shm_imarray.close() - shm_background.close() - shm_imarray.unlink() - shm_background.unlink() + if shm_imarray is not None: + shm_imarray.close() + shm_imarray.unlink() + if shm_background is not None: + background = np.copy(background) + shm_background.close() + shm_background.unlink() return background @@ -258,7 +272,7 @@ def interpol(shm_imarray_name, shm_background_name, c, x_sub, y_sub, shape, kind return if downscale_factor != 1: - result = resize(result, shape, preserve_range=True) + result = cv2.resize(src=result, dsize=(shape[1], shape[0]), interpolation=cv2.INTER_LINEAR) background[:, :, c] = result except Exception as e: diff --git a/graxpert/denoising.py b/graxpert/denoising.py index 1b134d3..3b86358 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -53,64 +53,57 @@ def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, logging.info(f"Used inference providers : {session.get_providers()}") last_progress = 0 + for b in range(0, ith * itw + batch_size, batch_size): + + input_tiles = [] + for t_idx in range(0, batch_size): + + index = b + t_idx + i = index % ith + j = index // ith + + if i >= ith or j >= itw: + break - input_tiles = [] - for i in range(ith): - for j in range(itw): x = stride * i y = stride * j tile = image[x : x + window_size, y : y + window_size, :] tile = (tile - median) / mad * 0.04 - # input_tile_copies.append(tile.copy()) tile = np.clip(tile, -1.0, 1.0) input_tiles.append(tile) - p = int(i / ith * 10) - if p > last_progress: - if progress is not None: - progress.update(p - last_progress) - else: - logging.info(f"Progress: {p}%") - last_progress = p - - input_tiles = np.array(input_tiles) - input_tile_copies = np.copy(input_tiles).reshape((ith, itw, window_size, window_size, 3)) + if not input_tiles: + continue + + input_tiles = np.array(input_tiles) + input_tile_copies = np.copy(input_tiles) - output_tiles = [] - - elapsed_time = 0 - for i in range(0, ith * itw, batch_size): - start = time.time() - session_result = session.run(None, {"gen_input_image": input_tiles[i : i + batch_size]})[0] - elapsed_time += time.time() - start + output_tiles = [] + session_result = session.run(None, {"gen_input_image": input_tiles})[0] for e in session_result: output_tiles.append(e) - p = int(10 + i / (ith * itw) * 80) - if p > last_progress: - if progress is not None: - progress.update(p - last_progress) - else: - logging.info(f"Progress: {p}%") - last_progress = p + output_tiles = np.array(output_tiles) + + for t_idx, tile in enumerate(output_tiles): - output_tiles = np.array(output_tiles) - output_tiles = output_tiles.reshape((ith, itw, window_size, window_size, 3)) + index = b + t_idx + i = index % ith + j = index // ith + + if i >= ith or j >= itw: + break - for i in range(ith): - for j in range(itw): x = stride * i y = stride * j - - tile = output_tiles[i, j, :] - tile = np.where(input_tile_copies[i, j] < 0.95, tile, input_tile_copies[i, j]) + tile = np.where(input_tile_copies[t_idx] < 0.95, tile, input_tile_copies[t_idx]) tile = tile / 0.04 * mad + median tile = tile[offset : offset + stride, offset : offset + stride, :] output[x + offset : stride * (i + 1) + offset, y + offset : stride * (j + 1) + offset, :] = tile - p = int(90 + i / ith * 10) + p = int(b / (ith * itw + batch_size) * 100) if p > last_progress: if progress is not None: progress.update(p - last_progress) @@ -121,7 +114,7 @@ def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, output = np.clip(output, 0, 1) output = output[offset : H + offset, offset : W + offset, :] output = output * strength + input * (1 - strength) - # + if num_colors == 1: output = np.array([output[:, :, 0]]) output = np.moveaxis(output, 0, -1) diff --git a/requirements.txt b/requirements.txt index a9c5f27..8855640 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ ml_dtypes numpy<=1.24.3,>=1.22 Pillow pykrige +opencv-Python requests scikit-image == 0.21.0 scipy From 1b67343eccbde764c4b8e16c692b84baac0e59d2 Mon Sep 17 00:00:00 2001 From: David Schmelter Date: Fri, 29 Mar 2024 15:39:24 +0100 Subject: [PATCH 3/5] fix conflicting cmdline arguments --- graxpert/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graxpert/main.py b/graxpert/main.py index 113e23e..36de75b 100644 --- a/graxpert/main.py +++ b/graxpert/main.py @@ -183,7 +183,7 @@ def main(): parser = argparse.ArgumentParser(add_help=False) parser.add_argument("-cli", "--cli", required=False, action="store_true", help="Has to be added when using the command line integration of GraXpert") parser.add_argument( - "-c", + "-cmd", "--command", required=False, default="background-extraction", From 1fbb3a4824b97f003bd2325b0409d7cb12aeca62 Mon Sep 17 00:00:00 2001 From: David Schmelter Date: Fri, 29 Mar 2024 15:52:53 +0100 Subject: [PATCH 4/5] fix incorrect clipping of tiles during denoising --- graxpert/denoising.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graxpert/denoising.py b/graxpert/denoising.py index 3b86358..76345de 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -56,6 +56,7 @@ def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, for b in range(0, ith * itw + batch_size, batch_size): input_tiles = [] + input_tile_copies = [] for t_idx in range(0, batch_size): index = b + t_idx @@ -70,6 +71,7 @@ def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, tile = image[x : x + window_size, y : y + window_size, :] tile = (tile - median) / mad * 0.04 + input_tile_copies.append(np.copy(tile)) tile = np.clip(tile, -1.0, 1.0) input_tiles.append(tile) @@ -78,7 +80,6 @@ def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, continue input_tiles = np.array(input_tiles) - input_tile_copies = np.copy(input_tiles) output_tiles = [] session_result = session.run(None, {"gen_input_image": input_tiles})[0] From 26e8247dcce0dc6a142ad2f8c7c0b8db8d79c2af Mon Sep 17 00:00:00 2001 From: David Schmelter Date: Fri, 29 Mar 2024 22:19:23 +0100 Subject: [PATCH 5/5] add ai_batch_size property to preferences, advanced settings, and cli --- graxpert/application/app.py | 36 +++++++++++++++++------------- graxpert/application/app_events.py | 1 + graxpert/cmdline_tools.py | 11 ++++++++- graxpert/denoising.py | 2 +- graxpert/main.py | 9 ++++++++ graxpert/preferences.py | 1 + graxpert/ui/right_menu.py | 8 +++++++ 7 files changed, 50 insertions(+), 18 deletions(-) diff --git a/graxpert/application/app.py b/graxpert/application/app.py index 3514757..eb8c74c 100644 --- a/graxpert/application/app.py +++ b/graxpert/application/app.py @@ -84,8 +84,12 @@ def initialize(self): eventbus.add_listener(AppEvents.BGE_AI_VERSION_CHANGED, self.on_bge_ai_version_changed) eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, self.on_denoise_ai_version_changed) eventbus.add_listener(AppEvents.SCALING_CHANGED, self.on_scaling_changed) + eventbus.add_listener(AppEvents.AI_BATCH_SIZE_CHANGED, self.on_ai_batch_size_changed) # event handling + def on_ai_batch_size_changed(self, event): + self.prefs.ai_batch_size = event["ai_batch_size"] + def on_bge_ai_version_changed(self, event): self.prefs.bge_ai_version = event["bge_ai_version"] @@ -133,9 +137,9 @@ def on_calculate_request(self, event=None): try: self.prefs.images_linked_option = False - + img_array_to_be_processed = np.copy(self.images.get("Original").img_array) - + background = AstroImage() background.set_from_array( extract_background( @@ -166,7 +170,7 @@ def on_calculate_request(self, event=None): self.images.set("Gradient-Corrected", gradient_corrected) self.images.set("Background", background) - + self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option), self.prefs.saturation) eventbus.emit(AppEvents.CALCULATE_SUCCESS) @@ -312,10 +316,10 @@ def on_save_as_changed(self, event): def on_smoothing_changed(self, event): self.prefs.smoothing_option = event["smoothing_option"] - + def on_denoise_strength_changed(self, event): self.prefs.denoise_strength = event["denoise_strength"] - + def on_denoise_request(self, event): if self.images.get("Original") is None: messagebox.showerror("Error", _("Please load your picture first.")) @@ -330,12 +334,12 @@ def on_denoise_request(self, event): try: img_array_to_be_processed = np.copy(self.images.get("Original").img_array) - if (self.images.get("Gradient-Corrected") is not None): + if self.images.get("Gradient-Corrected") is not None: img_array_to_be_processed = np.copy(self.images.get("Gradient-Corrected").img_array) - + self.prefs.images_linked_option = True ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version) - imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, progress=progress) + imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, progress=progress) denoised = AstroImage() denoised.set_from_array(imarray) @@ -345,9 +349,9 @@ def on_denoise_request(self, event): denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state) denoised.copy_metadata(self.images.get("Original")) - + self.images.set("Denoised", denoised) - + self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation) eventbus.emit(AppEvents.DENOISE_SUCCESS) @@ -375,13 +379,13 @@ def on_save_request(self, event): eventbus.emit(AppEvents.SAVE_BEGIN) try: - if (self.images.get("Denoised") is not None): + if self.images.get("Denoised") is not None: self.images.get("Denoised").save(dir, self.prefs.saveas_option) - elif (self.images.get("Gradient-Corrected") is not None): + elif self.images.get("Gradient-Corrected") is not None: self.images.get("Gradient-Corrected").save(dir, self.prefs.saveas_option) else: self.images.get("Original").save(dir, self.prefs.saveas_option) - + except Exception as e: logging.exception(e) eventbus.emit(AppEvents.SAVE_ERROR) @@ -425,13 +429,13 @@ def on_save_stretched_request(self, event): eventbus.emit(AppEvents.SAVE_BEGIN) try: - if (self.images.get("Denoised") is not None): + if self.images.get("Denoised") is not None: self.images.get("Denoised").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option)) - elif (self.images.get("Gradient-Corrected") is not None): + elif self.images.get("Gradient-Corrected") is not None: self.images.get("Gradient-Corrected").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option)) else: self.images.get("Original").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option)) - + except Exception as e: eventbus.emit(AppEvents.SAVE_ERROR) logging.exception(e) diff --git a/graxpert/application/app_events.py b/graxpert/application/app_events.py index ad5a1b1..147d585 100644 --- a/graxpert/application/app_events.py +++ b/graxpert/application/app_events.py @@ -76,3 +76,4 @@ class AppEvents(Enum): CORRECTION_TYPE_CHANGED = auto() LANGUAGE_CHANGED = auto() SCALING_CHANGED = auto() + AI_BATCH_SIZE_CHANGED = auto() diff --git a/graxpert/cmdline_tools.py b/graxpert/cmdline_tools.py index b602944..4dd9487 100644 --- a/graxpert/cmdline_tools.py +++ b/graxpert/cmdline_tools.py @@ -220,6 +220,8 @@ def execute(self): preferences.ai_version = json_prefs["ai_version"] if "denoise_strength" in json_prefs: preferences.denoise_strength = json_prefs["denoise_strength"] + if "ai_batch_size" in json_prefs: + preferences.ai_batch_size = json_prefs["ai_batch_size"] except Exception as e: logging.exception(e) @@ -233,6 +235,12 @@ def execute(self): logging.info(f"Using user-supplied denoise strength value {preferences.denoise_strength}.") else: logging.info(f"Using stored denoise strength value {preferences.denoise_strength}.") + + if self.args.ai_batch_size is not None: + preferences.ai_batch_size = self.args.ai_batch_size + logging.info(f"Using user-supplied batch size value {preferences.ai_batch_size}.") + else: + logging.info(f"Using stored batch size value {preferences.ai_batch_size}.") ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.get_ai_version(preferences)) @@ -249,7 +257,8 @@ def execute(self): denoise( astro_Image.img_array, ai_model_path, - preferences.denoise_strength + preferences.denoise_strength, + batch_size=preferences.ai_batch_size )) processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format()) diff --git a/graxpert/denoising.py b/graxpert/denoising.py index 76345de..17d9935 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -8,7 +8,7 @@ from graxpert.ai_model_handling import get_execution_providers_ordered -def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, progress=None): +def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, progress=None): logging.info("Starting denoising") diff --git a/graxpert/main.py b/graxpert/main.py index 36de75b..a8b903f 100644 --- a/graxpert/main.py +++ b/graxpert/main.py @@ -241,6 +241,15 @@ def main(): type=float, help='Strength of the desired denoising effect, default: "1.0"', ) + denoise_parser.add_argument( + "-batch_size", + "--ai_batch_size", + nargs="?", + required=False, + default=None, + type=int, + help='Number of image tiles which Graxpert will denoise in parallel. Be careful: increasing this value might result in out-of-memory errors. Valid Range: 1..50, default: "3"', + ) if "-h" in sys.argv or "--help" in sys.argv: if "denoising" in sys.argv: diff --git a/graxpert/preferences.py b/graxpert/preferences.py index bc25e93..dd92073 100644 --- a/graxpert/preferences.py +++ b/graxpert/preferences.py @@ -40,6 +40,7 @@ class Prefs: denoise_ai_version: AnyStr = None graxpert_version: AnyStr = graxpert_version denoise_strength: float = 0.5 + ai_batch_size: int = 3 def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs: diff --git a/graxpert/ui/right_menu.py b/graxpert/ui/right_menu.py index 2b516db..4951b7b 100644 --- a/graxpert/ui/right_menu.py +++ b/graxpert/ui/right_menu.py @@ -159,6 +159,11 @@ def __init__(self, master, **kwargs): self.denoise_ai_options.insert(0, "None") self.denoise_ai_version.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_AI_VERSION_CHANGED, {"denoise_ai_version": self.denoise_ai_version.get()})) + # ai settings + self.ai_batch_size = tk.IntVar() + self.ai_batch_size.set(graxpert.prefs.ai_batch_size) + self.ai_batch_size.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.AI_BATCH_SIZE_CHANGED, {"ai_batch_size": self.ai_batch_size.get()})) + self.create_and_place_children() self.setup_layout() @@ -206,6 +211,9 @@ def lang_change(lang): CTkLabel(self, text=_("Denoising AI-Model"), font=self.heading_font2).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N) GraXpertOptionMenu(self, variable=self.denoise_ai_version, values=self.denoise_ai_options).grid(**self.default_grid()) + # ai settings + ValueSlider(self, variable=self.ai_batch_size, variable_name=_("AI Batch Size"), min_value=1, max_value=50, precision=0).grid(**self.default_grid()) + def setup_layout(self): self.columnconfigure(0, weight=1) self.rowconfigure(0, weight=1)