Skip to content

Commit

Permalink
Update cpn inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Jun 25, 2024
1 parent 7536013 commit 43881f2
Showing 1 changed file with 133 additions and 25 deletions.
158 changes: 133 additions & 25 deletions celldetection_scripts/cpn_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import json
import traceback
import tifffile
import torch
import torch.nn as nn
Expand All @@ -15,7 +17,7 @@
from torch.distributed import is_available, all_gather_object, get_world_size, is_initialized, get_rank
from itertools import chain
import albumentations.augmentations.functional as F
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict, Any
from warnings import warn
from skimage import img_as_float, img_as_ubyte

Expand Down Expand Up @@ -133,6 +135,8 @@ def __getitem__(self, item):
def apply_keep_indices_(items: dict, keep, ignore_keys=None):
# Applies keep indices to all Tensors, except keys listed in ignore_keys
for k, v in items.items():
if v is None:
continue
is_tensor = None
if ignore_keys is not None and k in ignore_keys:
continue
Expand All @@ -156,6 +160,8 @@ def apply_keep_indices_flat_(items: dict, keep, ignore_keys=None):
def concat_results_(coll, new):
for k, v in new.items():
is_tensor = None
if v is None:
continue
for v_ in v:
if is_tensor is None:
is_tensor = isinstance(v_, torch.Tensor)
Expand Down Expand Up @@ -188,7 +194,7 @@ def preprocess(img, gamma=1., contrast=1., brightness=0., percentile=None):
return img


def resolve_model(model_name, model_parameters, verbose=True):
def resolve_model(model_name, model_parameters, verbose=True, **kwargs):
if isinstance(model_name, nn.Module):
# Is module already
model = model_name
Expand All @@ -197,14 +203,16 @@ def resolve_model(model_name, model_parameters, verbose=True):
model = model_name(map_location='cpu')
else:
if model_name.endswith('.ckpt'):
if len(kwargs):
warn(f'Cannot use kwargs when loading Lightning Checkpoints. Ignoring the following: {kwargs}')
# Lightning checkpoint
model = cd.models.LitCpn.load_from_checkpoint(model_name, map_location='cpu')
else:
model = cd.load_model(model_name, map_location='cpu')
model = cd.load_model(model_name, map_location='cpu', **kwargs)
if not isinstance(model, cd.models.LitCpn):
if verbose:
print('Wrap model with lightning', end='')
model = cd.models.LitCpn(model)
model = cd.models.LitCpn(model, **kwargs)
model.model.max_imsize = None
model.eval()
model.requires_grad_(False)
Expand All @@ -219,8 +227,8 @@ def resolve_model(model_name, model_parameters, verbose=True):


def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768, 768), strides=(384, 384), reps=1,
transforms=None,
batch_size=1, num_workers=0, pin_memory=False, border_removal=6, min_vote=1, stitching_rule='nms',
transforms=None, model_kwargs_list=None,
batch_size=1, num_workers=0, pin_memory=False, border_removal=4, min_vote=1, stitching_rule='nms',
gamma=1., contrast=1., brightness=0., percentile=None, model_parameters=None,
point_mask_exclusive=False, verbose=True, **kwargs):
assert len(models) >= 1, 'Please specify at least one model.'
Expand Down Expand Up @@ -256,8 +264,8 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768
results = {}
h_tiles, w_tiles = tile_loader.num_slices_per_axis
nms_thresh = None
for model_name in models:
model = resolve_model(model_name, model_parameters, verbose=verbose)
for model_name, model_kwargs in zip(models, model_kwargs_list):
model = resolve_model(model_name, model_parameters, verbose=verbose, **model_kwargs)
nms_thresh = kwargs.get('nms_thresh', model.model.nms_thresh)

y = trainer.predict(model, data_loader)
Expand Down Expand Up @@ -365,7 +373,10 @@ def cpn_inference(
percentile: Optional[List[float]] = None,
model_parameters: str = '',
verbose: bool = True,
skip_existing: bool = False
skip_existing: bool = False,
model_kwargs: Union[Dict[str, Any], List[Dict[str, Any]], str, List[str]] = None,
group_level: str = 'job',
continue_on_exception: bool = False,
):
"""
Process contour proposals for instance segmentation using specified parameters.
Expand Down Expand Up @@ -425,11 +436,49 @@ def cpn_inference(
percentile (list[float], optional): Percentile norm. Performs min-max normalization with specified percentiles. Default is None.
model_parameters (str): Model parameters. Pass as string in "key=value,key1=value1" format. Default is ''.
verbose (bool): Verbosity toggle.
skip_existing(bool): Whether to inputs with existing output files.
skip_existing (bool): Whether to inputs with existing output files.
model_kwargs (str, dict, list[str], list[dict]): Model kwargs. If passed as string, JSON format is expected.
group_level (str): Processing group level. One of `("job", "node", "rank")`, indicating the scope of processing
groups that jointly process the same inputs. `"rank"` indicates for example that each input is processed
by just one rank. Note that each rank is assumed to only have access to a single device, which can be
ensured for example via `CUDA_VISIBLE_DEVICES` for GPUs.
continue_on_exception (bool): If ``True``, try to continue processing when certain Exceptions are raised.
Only works for selected stages (e.g. loading of an input file).
"""

args = dict(locals())

if isinstance(devices, str) and devices.isnumeric():
devices = int(devices)

# Group level
assert group_level in ('node', 'job', 'rank'), '`group_level` must be one of "node", "job", "rank"'
comm, mpi_rank, mpi_ranks = cd.mpi.get_comm(return_ranks=True)
mpi_local_rank = mpi_local_ranks = None
if group_level != 'job':
assert cd.mpi.has_mpi(), f'To use `group_level={group_level}` MPI must be available.'
if group_level == 'node':
raise NotImplementedError(f'`group_level={group_level}` is not yet available.')
local_comm, mpi_local_rank, mpi_local_ranks = cd.mpi.get_local_comm(comm, return_ranks=True)

if group_level == 'rank':
# Check strategy
if strategy != 'auto':
warn(f'Strategy is being set to `"auto"` to comply with `group_level={group_level}`. '
f'It was initially set to {strategy}.')
strategy = 'auto'

# Check devices
if isinstance(devices, int) and devices != 1:
warn(f'Devices is being set to `1` to comply with `group_level={group_level}`. '
f'It was initially set to {devices}.')
devices = 1
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
warn(f'Group level was set `group_level={group_level}`, but found multiple devices.\n'
'By default each rank will only use one device in this mode.\n'
'To ensure that each rank has its own dedicated device please change visibility settings, e.g. '
'via `CUDA_VISIBLE_DEVICES`.')

if truncated_images:
ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand All @@ -447,6 +496,11 @@ def resolve_inputs_(collection, x, tag='inputs'):
masks = [masks]
if models is not None and not isinstance(models, (tuple, list)):
models = [models]
if not isinstance(model_kwargs, (tuple, list)):
model_kwargs = [model_kwargs] * len(models)
else:
assert model_kwargs is None or len(models) == len(model_kwargs), ('Please provide one keyword argument '
'dict per model.')

# Prepare input args
input_list = []
Expand All @@ -470,38 +524,50 @@ def resolve_inputs_(collection, x, tag='inputs'):

# Prepare model args
model_list = []
for m in models:
model_kwargs_list = []
for m, m_kwargs in zip(models, model_kwargs):
if m_kwargs is None:
m_kwargs = dict()
elif isinstance(m_kwargs, str):
m_kwargs = json.loads(m_kwargs)
else:
assert isinstance(m_kwargs, dict), 'Please provide `model_kwargs` as a dictionary ' \
'(JSON string of dictionary also supported).'

if isinstance(m, nn.Module):
model_list.append(m)
model_kwargs_list.append(m_kwargs)
else:
assert isinstance(m, str)
if m.startswith('http://') or m.startswith('https://') or m.startswith('cd://') or (
not isfile(m) and not splitext(m)[1]):
# Either URL (leading http(s)) or hosted model (leading cd or just no file extension as a fallback)
model_list.append(lambda _m=m, **kwargs: cd.fetch_model(_m, **kwargs))
model_list.append(lambda _m=m, _mkw=m_kwargs, **kwargs: cd.fetch_model(_m, **kwargs, **_mkw))
model_kwargs_list.append(dict())
else:
files = sorted(glob(m))
if len(files) == 0 and sep not in m and '.' not in m:
files = [lambda _m=m, **kwargs: cd.fetch_model(_m, **kwargs)] # fallback: try cd-hosted
assert len(files), f'Could not find models: {m}'
model_list += files
model_kwargs_list += [m_kwargs] * len(files)

# Prepare model parameters
model_parameters = [i.strip().split('=') for i in model_parameters.split(',') if len(i.strip())]
model_parameters = {k: v for k, v in model_parameters}
if verbose and model_parameters is not None and len(model_parameters):
print('Changing the following model parameters:', model_parameters)

if isinstance(devices, str) and devices.isnumeric():
devices = int(devices)

if verbose:
print('Summary:\n ', '\n '.join([
f'Number of inputs: {len(input_list)}',
f'Number of models: {len(model_list)}',
f'Output path: {outputs}' + ' (newly created)' * (not isdir(outputs)),
f'Workers: {num_workers}',
f'Devices: {devices}',
f'Cuda available: {torch.cuda.is_available()}',
f'Cuda device count: {torch.cuda.device_count() if torch.cuda.is_available() else 0}',
f'Accelerator: {accelerator}',
f'Strategy: {strategy}',
]))

Expand Down Expand Up @@ -545,12 +611,33 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):

output_list = []
for src_idx, src in enumerate(input_list):
# Group level: Make sure inputs are assigned correctly
if group_level == 'rank':
if (src_idx % mpi_ranks) != mpi_rank:
continue
elif group_level == 'node':
if (src_idx % mpi_local_ranks) != mpi_local_rank:
continue

print(f'Next input: {src_idx} (rank {mpi_rank}/{mpi_ranks})', src)

# Load inputs
try:
img, dst = load_inputs(src, inputs_dataset, inputs_method, 'inputs', idx=src_idx)
except FileExistsError:
if verbose:
print('Skipping input, because output exists already:', src)
continue
except Exception as e:
if continue_on_exception:
# assuming that all ranks fail to load input
warn(f"An exception occurred: {e}\nTraceback:\n{traceback.format_exc()}")
if cd.mpi.has_mpi():
comm.barrier()
continue
else:
raise e

dst_h5 = dst.format(ext='.h5')

if isinstance(src, np.ndarray):
Expand All @@ -575,7 +662,7 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):

# Resolve model now if it's just one
if len(model_list) == 1:
model_list[0] = resolve_model(model_list[0], model_parameters, verbose=verbose)
model_list[0] = resolve_model(model_list[0], model_parameters, verbose=verbose, **model_kwargs_list[0])

y = cd.asnumpy(apply_model(
img, model_list, trainer,
Expand All @@ -595,6 +682,7 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):
brightness=brightness,
percentile=percentile,
model_parameters=model_parameters,
model_kwargs_list=model_kwargs_list,
point_mask_exclusive=point_mask_exclusive,
verbose=verbose
))
Expand All @@ -610,7 +698,10 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):

labels_ = flat_labels_ = None
if do_labels:
labels_ = cd.data.contours2labels(y['contours'], img.shape[:2])
if 'contours' in y:
labels_ = cd.data.contours2labels(y['contours'], img.shape[:2])
else:
labels_ = np.zeros(tuple(img.shape[:2]) + (1,), dtype='uint8')
if labels:
y['labels'] = output['labels'] = labels_
if flat_labels_:
Expand Down Expand Up @@ -639,9 +730,9 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):
if overlay:
if do_labels:
assert labels_ is not None or flat_labels_ is not None
label_vis = img_as_ubyte(cd.label_cmap(flat_labels_ if labels_ is None else labels_))
label_vis = cd.label_cmap(flat_labels_ if labels_ is None else labels_, ubyte=True)
else:
label_vis = cd.data.contours2overlay(y['contours'], img.shape[:2])
label_vis = cd.data.contours2overlay(y.get('contours'), img.shape[:2])
dst_ove_tif = dst.format(ext='_overlay.tif')
tifffile.imwrite(dst_ove_tif, label_vis, compression='ZLIB')
output['overlay'] = label_vis
Expand All @@ -650,14 +741,21 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):
if demo_figure:
from matplotlib import pyplot as plt
cd.imshow_row(img, img, figsize=(30, 15), titles=('input', 'contours'))
cd.plot_contours(y['contours'])
cd.plot_boxes(y['boxes'])
loc = cd.asnumpy(y['locations'])
plt.scatter(loc[:, 0], loc[:, 1], marker='x')
if 'contours' in y:
cd.plot_contours(y['contours'])
if 'boxes' in y:
cd.plot_boxes(y['boxes'])
if 'locations' in y:
loc = cd.asnumpy(y['locations'])
plt.scatter(loc[:, 0], loc[:, 1], marker='x')
out_files['demo_figure'] = dst_demo = dst.format(ext='_demo.png')
cd.save_fig(dst_demo)
if len(out_files):
output['files'] = out_files

if cd.mpi.has_mpi():
comm.barrier()

return output_list


Expand All @@ -681,7 +779,10 @@ def d(name):
parser.add_argument('-m', '--models', nargs='+',
help='Model. Either filename, name pattern (glob), URL (leading http:// or https://), or '
'hosted model name (leading cd://). '
'Example: `--model \'cd://ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c\'`')
'Example: `--models \'cd://ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c\'`')
parser.add_argument('--model_kwargs', nargs='+',
help='Model kwargs in JSON format. '
'Example: `--model_kwargs \'{"augment": true}\'')
parser.add_argument('--masks', default=d('masks'), nargs='+', type=str,
help='Masks. Either filename, name pattern (glob), or URL (leading http:// or https://). '
'A mask determines where the model searches for objects. Regions with values <= 0'
Expand Down Expand Up @@ -742,6 +843,11 @@ def d(name):
help='Separator string for region properties that are written to multiple columns. '
'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.')

parser.add_argument('--group_level', type=str,
help='Processing group level. One of `("job", "node", "rank")`, indicating the scope of '
'processing groups that jointly process the same inputs. `"rank"` indicates for example '
'that each input is processed by just one rank.')

parser.add_argument('--gamma', default=d('gamma'), type=float, help='Gamma value for gamma transform.')
parser.add_argument('--contrast', default=d('contrast'), type=float, help='Factor for contrast adjustment.')
parser.add_argument('--brightness', default=d('brightness'), type=float, help='Factor for brightness adjustment.')
Expand Down Expand Up @@ -798,7 +904,9 @@ def d(name):
brightness=args.brightness,
percentile=args.percentile,
model_parameters=args.model_parameters,
skip_existing=args.skip_existing
skip_existing=args.skip_existing,
model_kwargs=args.model_kwargs,
group_level=args.group_level
)

if not (is_available() and is_initialized()) or get_rank() == 0: # because why not
Expand Down

0 comments on commit 43881f2

Please sign in to comment.