Skip to content

Commit

Permalink
Images can now be stretched with linked channels
Browse files Browse the repository at this point in the history
  • Loading branch information
Steffenhir committed Jan 21, 2024
1 parent b3b42dd commit 0c59443
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions graxpert/stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from graxpert.mp_logging import get_logging_queue, worker_configurer
from graxpert.parallel_processing import executor

from dataclasses import dataclass

@dataclass
class MTFStretchParameters:
midtone: float
shadow_clipping: float
highlight_clipping: float = 1.0

class StretchParameters:
stretch_option: str
Expand All @@ -19,8 +26,10 @@ class StretchParameters:
channels_linked: bool = False
images_linked: bool = False

def __init__(self, stretch_option: str):
def __init__(self, stretch_option: str, channels_linked: bool = False, images_linked: bool = False):
self.stretch_option = stretch_option
self.channels_linked = channels_linked
self. images_linked = images_linked

if stretch_option == "No Stretch":
self.do_stretch = False
Expand All @@ -42,7 +51,7 @@ def __init__(self, stretch_option: str):
self.sigma = 2.0


def stretch_channel(shm_name, c, bg, sigma, shape, dtype, logging_queue, logging_configurer):
def stretch_channel(shm_name, c, stretch_params, mtf_stretch_params, shape, dtype, logging_queue, logging_configurer):

logging_configurer(logging_queue)
logging.info("stretch.stretch_channel started")
Expand All @@ -52,30 +61,38 @@ def stretch_channel(shm_name, c, bg, sigma, shape, dtype, logging_queue, logging
channel = channels[:,:,c]

try:
indx_clip = np.logical_and(channel < 1.0, channel > 0.0)
median = np.median(channel[indx_clip])
mad = np.median(np.abs(channel[indx_clip]-median))

shadow_clipping = np.clip(median - sigma*mad, 0, 1.0)
highlight_clipping = 1.0

midtone = MTF((median-shadow_clipping)/(highlight_clipping - shadow_clipping), bg)
if not mtf_stretch_params:
mtf_stretch_params = calculate_mtf_stretch_parameters(stretch_params, channel)

channel[channel <= shadow_clipping] = 0.0
channel[channel >= highlight_clipping] = 1.0
channel[channel <= mtf_stretch_params.shadow_clipping] = 0.0
channel[channel >= mtf_stretch_params.highlight_clipping] = 1.0

indx_inside = np.logical_and(channel > shadow_clipping, channel < highlight_clipping)
indx_inside = np.logical_and(channel > mtf_stretch_params.shadow_clipping, channel < mtf_stretch_params.highlight_clipping)

channel[indx_inside] = (channel[indx_inside]-shadow_clipping)/(highlight_clipping - shadow_clipping)
channel[indx_inside] = (channel[indx_inside]-mtf_stretch_params.shadow_clipping)/(mtf_stretch_params.highlight_clipping - mtf_stretch_params.shadow_clipping)

channel = MTF(channel, midtone)
channel = MTF(channel, mtf_stretch_params.midtone)

except:
logging.exception("An error occured while stretching a color channel")
finally:
existing_shm.close()

logging.info("stretch.stretch_channel finished")

def calculate_mtf_stretch_parameters(stretch_params, channel):
channel = channel.flatten()

indx_clip = np.logical_and(channel < 1.0, channel > 0.0)
median = np.median(channel[indx_clip])
mad = np.median(np.abs(channel[indx_clip]-median))

shadow_clipping = np.clip(median - stretch_params.sigma*mad, 0, 1.0)
highlight_clipping = 1.0
midtone = MTF((median-shadow_clipping)/(highlight_clipping - shadow_clipping), stretch_params.bg)

return MTFStretchParameters(midtone, shadow_clipping)


def stretch(data, stretch_params: StretchParameters):
return stretch_all([data], stretch_params)[0]
Expand All @@ -86,8 +103,6 @@ def stretch_all(datas, stretch_params: StretchParameters):
datas = [data.clip(min=0, max=1) for data in datas]
return datas

bg = stretch_params.bg
sigma = stretch_params.sigma
futures = []
shms = []
copies = []
Expand All @@ -100,8 +115,13 @@ def stretch_all(datas, stretch_params: StretchParameters):
np.copyto(copy, data)
shms.append(shm)
copies.append(copy)

mtf_stretch_params = None
if stretch_params.channels_linked:
mtf_stretch_params = calculate_mtf_stretch_parameters(stretch_params, copy)

for c in range(copy.shape[-1]):
futures.insert(c, executor.submit(stretch_channel, shm.name, c, bg, sigma, copy.shape, copy.dtype, logging_queue, worker_configurer))
futures.insert(c, executor.submit(stretch_channel, shm.name, c, stretch_params, mtf_stretch_params, copy.shape, copy.dtype, logging_queue, worker_configurer))
wait(futures)

for copy in copies:
Expand Down

0 comments on commit 0c59443

Please sign in to comment.