Skip to content

Commit

Permalink
Management of all images is now done in AstroImageRepository class
Browse files Browse the repository at this point in the history
  • Loading branch information
Steffenhir committed Mar 23, 2024
1 parent 55a49f1 commit a504e36
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 95 deletions.
54 changes: 54 additions & 0 deletions graxpert/AstroImageRepository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from graxpert.astroimage import AstroImage
from graxpert.stretch import StretchParameters, stretch_all
from typing import Dict

class AstroImageRepository:
images: Dict = {"Original": None, "Gradient-Corrected": None, "Background": None, "Denoised": None}

def set(self, type:str, image:AstroImage):
self.images[type] = image

def get(self, type:str):
return self.images[type]

def stretch_all(self, stretch_params:StretchParameters, saturation:float):
all_image_arrays = []

for key, value in self.images.items():
if (value is not None):
all_image_arrays.append(value.img_array)


stretches = stretch_all(all_image_arrays, stretch_params)

i = 0
for key, value in self.images.items():
if (value is not None):
value.update_display_from_array(stretches[i], saturation)
i = i + 1

def crop_all(self, start_x:float, end_x:float, start_y:float, end_y:float):
for key, astroimg in self.images.items():
if astroimg is not None:
astroimg.crop(start_x, end_x, start_y, end_y)

def update_saturation(self, saturation):
for key, value in self.images.items():
if (value is not None):
value.update_saturation(saturation)

def reset(self):
for key, value in self.images.items():
self.images[key] = None

def display_options(self):
display_options = []

for key, value in self.images.items():
if (self.images[key] is not None):
display_options.append(key)

return display_options



114 changes: 50 additions & 64 deletions graxpert/application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from graxpert.application.app_events import AppEvents
from graxpert.application.eventbus import eventbus
from graxpert.astroimage import AstroImage
from graxpert.AstroImageRepository import AstroImageRepository
from graxpert.background_extraction import extract_background
from graxpert.commands import INIT_HANDLER, RESET_POINTS_HANDLER, RM_POINT_HANDLER, SEL_POINTS_HANDLER, Command
from graxpert.denoising import denoise
Expand All @@ -34,7 +35,7 @@ def initialize(self):
self.filename = ""
self.data_type = ""

self.images = {"Original": None, "Background": None, "Processed": None}
self.images = AstroImageRepository()
self.display_type = "Original"

self.mat_affine = np.eye(3)
Expand Down Expand Up @@ -98,7 +99,7 @@ def on_bg_tol_changed(self, event):
self.prefs.bg_tol_option = event["bg_tol_option"]

def on_calculate_request(self, event=None):
if self.images["Original"] is None:
if self.images.get("Original") is None:
messagebox.showerror("Error", _("Please load your picture first."))
return

Expand All @@ -125,7 +126,7 @@ def on_calculate_request(self, event=None):

progress = DynamicProgressThread(callback=lambda p: eventbus.emit(AppEvents.CALCULATE_PROGRESS, {"progress": p}))

imarray = np.copy(self.images["Original"].img_array)
imarray = np.copy(self.images.get("Original").img_array)

downscale_factor = 1

Expand All @@ -134,8 +135,8 @@ def on_calculate_request(self, event=None):

try:
self.prefs.images_linked_option = False
self.images["Background"] = AstroImage()
self.images["Background"].set_from_array(
background = AstroImage()
background.set_from_array(
extract_background(
imarray,
np.array(background_points),
Expand All @@ -151,25 +152,24 @@ def on_calculate_request(self, event=None):
)
)

self.images["Processed"] = AstroImage()
self.images["Processed"].set_from_array(imarray)
gradient_corrected = AstroImage()
gradient_corrected.set_from_array(imarray)

# Update fits header and metadata
background_mean = np.mean(self.images["Background"].img_array)
self.images["Processed"].update_fits_header(self.images["Original"].fits_header, background_mean, self.prefs, self.cmd.app_state)
self.images["Background"].update_fits_header(self.images["Original"].fits_header, background_mean, self.prefs, self.cmd.app_state)
background_mean = np.mean(background.img_array)
gradient_corrected.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)
gradient_corrected.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)

self.images["Processed"].copy_metadata(self.images["Original"])
self.images["Background"].copy_metadata(self.images["Original"])
gradient_corrected.copy_metadata(self.images.get("Original"))
background.copy_metadata(self.images.get("Original"))

all_images = [self.images["Original"].img_array, self.images["Processed"].img_array, self.images["Background"].img_array]
stretches = stretch_all(all_images, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
self.images["Original"].update_display_from_array(stretches[0], self.prefs.saturation)
self.images["Processed"].update_display_from_array(stretches[1], self.prefs.saturation)
self.images["Background"].update_display_from_array(stretches[2], self.prefs.saturation)
self.images.set("Gradient-Corrected", gradient_corrected)
self.images.set("Background", background)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.CALCULATE_SUCCESS)
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Processed"})
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Gradient-Corrected"})

except Exception as e:
logging.exception(e)
Expand All @@ -184,23 +184,21 @@ def on_change_saturation_request(self, event):

eventbus.emit(AppEvents.CHANGE_SATURATION_BEGIN)

for img in self.images.values():
if img is not None:
img.update_saturation(self.prefs.saturation)
self.images.update_saturation(self.prefs.saturation)

eventbus.emit(AppEvents.CHANGE_SATURATION_END)

def on_correction_type_changed(self, event):
self.prefs.corr_type = event["corr_type"]

def on_create_grid_request(self, event=None):
if self.images["Original"] is None:
if self.images.get("Original") is None:
messagebox.showerror("Error", _("Please load your picture first."))
return

eventbus.emit(AppEvents.CREATE_GRID_BEGIN)

self.cmd = Command(SEL_POINTS_HANDLER, self.cmd, data=self.images["Original"].img_array, num_pts=self.prefs.bg_pts_option, tol=self.prefs.bg_tol_option, sample_size=self.prefs.sample_size)
self.cmd = Command(SEL_POINTS_HANDLER, self.cmd, data=self.images.get("Original").img_array, num_pts=self.prefs.bg_pts_option, tol=self.prefs.bg_tol_option, sample_size=self.prefs.sample_size)
self.cmd.execute()

eventbus.emit(AppEvents.CREATE_GRID_END)
Expand Down Expand Up @@ -243,23 +241,22 @@ def on_load_image(self, event):
self.filename = os.path.splitext(os.path.basename(filename))[0]

self.data_type = os.path.splitext(filename)[1]
self.images["Original"] = image
self.images["Processed"] = None
self.images["Background"] = None
self.images.reset()
self.images.set("Original", image)
self.prefs.working_dir = os.path.dirname(filename)

os.chdir(os.path.dirname(filename))

width = self.images["Original"].img_display.width
height = self.images["Original"].img_display.height
width = self.images.get("Original").img_display.width
height = self.images.get("Original").img_display.height

if self.prefs.width != width or self.prefs.height != height:
self.reset_backgroundpts()

self.prefs.width = width
self.prefs.height = height

tmp_state = fitsheader_2_app_state(self, self.cmd.app_state, self.images["Original"].fits_header)
tmp_state = fitsheader_2_app_state(self, self.cmd.app_state, self.images.get("Original").fits_header)
self.cmd: Command = Command(INIT_HANDLER, background_points=tmp_state.background_points)
self.cmd.execute()

Expand Down Expand Up @@ -319,7 +316,7 @@ def on_denoise_strength_changed(self, event):
self.prefs.denoise_strength = event["denoise_strength"]

def on_denoise_request(self, event):
if self.images["Original"] is None:
if self.images.get("Original") is None:
messagebox.showerror("Error", _("Please load your picture first."))
return

Expand All @@ -333,24 +330,23 @@ def on_denoise_request(self, event):
try:
self.prefs.images_linked_option = True
ai_model_path = ai_model_path_from_version(denoise_ai_models_dir, self.prefs.denoise_ai_version)
imarray = denoise(self.images["Original"].img_array, ai_model_path, self.prefs.denoise_strength, progress=progress)
imarray = denoise(self.images.get("Original").img_array, ai_model_path, self.prefs.denoise_strength, progress=progress)

self.images["Processed"] = AstroImage()
self.images["Processed"].set_from_array(imarray)
denoised = AstroImage()
denoised.set_from_array(imarray)

# Update fits header and metadata
background_mean = np.mean(self.images["Original"].img_array)
self.images["Processed"].update_fits_header(self.images["Original"].fits_header, background_mean, self.prefs, self.cmd.app_state)
background_mean = np.mean(self.images.get("Original").img_array)
denoised.update_fits_header(self.images.get("Original").fits_header, background_mean, self.prefs, self.cmd.app_state)

self.images["Processed"].copy_metadata(self.images["Original"])

all_images = [self.images["Original"].img_array, self.images["Processed"].img_array]
stretches = stretch_all(all_images, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option))
self.images["Original"].update_display_from_array(stretches[0], self.prefs.saturation)
self.images["Processed"].update_display_from_array(stretches[1], self.prefs.saturation)
denoised.copy_metadata(self.images.get("Original"))

self.images.set("Denoised", denoised)

self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)

eventbus.emit(AppEvents.DENOISE_SUCCESS)
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Processed"})
eventbus.emit(AppEvents.UPDATE_DISPLAY_TYPE_REEQUEST, {"display_type": "Denoised"})

except Exception as e:
logging.exception(e)
Expand All @@ -374,7 +370,7 @@ def on_save_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
self.images["Processed"].save(dir, self.prefs.saveas_option)
self.images.get("Gradient-Corrected").save(dir, self.prefs.saveas_option)
except Exception as e:
logging.exception(e)
eventbus.emit(AppEvents.SAVE_ERROR)
Expand All @@ -396,7 +392,7 @@ def on_save_background_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
self.images["Background"].save(dir, self.prefs.saveas_option)
self.images.get("Background").save(dir, self.prefs.saveas_option)
except Exception as e:
logging.exception(e)
eventbus.emit(AppEvents.SAVE_ERROR)
Expand All @@ -418,10 +414,10 @@ def on_save_stretched_request(self, event):
eventbus.emit(AppEvents.SAVE_BEGIN)

try:
if self.images["Processed"] is None:
self.images["Original"].save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
if self.images.get("Gradient-Corrected") is None:
self.images.get("Original").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
else:
self.images["Processed"].save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
self.images.get("Gradient-Corrected").save_stretched(dir, self.prefs.saveas_option, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option))
except Exception as e:
eventbus.emit(AppEvents.SAVE_ERROR)
logging.exception(e)
Expand All @@ -448,17 +444,7 @@ def do_stretch(self):
eventbus.emit(AppEvents.STRETCH_IMAGE_BEGIN)

try:
all_images = []
all_image_arrays = []
stretches = []
for img in self.images.values():
if img is not None:
all_images.append(img)
all_image_arrays.append(img.img_array)
if len(all_images) > 0:
stretches = stretch_all(all_image_arrays, StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option))
for idx, img in enumerate(all_images):
all_images[idx].update_display_from_array(stretches[idx], self.prefs.saturation)
self.images.stretch_all(StretchParameters(self.prefs.stretch_option, self.prefs.channels_linked_option, self.prefs.images_linked_option), self.prefs.saturation)
except Exception as e:
eventbus.emit(AppEvents.STRETCH_IMAGE_ERROR)
logging.exception(e)
Expand Down Expand Up @@ -522,29 +508,29 @@ def to_canvas_point(self, x, y):
return np.dot(self.mat_affine, (x, y, 1.0))

def to_image_point(self, x, y):
if self.images[self.display_type] is None:
if self.images.get(self.display_type) is None:
return []

mat_inv = np.linalg.inv(self.mat_affine)
image_point = np.dot(mat_inv, (x, y, 1.0))

width = self.images[self.display_type].width
height = self.images[self.display_type].height
width = self.images.get(self.display_type).width
height = self.images.get(self.display_type).height

if image_point[0] < 0 or image_point[1] < 0 or image_point[0] > width or image_point[1] > height:
return []

return image_point

def to_image_point_pinned(self, x, y):
if self.images[self.display_type] is None:
if self.images.get(self.display_type) is None:
return []

mat_inv = np.linalg.inv(self.mat_affine)
image_point = np.dot(mat_inv, (x, y, 1.0))

width = self.images[self.display_type].width
height = self.images[self.display_type].height
width = self.images.get(self.display_type).width
height = self.images.get(self.display_type).height

if image_point[0] < 0:
image_point[0] = 0
Expand Down
Loading

0 comments on commit a504e36

Please sign in to comment.