From 0c5944303d217bbcaae7b623f4932d19c5a1b0bf Mon Sep 17 00:00:00 2001 From: Steffen Hirtle Date: Sun, 21 Jan 2024 14:05:26 +0100 Subject: [PATCH] Images can now be stretched with linked channels --- graxpert/stretch.py | 56 ++++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/graxpert/stretch.py b/graxpert/stretch.py index 77aac71..cb91cb8 100644 --- a/graxpert/stretch.py +++ b/graxpert/stretch.py @@ -10,6 +10,13 @@ from graxpert.mp_logging import get_logging_queue, worker_configurer from graxpert.parallel_processing import executor +from dataclasses import dataclass + +@dataclass +class MTFStretchParameters: + midtone: float + shadow_clipping: float + highlight_clipping: float = 1.0 class StretchParameters: stretch_option: str @@ -19,8 +26,10 @@ class StretchParameters: channels_linked: bool = False images_linked: bool = False - def __init__(self, stretch_option: str): + def __init__(self, stretch_option: str, channels_linked: bool = False, images_linked: bool = False): self.stretch_option = stretch_option + self.channels_linked = channels_linked + self. images_linked = images_linked if stretch_option == "No Stretch": self.do_stretch = False @@ -42,7 +51,7 @@ def __init__(self, stretch_option: str): self.sigma = 2.0 -def stretch_channel(shm_name, c, bg, sigma, shape, dtype, logging_queue, logging_configurer): +def stretch_channel(shm_name, c, stretch_params, mtf_stretch_params, shape, dtype, logging_queue, logging_configurer): logging_configurer(logging_queue) logging.info("stretch.stretch_channel started") @@ -52,23 +61,17 @@ def stretch_channel(shm_name, c, bg, sigma, shape, dtype, logging_queue, logging channel = channels[:,:,c] try: - indx_clip = np.logical_and(channel < 1.0, channel > 0.0) - median = np.median(channel[indx_clip]) - mad = np.median(np.abs(channel[indx_clip]-median)) - - shadow_clipping = np.clip(median - sigma*mad, 0, 1.0) - highlight_clipping = 1.0 - - midtone = MTF((median-shadow_clipping)/(highlight_clipping - shadow_clipping), bg) + if not mtf_stretch_params: + mtf_stretch_params = calculate_mtf_stretch_parameters(stretch_params, channel) - channel[channel <= shadow_clipping] = 0.0 - channel[channel >= highlight_clipping] = 1.0 + channel[channel <= mtf_stretch_params.shadow_clipping] = 0.0 + channel[channel >= mtf_stretch_params.highlight_clipping] = 1.0 - indx_inside = np.logical_and(channel > shadow_clipping, channel < highlight_clipping) + indx_inside = np.logical_and(channel > mtf_stretch_params.shadow_clipping, channel < mtf_stretch_params.highlight_clipping) - channel[indx_inside] = (channel[indx_inside]-shadow_clipping)/(highlight_clipping - shadow_clipping) + channel[indx_inside] = (channel[indx_inside]-mtf_stretch_params.shadow_clipping)/(mtf_stretch_params.highlight_clipping - mtf_stretch_params.shadow_clipping) - channel = MTF(channel, midtone) + channel = MTF(channel, mtf_stretch_params.midtone) except: logging.exception("An error occured while stretching a color channel") @@ -76,6 +79,20 @@ def stretch_channel(shm_name, c, bg, sigma, shape, dtype, logging_queue, logging existing_shm.close() logging.info("stretch.stretch_channel finished") + +def calculate_mtf_stretch_parameters(stretch_params, channel): + channel = channel.flatten() + + indx_clip = np.logical_and(channel < 1.0, channel > 0.0) + median = np.median(channel[indx_clip]) + mad = np.median(np.abs(channel[indx_clip]-median)) + + shadow_clipping = np.clip(median - stretch_params.sigma*mad, 0, 1.0) + highlight_clipping = 1.0 + midtone = MTF((median-shadow_clipping)/(highlight_clipping - shadow_clipping), stretch_params.bg) + + return MTFStretchParameters(midtone, shadow_clipping) + def stretch(data, stretch_params: StretchParameters): return stretch_all([data], stretch_params)[0] @@ -86,8 +103,6 @@ def stretch_all(datas, stretch_params: StretchParameters): datas = [data.clip(min=0, max=1) for data in datas] return datas - bg = stretch_params.bg - sigma = stretch_params.sigma futures = [] shms = [] copies = [] @@ -100,8 +115,13 @@ def stretch_all(datas, stretch_params: StretchParameters): np.copyto(copy, data) shms.append(shm) copies.append(copy) + + mtf_stretch_params = None + if stretch_params.channels_linked: + mtf_stretch_params = calculate_mtf_stretch_parameters(stretch_params, copy) + for c in range(copy.shape[-1]): - futures.insert(c, executor.submit(stretch_channel, shm.name, c, bg, sigma, copy.shape, copy.dtype, logging_queue, worker_configurer)) + futures.insert(c, executor.submit(stretch_channel, shm.name, c, stretch_params, mtf_stretch_params, copy.shape, copy.dtype, logging_queue, worker_configurer)) wait(futures) for copy in copies: