Skip to content

Commit

Permalink
add ai_batch_size property to preferences, advanced settings, and cli
Browse files Browse the repository at this point in the history
  • Loading branch information
schmelly committed Mar 29, 2024
1 parent 1fbb3a4 commit 26e8247
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 18 deletions.
36 changes: 20 additions & 16 deletions graxpert/application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ def initialize(self):
eventbus.add_listener(AppEvents.BGE_AI_VERSION_CHANGED, self.on_bge_ai_version_changed)
eventbus.add_listener(AppEvents.DENOISE_AI_VERSION_CHANGED, self.on_denoise_ai_version_changed)
eventbus.add_listener(AppEvents.SCALING_CHANGED, self.on_scaling_changed)
eventbus.add_listener(AppEvents.AI_BATCH_SIZE_CHANGED, self.on_ai_batch_size_changed)

# event handling
def on_ai_batch_size_changed(self, event):
self.prefs.ai_batch_size = event["ai_batch_size"]

def on_bge_ai_version_changed(self, event):
self.prefs.bge_ai_version = event["bge_ai_version"]

Expand Down Expand Up @@ -133,9 +137,9 @@ def on_calculate_request(self, event=None):

try:
self.prefs.images_linked_option = False

img_array_to_be_processed = np.copy(self.images.get("Original").img_array)

background = AstroImage()
background.set_from_array(
extract_background(
Expand Down Expand Up @@ -166,7 +170,7 @@ def on_calculate_request(self, event=None):

self.images.set("Gradient-Corrected", gradient_corrected)
self.images.set("Background", background)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.CALCULATE_SUCCESS)
Expand Down Expand Up @@ -312,10 +316,10 @@ def on_save_as_changed(self, event):

def on_smoothing_changed(self, event):
self.prefs.smoothing_option = event["smoothing_option"]

def on_denoise_strength_changed(self, event):
self.prefs.denoise_strength = event["denoise_strength"]

def on_denoise_request(self, event):
if self.images.get("Original") is None:
messagebox.showerror("Error", _("Please load your picture first."))
Expand All @@ -330,12 +334,12 @@ def on_denoise_request(self, event):

try:
img_array_to_be_processed = np.copy(self.images.get("Original").img_array)
if (self.images.get("Gradient-Corrected") is not None):
if self.images.get("Gradient-Corrected") is not None:
img_array_to_be_processed = np.copy(self.images.get("Gradient-Corrected").img_array)

self.prefs.images_linked_option = True
ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, progress=progress)
imarray = denoise(img_array_to_be_processed, ai_model_path, self.prefs.denoise_strength, batch_size=self.prefs.ai_batch_size, progress=progress)

denoised = AstroImage()
denoised.set_from_array(imarray)
Expand All @@ -345,9 +349,9 @@ def on_denoise_request(self, event):
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)

denoised.copy_metadata(self.images.get("Original"))

self.images.set("Denoised", denoised)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.DENOISE_SUCCESS)
Expand Down Expand Up @@ -375,13 +379,13 @@ def on_save_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
if (self.images.get("Denoised") is not None):
if self.images.get("Denoised") is not None:
self.images.get("Denoised").save(dir, self.prefs.saveas_option)
elif (self.images.get("Gradient-Corrected") is not None):
elif self.images.get("Gradient-Corrected") is not None:
self.images.get("Gradient-Corrected").save(dir, self.prefs.saveas_option)
else:
self.images.get("Original").save(dir, self.prefs.saveas_option)

except Exception as e:
logging.exception(e)
eventbus.emit(AppEvents.SAVE_ERROR)
Expand Down Expand Up @@ -425,13 +429,13 @@ def on_save_stretched_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
if (self.images.get("Denoised") is not None):
if self.images.get("Denoised") is not None:
self.images.get("Denoised").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
elif (self.images.get("Gradient-Corrected") is not None):
elif self.images.get("Gradient-Corrected") is not None:
self.images.get("Gradient-Corrected").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
else:
self.images.get("Original").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))

except Exception as e:
eventbus.emit(AppEvents.SAVE_ERROR)
logging.exception(e)
Expand Down
1 change: 1 addition & 0 deletions graxpert/application/app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ class AppEvents(Enum):
CORRECTION_TYPE_CHANGED = auto()
LANGUAGE_CHANGED = auto()
SCALING_CHANGED = auto()
AI_BATCH_SIZE_CHANGED = auto()
11 changes: 10 additions & 1 deletion graxpert/cmdline_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def execute(self):
preferences.ai_version = json_prefs["ai_version"]
if "denoise_strength" in json_prefs:
preferences.denoise_strength = json_prefs["denoise_strength"]
if "ai_batch_size" in json_prefs:
preferences.ai_batch_size = json_prefs["ai_batch_size"]

except Exception as e:
logging.exception(e)
Expand All @@ -233,6 +235,12 @@ def execute(self):
logging.info(f"Using user-supplied denoise strength value {preferences.denoise_strength}.")
else:
logging.info(f"Using stored denoise strength value {preferences.denoise_strength}.")

if self.args.ai_batch_size is not None:
preferences.ai_batch_size = self.args.ai_batch_size
logging.info(f"Using user-supplied batch size value {preferences.ai_batch_size}.")
else:
logging.info(f"Using stored batch size value {preferences.ai_batch_size}.")

ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.get_ai_version(preferences))

Expand All @@ -249,7 +257,8 @@ def execute(self):
denoise(
astro_Image.img_array,
ai_model_path,
preferences.denoise_strength
preferences.denoise_strength,
batch_size=preferences.ai_batch_size
))
processed_Astro_Image.save(self.get_save_path(), self.get_output_file_format())

Expand Down
2 changes: 1 addition & 1 deletion graxpert/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graxpert.ai_model_handling import get_execution_providers_ordered


def denoise(image, ai_path, strength, batch_size=5, window_size=256, stride=128, progress=None):
def denoise(image, ai_path, strength, batch_size=3, window_size=256, stride=128, progress=None):

logging.info("Starting denoising")

Expand Down
9 changes: 9 additions & 0 deletions graxpert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ def main():
type=float,
help='Strength of the desired denoising effect, default: "1.0"',
)
denoise_parser.add_argument(
"-batch_size",
"--ai_batch_size",
nargs="?",
required=False,
default=None,
type=int,
help='Number of image tiles which Graxpert will denoise in parallel. Be careful: increasing this value might result in out-of-memory errors. Valid Range: 1..50, default: "3"',
)

if "-h" in sys.argv or "--help" in sys.argv:
if "denoising" in sys.argv:
Expand Down
1 change: 1 addition & 0 deletions graxpert/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Prefs:
denoise_ai_version: AnyStr = None
graxpert_version: AnyStr = graxpert_version
denoise_strength: float = 0.5
ai_batch_size: int = 3


def app_state_2_prefs(prefs: Prefs, app_state: AppState) -> Prefs:
Expand Down
8 changes: 8 additions & 0 deletions graxpert/ui/right_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def __init__(self, master, **kwargs):
self.denoise_ai_options.insert(0, "None")
self.denoise_ai_version.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.DENOISE_AI_VERSION_CHANGED, {"denoise_ai_version": self.denoise_ai_version.get()}))

# ai settings
self.ai_batch_size = tk.IntVar()
self.ai_batch_size.set(graxpert.prefs.ai_batch_size)
self.ai_batch_size.trace_add("write", lambda a, b, c: eventbus.emit(AppEvents.AI_BATCH_SIZE_CHANGED, {"ai_batch_size": self.ai_batch_size.get()}))

self.create_and_place_children()
self.setup_layout()

Expand Down Expand Up @@ -206,6 +211,9 @@ def lang_change(lang):
CTkLabel(self, text=_("Denoising AI-Model"), font=self.heading_font2).grid(column=0, row=self.nrow(), pady=pady, sticky=tk.N)
GraXpertOptionMenu(self, variable=self.denoise_ai_version, values=self.denoise_ai_options).grid(**self.default_grid())

# ai settings
ValueSlider(self, variable=self.ai_batch_size, variable_name=_("AI Batch Size"), min_value=1, max_value=50, precision=0).grid(**self.default_grid())

def setup_layout(self):
self.columnconfigure(0, weight=1)
self.rowconfigure(0, weight=1)
Expand Down

0 comments on commit 26e8247

Please sign in to comment.