Skip to content

Commit

Permalink
Add grayscale option
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Oct 29, 2024
1 parent f3bf775 commit 579f8d6
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions celldetection_scripts/cpn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,25 @@ def on_oom():
return target_device


def preprocess(img, gamma=1., contrast=1., brightness=0., percentile=None):
def preprocess(img, gamma=1., contrast=1., brightness=0., percentile=None, grayscale=False):
# TODO: Add more options
if percentile is not None:
img = cd.data.normalize_percentile(img, percentile)
if img.itemsize > 1:
warn('Performing implicit percentile normalization, since input is not uint8.')
img = cd.data.normalize_percentile(img)
if grayscale and img.ndim == 3:
channels = img.shape[-1]
if channels == 1:
img = img.squeeze(-1)
elif channels == 2:
img = img.mean(-1)
elif channels == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif channels == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
else:
raise ValueError(f'Unsupported number of channels: {channels}')
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

Expand Down Expand Up @@ -300,7 +312,7 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768
transforms=None, model_kwargs_list=None,
batch_size=1, num_workers=0, pin_memory=False, border_removal=4, min_vote=1, stitching_rule='nms',
gamma=1., contrast=1., brightness=0., percentile=None, model_parameters=None,
point_mask_exclusive=False, verbose=True, **kwargs):
point_mask_exclusive=False, verbose=True, grayscale=False, **kwargs):
assert len(models) >= 1, 'Please specify at least one model.'
assert min_vote >= 1, f'Min vote smaller than minimum: {min_vote}'
assert len(models) >= min_vote, f'Min vote greater than number of models: {min_vote}'
Expand All @@ -313,7 +325,8 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768
strides = (strides,) * 2
elif len(strides) == 1:
strides *= 2
img = preprocess(img, gamma=gamma, contrast=contrast, brightness=brightness, percentile=percentile)
img = preprocess(img, gamma=gamma, contrast=contrast, brightness=brightness, percentile=percentile,
grayscale=grayscale)
if img.ndim == 2 or (img.ndim == 3 and img.shape[-1] == 1): # todo: should depend on model
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
x = img # uint8 converted in lightning_base from now on
Expand Down Expand Up @@ -452,6 +465,7 @@ def cpn_inference(
contrast: float = 1.0,
brightness: float = 0.0,
percentile: Optional[List[float]] = None,
grayscale: bool = False,
model_parameters: str = '',
verbose: bool = True,
skip_existing: bool = False,
Expand Down Expand Up @@ -517,6 +531,7 @@ def cpn_inference(
contrast (float): Factor for contrast adjustment. Default is 1.0.
brightness (float): Factor for brightness adjustment. Default is 0.0.
percentile (list[float], optional): Percentile norm. Performs min-max normalization with specified percentiles. Default is None.
grayscale (bool): Whether to convert multi-channel inputs to grayscale before processing.
model_parameters (str): Model parameters. Pass as string in "key=value,key1=value1" format. Default is ''.
verbose (bool): Verbosity toggle.
skip_existing (bool): Whether to inputs with existing output files.
Expand Down Expand Up @@ -773,6 +788,7 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):
contrast=contrast,
brightness=brightness,
percentile=percentile,
grayscale=grayscale,
model_parameters=model_parameters,
model_kwargs_list=model_kwargs_list,
point_mask_exclusive=point_mask_exclusive,
Expand Down Expand Up @@ -949,6 +965,9 @@ def d(name):
help='Percentile norm. Performs min-max normalization with specified percentiles.'
'Specify either two values `(min, max)` or just `max` interpreted as '
'(1 - max, max).')
parser.add_argument('--grayscale', action='store_true',
help='Whether to convert all inputs to grayscale images before processing them. Note that this '
'works independently from the required number of input channels of the model.')
parser.add_argument('--model_parameters', default=d('model_parameters'), type=str,
help='Model parameters. Pass as string in "key=value,key1=value1" format')
parser.add_argument('--skip_existing', action='store_true',
Expand Down Expand Up @@ -999,6 +1018,7 @@ def d(name):
contrast=args.contrast,
brightness=args.brightness,
percentile=args.percentile,
grayscale=args.grayscale,
model_parameters=args.model_parameters,
skip_existing=args.skip_existing,
model_kwargs=args.model_kwargs,
Expand Down

0 comments on commit 579f8d6

Please sign in to comment.