diff --git a/graxpert/ai_model_handling.py b/graxpert/ai_model_handling.py index fd515e8..998d578 100644 --- a/graxpert/ai_model_handling.py +++ b/graxpert/ai_model_handling.py @@ -28,7 +28,7 @@ try: os.rename(ai_models_dir, bge_ai_models_dir) except Exception as e: - logging.error(f"Renaming {ai_models_dir} to {bge_ai_models_dir} failed. {bge_ai_models_dir} will be newly created. Consider deleting obsolete {ai_models_dir}.") + logging.error(f"Renaming {ai_models_dir} to {bge_ai_models_dir} failed. {bge_ai_models_dir} will be newly created. Consider deleting obsolete {ai_models_dir} manually.") os.makedirs(bge_ai_models_dir, exist_ok=True) @@ -64,7 +64,11 @@ def list_remote_versions(bucket_name): def list_local_versions(ai_models_dir): try: - model_dirs = [{"path": os.path.join(ai_models_dir, f), "version": f} for f in os.listdir(ai_models_dir) if re.search(r"\d\.\d\.\d", f)] # match semantic version + model_dirs = [ + {"path": os.path.join(ai_models_dir, f), "version": f} + for f in os.listdir(ai_models_dir) + if re.search(r"\d\.\d\.\d", f) and len(os.listdir(os.path.join(ai_models_dir, f))) > 0 # match semantic version + ] return model_dirs except Exception as e: logging.exception(e) @@ -122,11 +126,11 @@ def cleanup_orphaned_local_versions(orphaned_local_versions): logging.exception(e) -def download_version(ai_models_dir, bucket_name, remote_version, progress=None): +def download_version(ai_models_dir, bucket_name, target_version, progress=None): try: remote_versions = list_remote_versions(bucket_name) for r in remote_versions: - if remote_version == r["version"]: + if target_version == r["version"]: remote_version = r break @@ -144,11 +148,11 @@ def download_version(ai_models_dir, bucket_name, remote_version, progress=None): with zipfile.ZipFile(ai_model_zip, "r") as zip_ref: zip_ref.extractall(ai_model_dir) - + if not os.path.isfile(ai_model_file): raise ValueError(f"Could not find ai 'model.onnx' file after extracting {ai_model_zip}") os.remove(ai_model_zip) - + except Exception as e: # try to delete (rollback) ai_model_dir in case of errors logging.exception(e) @@ -163,7 +167,6 @@ def validate_local_version(ai_models_dir, local_version): def get_execution_providers_ordered(): - supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", - "CPUExecutionProvider"] + supported_providers = ["DmlExecutionProvider", "CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"] return [provider for provider in supported_providers if provider in ort.get_available_providers()] diff --git a/graxpert/denoising.py b/graxpert/denoising.py index d8484d6..abd3390 100644 --- a/graxpert/denoising.py +++ b/graxpert/denoising.py @@ -6,20 +6,31 @@ import onnxruntime as ort from graxpert.ai_model_handling import get_execution_providers_ordered -from graxpert.application.eventbus import eventbus from graxpert.application.app_events import AppEvents +from graxpert.application.eventbus import eventbus from graxpert.ui.ui_events import UiEvents -def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, progress=None): + +def denoise(image, ai_path, strength, batch_size=4, window_size=256, stride=128, progress=None): logging.info("Starting denoising") + if batch_size < 1: + logging.info(f"mapping batch_size of {batch_size} to 1") + batch_size = 1 + elif batch_size > 32: + logging.info(f"mapping batch_size of {batch_size} to 32") + batch_size = 32 + elif not (batch_size & (batch_size - 1) == 0): # check if batch_size is power of two + logging.info(f"mapping batch_size of {batch_size} to {2 ** (batch_size).bit_length() // 2}") + batch_size = 2 ** (batch_size).bit_length() // 2 # map batch_size to power of two + input = copy.deepcopy(image) global cached_denoised_image if cached_denoised_image is not None: return blend_images(input, cached_denoised_image, strength) - + num_colors = image.shape[-1] if num_colors == 1: image = np.array([image[:, :, 0], image[:, :, 0], image[:, :, 0]]) @@ -51,9 +62,7 @@ def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, output = copy.deepcopy(image) providers = get_execution_providers_ordered() - ort_options = ort.SessionOptions() - ort_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL - session = ort.InferenceSession(ai_path, providers=providers, sess_options=ort_options) + session = ort.InferenceSession(ai_path, providers=providers) logging.info(f"Available inference providers : {providers}") logging.info(f"Used inference providers : {session.get_providers()}") @@ -84,7 +93,7 @@ def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, if not input_tiles: continue - + input_tiles = np.array(input_tiles) output_tiles = [] @@ -131,15 +140,18 @@ def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, return output + def blend_images(original_image, denoised_image, strength): - blend = denoised_image * strength + original_image * (1-strength) + blend = denoised_image * strength + original_image * (1 - strength) return np.clip(blend, 0, 1) + def reset_cached_denoised_image(event): global cached_denoised_image cached_denoised_image = None + cached_denoised_image = None eventbus.add_listener(AppEvents.LOAD_IMAGE_REQUEST, reset_cached_denoised_image) eventbus.add_listener(AppEvents.CALCULATE_REQUEST, reset_cached_denoised_image) -eventbus.add_listener(UiEvents.APPLY_CROP_REQUEST, reset_cached_denoised_image) \ No newline at end of file +eventbus.add_listener(UiEvents.APPLY_CROP_REQUEST, reset_cached_denoised_image) diff --git a/graxpert/main.py b/graxpert/main.py index a8b903f..5c22b34 100644 --- a/graxpert/main.py +++ b/graxpert/main.py @@ -49,6 +49,9 @@ def collect_available_versions(ai_models_dir, bucket_name): def bge_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")): return version_type(bge_ai_models_dir, bge_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")) +def denoise_version_type(arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")): + return version_type(denoise_ai_models_dir, denoise_bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")) + def version_type(ai_models_dir, bucket_name, arg_value, pat=re.compile(r"^\d+\.\d+\.\d+$")): @@ -227,7 +230,7 @@ def main(): nargs="?", required=False, default=None, - type=bge_version_type, + type=denoise_version_type, help='Version of the Denoising AI model, default: "latest"; available locally: [{}], available remotely: [{}]'.format( ", ".join(available_denoise_versions[0]), ", ".join(available_denoise_versions[1]) ), @@ -239,7 +242,7 @@ def main(): required=False, default=None, type=float, - help='Strength of the desired denoising effect, default: "1.0"', + help='Strength of the desired denoising effect, default: "0.5"', ) denoise_parser.add_argument( "-batch_size", @@ -248,7 +251,7 @@ def main(): required=False, default=None, type=int, - help='Number of image tiles which Graxpert will denoise in parallel. Be careful: increasing this value might result in out-of-memory errors. Valid Range: 1..50, default: "3"', + help='Number of image tiles which Graxpert will denoise in parallel. Be careful: increasing this value might result in out-of-memory errors. Valid Range: 1..32, default: "4"', ) if "-h" in sys.argv or "--help" in sys.argv: diff --git a/graxpert/preferences.py b/graxpert/preferences.py index dd92073..149cfd4 100644 --- a/graxpert/preferences.py +++ b/graxpert/preferences.py @@ -40,7 +40,7 @@ class Prefs: denoise_ai_version: AnyStr = None graxpert_version: AnyStr = graxpert_version denoise_strength: float = 0.5 - ai_batch_size: int = 3 + ai_batch_size: int = 4 def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs: diff --git a/graxpert/ui/right_menu.py b/graxpert/ui/right_menu.py index 33a43d7..67c8e19 100644 --- a/graxpert/ui/right_menu.py +++ b/graxpert/ui/right_menu.py @@ -187,6 +187,7 @@ def __init__(self, master, **kwargs): self.denoise_ai_version.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_AI_VERSION_CHANGED, {"denoise_ai_version": self.denoise_ai_version.get()})) # ai settings + self.ai_batch_size_options = ["1","2","4","8","16","32"] self.ai_batch_size = tk.IntVar() self.ai_batch_size.set(graxpert.prefs.ai_batch_size) self.ai_batch_size.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.AI_BATCH_SIZE_CHANGED, {"ai_batch_size": self.ai_batch_size.get()})) @@ -239,7 +240,8 @@ def lang_change(lang): GraXpertOptionMenu(self, variable=self.denoise_ai_version, values=self.denoise_ai_options).grid(**self.default_grid()) # ai settings - ValueSlider(self, variable=self.ai_batch_size, variable_name=_("AI Batch Size"), min_value=1, max_value=50, precision=0).grid(**self.default_grid()) + CTkLabel(self, text=_("AI inference batch size"), font=self.heading_font2).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N) + GraXpertOptionMenu(self, variable=self.ai_batch_size, values=self.ai_batch_size_options).grid(**self.default_grid()) def setup_layout(self): self.columnconfigure(0, weight=1)