Skip to content

Commit

Permalink
Original and Background image do not use Gradient-Corrected image as …
Browse files Browse the repository at this point in the history
…reference anymore after denoising
  • Loading branch information
Steffenhir committed Apr 1, 2024
1 parent 1bb527f commit 5b84ff3
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 85 deletions.
43 changes: 31 additions & 12 deletions graxpert/AstroImageRepository.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from graxpert.astroimage import AstroImage
from graxpert.stretch import StretchParameters, stretch_all
from graxpert.stretch import StretchParameters, stretch_all, calculate_mtf_stretch_parameters_for_image
from typing import Dict

class AstroImageRepository:
Expand All @@ -12,21 +12,40 @@ def get(self, type:str):
return self.images[type]

def stretch_all(self, stretch_params:StretchParameters, saturation:float):

if self.get("Original") is None:
return

all_image_arrays = []
all_mtf_stretch_params = []

for key, value in self.images.items():
if (value is not None):
all_image_arrays.append(value.img_array)

if self.get("Gradient-Corrected") is not None and self.get("Denoised") is not None:
stretches = stretch_all(all_image_arrays, stretch_params, reference_img_array=self.get("Gradient-Corrected").img_array)
else:
stretches = stretch_all(all_image_arrays, stretch_params)
all_image_arrays.append(self.get("Original").img_array)
all_mtf_stretch_params.append(calculate_mtf_stretch_parameters_for_image(stretch_params, self.get("Original").img_array))

if self.get("Gradient-Corrected") is not None and self.get("Background") is not None:
all_image_arrays.append(self.get("Gradient-Corrected").img_array)
all_mtf_stretch_params.append(calculate_mtf_stretch_parameters_for_image(stretch_params, self.get("Gradient-Corrected").img_array))

all_image_arrays.append(self.get("Background").img_array)
all_mtf_stretch_params.append(all_mtf_stretch_params[0])


if self.get("Denoised") is not None and self.get("Gradient-Corrected") is None:
all_image_arrays.append(self.get("Denoised").img_array)
all_mtf_stretch_params.append(all_mtf_stretch_params[0])

elif self.get("Denoised") is not None and self.get("Gradient-Corrected") is not None:
all_image_arrays.append(self.get("Denoised").img_array)
all_mtf_stretch_params.append(all_mtf_stretch_params[1])


stretches = stretch_all(all_image_arrays, all_mtf_stretch_params)


i = 0
for key, value in self.images.items():
if (value is not None):
value.update_display_from_array(stretches[i], saturation)
for key, image in self.images.items():
if image is not None:
image.update_display_from_array(stretches[i], saturation)
i = i + 1

def crop_all(self, start_x:float, end_x:float, start_y:float, end_y:float):
Expand Down
5 changes: 4 additions & 1 deletion graxpert/astroimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def update_display_from_array(self, img_display, saturation):
return

def stretch(self, stretch_params: StretchParameters):
return np.clip(stretch(self.img_array, stretch_params), 0.0, 1.0)
if stretch_params.do_stretch:
return np.clip(stretch(self.img_array, stretch_params), 0.0, 1.0)
else:
return np.clip(self.img_array, 0.0, 1.0)

def crop(self, startx, endx, starty, endy):
self.img_array = self.img_array[starty:endy, startx:endx, :]
Expand Down
128 changes: 56 additions & 72 deletions graxpert/stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,96 +49,31 @@ def __init__(self, stretch_option: str, channels_linked: bool = False, images_li
elif stretch_option == "30% Bg, 2 sigma":
self.bg = 0.3
self.sigma = 2.0


def stretch_channel(shm_name, c, stretch_params, mtf_stretch_params, shape, dtype, logging_queue, logging_configurer):

logging_configurer(logging_queue)
logging.info("stretch.stretch_channel started")

existing_shm = shared_memory.SharedMemory(name=shm_name)
channels = np.ndarray(shape, dtype, buffer=existing_shm.buf) #[:,:,channel_idx]
channel = channels[:,:,c]

try:
if not mtf_stretch_params:
mtf_stretch_params = calculate_mtf_stretch_parameters(stretch_params, channel)

channel[channel <= mtf_stretch_params.shadow_clipping] = 0.0
channel[channel >= mtf_stretch_params.highlight_clipping] = 1.0

indx_inside = np.logical_and(channel > mtf_stretch_params.shadow_clipping, channel < mtf_stretch_params.highlight_clipping)

channel[indx_inside] = (channel[indx_inside]-mtf_stretch_params.shadow_clipping)/(mtf_stretch_params.highlight_clipping - mtf_stretch_params.shadow_clipping)

channel = MTF(channel, mtf_stretch_params.midtone)

except:
logging.exception("An error occured while stretching a color channel")
finally:
existing_shm.close()

logging.info("stretch.stretch_channel finished")

def calculate_mtf_stretch_parameters(stretch_params, channel):
channel = channel.flatten()

indx_clip = np.logical_and(channel < 1.0, channel > 0.0)
median = np.median(channel[indx_clip])
mad = np.median(np.abs(channel[indx_clip]-median))

shadow_clipping = np.clip(median - stretch_params.sigma*mad, 0, 1.0)
highlight_clipping = 1.0
midtone = MTF((median-shadow_clipping)/(highlight_clipping - shadow_clipping), stretch_params.bg)

return MTFStretchParameters(midtone, shadow_clipping)


def stretch(data, stretch_params: StretchParameters):
return stretch_all([data], stretch_params)[0]
mtf_stretch_param = calculate_mtf_stretch_parameters_for_image(stretch_params, data)
return stretch_all([data], [mtf_stretch_param])[0]

def stretch_all(datas, stretch_params: StretchParameters, reference_img_array=None):

if not stretch_params.do_stretch:
datas = [data.clip(min=0, max=1) for data in datas]
return datas

if reference_img_array is None:
reference_img_array = datas[0]

def stretch_all(datas, mtf_stretch_params: list[MTFStretchParameters]):

futures = []
shms = []
copies = []
result = []
logging_queue = get_logging_queue()

common_mtf_stretch_params_per_channel = []
if stretch_params.images_linked:
if stretch_params.channels_linked:
mtf_stretch_params_for_all_channel = calculate_mtf_stretch_parameters(stretch_params, reference_img_array)
common_mtf_stretch_params_per_channel = [mtf_stretch_params_for_all_channel] * reference_img_array.shape[-1]
else:
for c in range(datas[0].shape[-1]):
common_mtf_stretch_params_per_channel.append(calculate_mtf_stretch_parameters(stretch_params, reference_img_array[:,:,c]))

logging_queue = get_logging_queue()

for data in datas:
for data, mtf_stretch_param in zip(datas, mtf_stretch_params):
shm = shared_memory.SharedMemory(create=True, size=data.nbytes)
copy = np.ndarray(data.shape, dtype=data.dtype, buffer=shm.buf)
np.copyto(copy, data)
shms.append(shm)
copies.append(copy)

mtf_stretch_params = [None] * data.shape[-1]

if stretch_params.images_linked:
mtf_stretch_params = common_mtf_stretch_params_per_channel
elif stretch_params.channels_linked:
mtf_stretch_params = calculate_mtf_stretch_parameters(stretch_params, copy)
mtf_stretch_params = [mtf_stretch_params] * data.shape[-1]

for c in range(copy.shape[-1]):
futures.insert(c, executor.submit(stretch_channel, shm.name, c, stretch_params, mtf_stretch_params[c], copy.shape, copy.dtype, logging_queue, worker_configurer))
futures.insert(c, executor.submit(stretch_channel, shm.name, c, mtf_stretch_param[c], copy.shape, copy.dtype, logging_queue, worker_configurer))
wait(futures)

for copy in copies:
Expand All @@ -150,6 +85,55 @@ def stretch_all(datas, stretch_params: StretchParameters, reference_img_array=No
shm.unlink()

return result


def calculate_mtf_stretch_parameters_for_image(stretch_params, image):
if stretch_params.channels_linked:
mtf_stretch_param = calculate_mtf_stretch_parameters_for_channel(stretch_params, image)
return [mtf_stretch_param] * image.shape[-1]

else:
return [calculate_mtf_stretch_parameters_for_channel(stretch_params, image[:,:,i]) for i in range(image.shape[-1])]

def calculate_mtf_stretch_parameters_for_channel(stretch_params, channel):
channel = channel.flatten()[::4]

indx_clip = np.logical_and(channel < 1.0, channel > 0.0)
median = np.median(channel[indx_clip])
mad = np.median(np.abs(channel[indx_clip]-median))

shadow_clipping = np.clip(median - stretch_params.sigma*mad, 0, 1.0)
highlight_clipping = 1.0
midtone = MTF((median-shadow_clipping)/(highlight_clipping - shadow_clipping), stretch_params.bg)

return MTFStretchParameters(midtone, shadow_clipping)


def stretch_channel(shm_name, c, mtf_stretch_params, shape, dtype, logging_queue, logging_configurer):

logging_configurer(logging_queue)
logging.info("stretch.stretch_channel started")

existing_shm = shared_memory.SharedMemory(name=shm_name)
channels = np.ndarray(shape, dtype, buffer=existing_shm.buf) #[:,:,channel_idx]
channel = channels[:,:,c]

try:
channel[channel <= mtf_stretch_params.shadow_clipping] = 0.0
channel[channel >= mtf_stretch_params.highlight_clipping] = 1.0

indx_inside = np.logical_and(channel > mtf_stretch_params.shadow_clipping, channel < mtf_stretch_params.highlight_clipping)

channel[indx_inside] = (channel[indx_inside]-mtf_stretch_params.shadow_clipping)/(mtf_stretch_params.highlight_clipping - mtf_stretch_params.shadow_clipping)

channel = MTF(channel, mtf_stretch_params.midtone)

except:
logging.exception("An error occured while stretching a color channel")
finally:
existing_shm.close()

logging.info("stretch.stretch_channel finished")


def MTF(data, midtone):
Expand Down

0 comments on commit 5b84ff3

Please sign in to comment.