From 43881f22eccad7b05b11bb2aa8c5cdf9fa3a37ef Mon Sep 17 00:00:00 2001 From: ericup Date: Tue, 25 Jun 2024 15:54:26 +0200 Subject: [PATCH] Update cpn inference --- celldetection_scripts/cpn_inference.py | 158 +++++++++++++++++++++---- 1 file changed, 133 insertions(+), 25 deletions(-) diff --git a/celldetection_scripts/cpn_inference.py b/celldetection_scripts/cpn_inference.py index a0b88f8..2a87d62 100644 --- a/celldetection_scripts/cpn_inference.py +++ b/celldetection_scripts/cpn_inference.py @@ -1,4 +1,6 @@ import argparse +import json +import traceback import tifffile import torch import torch.nn as nn @@ -15,7 +17,7 @@ from torch.distributed import is_available, all_gather_object, get_world_size, is_initialized, get_rank from itertools import chain import albumentations.augmentations.functional as F -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict, Any from warnings import warn from skimage import img_as_float, img_as_ubyte @@ -133,6 +135,8 @@ def __getitem__(self, item): def apply_keep_indices_(items: dict, keep, ignore_keys=None): # Applies keep indices to all Tensors, except keys listed in ignore_keys for k, v in items.items(): + if v is None: + continue is_tensor = None if ignore_keys is not None and k in ignore_keys: continue @@ -156,6 +160,8 @@ def apply_keep_indices_flat_(items: dict, keep, ignore_keys=None): def concat_results_(coll, new): for k, v in new.items(): is_tensor = None + if v is None: + continue for v_ in v: if is_tensor is None: is_tensor = isinstance(v_, torch.Tensor) @@ -188,7 +194,7 @@ def preprocess(img, gamma=1., contrast=1., brightness=0., percentile=None): return img -def resolve_model(model_name, model_parameters, verbose=True): +def resolve_model(model_name, model_parameters, verbose=True, **kwargs): if isinstance(model_name, nn.Module): # Is module already model = model_name @@ -197,14 +203,16 @@ def resolve_model(model_name, model_parameters, verbose=True): model = model_name(map_location='cpu') else: if model_name.endswith('.ckpt'): + if len(kwargs): + warn(f'Cannot use kwargs when loading Lightning Checkpoints. Ignoring the following: {kwargs}') # Lightning checkpoint model = cd.models.LitCpn.load_from_checkpoint(model_name, map_location='cpu') else: - model = cd.load_model(model_name, map_location='cpu') + model = cd.load_model(model_name, map_location='cpu', **kwargs) if not isinstance(model, cd.models.LitCpn): if verbose: print('Wrap model with lightning', end='') - model = cd.models.LitCpn(model) + model = cd.models.LitCpn(model, **kwargs) model.model.max_imsize = None model.eval() model.requires_grad_(False) @@ -219,8 +227,8 @@ def resolve_model(model_name, model_parameters, verbose=True): def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768, 768), strides=(384, 384), reps=1, - transforms=None, - batch_size=1, num_workers=0, pin_memory=False, border_removal=6, min_vote=1, stitching_rule='nms', + transforms=None, model_kwargs_list=None, + batch_size=1, num_workers=0, pin_memory=False, border_removal=4, min_vote=1, stitching_rule='nms', gamma=1., contrast=1., brightness=0., percentile=None, model_parameters=None, point_mask_exclusive=False, verbose=True, **kwargs): assert len(models) >= 1, 'Please specify at least one model.' @@ -256,8 +264,8 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768 results = {} h_tiles, w_tiles = tile_loader.num_slices_per_axis nms_thresh = None - for model_name in models: - model = resolve_model(model_name, model_parameters, verbose=verbose) + for model_name, model_kwargs in zip(models, model_kwargs_list): + model = resolve_model(model_name, model_parameters, verbose=verbose, **model_kwargs) nms_thresh = kwargs.get('nms_thresh', model.model.nms_thresh) y = trainer.predict(model, data_loader) @@ -365,7 +373,10 @@ def cpn_inference( percentile: Optional[List[float]] = None, model_parameters: str = '', verbose: bool = True, - skip_existing: bool = False + skip_existing: bool = False, + model_kwargs: Union[Dict[str, Any], List[Dict[str, Any]], str, List[str]] = None, + group_level: str = 'job', + continue_on_exception: bool = False, ): """ Process contour proposals for instance segmentation using specified parameters. @@ -425,11 +436,49 @@ def cpn_inference( percentile (list[float], optional): Percentile norm. Performs min-max normalization with specified percentiles. Default is None. model_parameters (str): Model parameters. Pass as string in "key=value,key1=value1" format. Default is ''. verbose (bool): Verbosity toggle. - skip_existing(bool): Whether to inputs with existing output files. + skip_existing (bool): Whether to inputs with existing output files. + model_kwargs (str, dict, list[str], list[dict]): Model kwargs. If passed as string, JSON format is expected. + group_level (str): Processing group level. One of `("job", "node", "rank")`, indicating the scope of processing + groups that jointly process the same inputs. `"rank"` indicates for example that each input is processed + by just one rank. Note that each rank is assumed to only have access to a single device, which can be + ensured for example via `CUDA_VISIBLE_DEVICES` for GPUs. + continue_on_exception (bool): If ``True``, try to continue processing when certain Exceptions are raised. + Only works for selected stages (e.g. loading of an input file). """ args = dict(locals()) + if isinstance(devices, str) and devices.isnumeric(): + devices = int(devices) + + # Group level + assert group_level in ('node', 'job', 'rank'), '`group_level` must be one of "node", "job", "rank"' + comm, mpi_rank, mpi_ranks = cd.mpi.get_comm(return_ranks=True) + mpi_local_rank = mpi_local_ranks = None + if group_level != 'job': + assert cd.mpi.has_mpi(), f'To use `group_level={group_level}` MPI must be available.' + if group_level == 'node': + raise NotImplementedError(f'`group_level={group_level}` is not yet available.') + local_comm, mpi_local_rank, mpi_local_ranks = cd.mpi.get_local_comm(comm, return_ranks=True) + + if group_level == 'rank': + # Check strategy + if strategy != 'auto': + warn(f'Strategy is being set to `"auto"` to comply with `group_level={group_level}`. ' + f'It was initially set to {strategy}.') + strategy = 'auto' + + # Check devices + if isinstance(devices, int) and devices != 1: + warn(f'Devices is being set to `1` to comply with `group_level={group_level}`. ' + f'It was initially set to {devices}.') + devices = 1 + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + warn(f'Group level was set `group_level={group_level}`, but found multiple devices.\n' + 'By default each rank will only use one device in this mode.\n' + 'To ensure that each rank has its own dedicated device please change visibility settings, e.g. ' + 'via `CUDA_VISIBLE_DEVICES`.') + if truncated_images: ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -447,6 +496,11 @@ def resolve_inputs_(collection, x, tag='inputs'): masks = [masks] if models is not None and not isinstance(models, (tuple, list)): models = [models] + if not isinstance(model_kwargs, (tuple, list)): + model_kwargs = [model_kwargs] * len(models) + else: + assert model_kwargs is None or len(models) == len(model_kwargs), ('Please provide one keyword argument ' + 'dict per model.') # Prepare input args input_list = [] @@ -470,21 +524,33 @@ def resolve_inputs_(collection, x, tag='inputs'): # Prepare model args model_list = [] - for m in models: + model_kwargs_list = [] + for m, m_kwargs in zip(models, model_kwargs): + if m_kwargs is None: + m_kwargs = dict() + elif isinstance(m_kwargs, str): + m_kwargs = json.loads(m_kwargs) + else: + assert isinstance(m_kwargs, dict), 'Please provide `model_kwargs` as a dictionary ' \ + '(JSON string of dictionary also supported).' + if isinstance(m, nn.Module): model_list.append(m) + model_kwargs_list.append(m_kwargs) else: assert isinstance(m, str) if m.startswith('http://') or m.startswith('https://') or m.startswith('cd://') or ( not isfile(m) and not splitext(m)[1]): # Either URL (leading http(s)) or hosted model (leading cd or just no file extension as a fallback) - model_list.append(lambda _m=m, **kwargs: cd.fetch_model(_m, **kwargs)) + model_list.append(lambda _m=m, _mkw=m_kwargs, **kwargs: cd.fetch_model(_m, **kwargs, **_mkw)) + model_kwargs_list.append(dict()) else: files = sorted(glob(m)) if len(files) == 0 and sep not in m and '.' not in m: files = [lambda _m=m, **kwargs: cd.fetch_model(_m, **kwargs)] # fallback: try cd-hosted assert len(files), f'Could not find models: {m}' model_list += files + model_kwargs_list += [m_kwargs] * len(files) # Prepare model parameters model_parameters = [i.strip().split('=') for i in model_parameters.split(',') if len(i.strip())] @@ -492,9 +558,6 @@ def resolve_inputs_(collection, x, tag='inputs'): if verbose and model_parameters is not None and len(model_parameters): print('Changing the following model parameters:', model_parameters) - if isinstance(devices, str) and devices.isnumeric(): - devices = int(devices) - if verbose: print('Summary:\n ', '\n '.join([ f'Number of inputs: {len(input_list)}', @@ -502,6 +565,9 @@ def resolve_inputs_(collection, x, tag='inputs'): f'Output path: {outputs}' + ' (newly created)' * (not isdir(outputs)), f'Workers: {num_workers}', f'Devices: {devices}', + f'Cuda available: {torch.cuda.is_available()}', + f'Cuda device count: {torch.cuda.device_count() if torch.cuda.is_available() else 0}', + f'Accelerator: {accelerator}', f'Strategy: {strategy}', ])) @@ -545,12 +611,33 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): output_list = [] for src_idx, src in enumerate(input_list): + # Group level: Make sure inputs are assigned correctly + if group_level == 'rank': + if (src_idx % mpi_ranks) != mpi_rank: + continue + elif group_level == 'node': + if (src_idx % mpi_local_ranks) != mpi_local_rank: + continue + + print(f'Next input: {src_idx} (rank {mpi_rank}/{mpi_ranks})', src) + + # Load inputs try: img, dst = load_inputs(src, inputs_dataset, inputs_method, 'inputs', idx=src_idx) except FileExistsError: if verbose: print('Skipping input, because output exists already:', src) continue + except Exception as e: + if continue_on_exception: + # assuming that all ranks fail to load input + warn(f"An exception occurred: {e}\nTraceback:\n{traceback.format_exc()}") + if cd.mpi.has_mpi(): + comm.barrier() + continue + else: + raise e + dst_h5 = dst.format(ext='.h5') if isinstance(src, np.ndarray): @@ -575,7 +662,7 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): # Resolve model now if it's just one if len(model_list) == 1: - model_list[0] = resolve_model(model_list[0], model_parameters, verbose=verbose) + model_list[0] = resolve_model(model_list[0], model_parameters, verbose=verbose, **model_kwargs_list[0]) y = cd.asnumpy(apply_model( img, model_list, trainer, @@ -595,6 +682,7 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): brightness=brightness, percentile=percentile, model_parameters=model_parameters, + model_kwargs_list=model_kwargs_list, point_mask_exclusive=point_mask_exclusive, verbose=verbose )) @@ -610,7 +698,10 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): labels_ = flat_labels_ = None if do_labels: - labels_ = cd.data.contours2labels(y['contours'], img.shape[:2]) + if 'contours' in y: + labels_ = cd.data.contours2labels(y['contours'], img.shape[:2]) + else: + labels_ = np.zeros(tuple(img.shape[:2]) + (1,), dtype='uint8') if labels: y['labels'] = output['labels'] = labels_ if flat_labels_: @@ -639,9 +730,9 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): if overlay: if do_labels: assert labels_ is not None or flat_labels_ is not None - label_vis = img_as_ubyte(cd.label_cmap(flat_labels_ if labels_ is None else labels_)) + label_vis = cd.label_cmap(flat_labels_ if labels_ is None else labels_, ubyte=True) else: - label_vis = cd.data.contours2overlay(y['contours'], img.shape[:2]) + label_vis = cd.data.contours2overlay(y.get('contours'), img.shape[:2]) dst_ove_tif = dst.format(ext='_overlay.tif') tifffile.imwrite(dst_ove_tif, label_vis, compression='ZLIB') output['overlay'] = label_vis @@ -650,14 +741,21 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): if demo_figure: from matplotlib import pyplot as plt cd.imshow_row(img, img, figsize=(30, 15), titles=('input', 'contours')) - cd.plot_contours(y['contours']) - cd.plot_boxes(y['boxes']) - loc = cd.asnumpy(y['locations']) - plt.scatter(loc[:, 0], loc[:, 1], marker='x') + if 'contours' in y: + cd.plot_contours(y['contours']) + if 'boxes' in y: + cd.plot_boxes(y['boxes']) + if 'locations' in y: + loc = cd.asnumpy(y['locations']) + plt.scatter(loc[:, 0], loc[:, 1], marker='x') out_files['demo_figure'] = dst_demo = dst.format(ext='_demo.png') cd.save_fig(dst_demo) if len(out_files): output['files'] = out_files + + if cd.mpi.has_mpi(): + comm.barrier() + return output_list @@ -681,7 +779,10 @@ def d(name): parser.add_argument('-m', '--models', nargs='+', help='Model. Either filename, name pattern (glob), URL (leading http:// or https://), or ' 'hosted model name (leading cd://). ' - 'Example: `--model \'cd://ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c\'`') + 'Example: `--models \'cd://ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c\'`') + parser.add_argument('--model_kwargs', nargs='+', + help='Model kwargs in JSON format. ' + 'Example: `--model_kwargs \'{"augment": true}\'') parser.add_argument('--masks', default=d('masks'), nargs='+', type=str, help='Masks. Either filename, name pattern (glob), or URL (leading http:// or https://). ' 'A mask determines where the model searches for objects. Regions with values <= 0' @@ -742,6 +843,11 @@ def d(name): help='Separator string for region properties that are written to multiple columns. ' 'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.') + parser.add_argument('--group_level', type=str, + help='Processing group level. One of `("job", "node", "rank")`, indicating the scope of ' + 'processing groups that jointly process the same inputs. `"rank"` indicates for example ' + 'that each input is processed by just one rank.') + parser.add_argument('--gamma', default=d('gamma'), type=float, help='Gamma value for gamma transform.') parser.add_argument('--contrast', default=d('contrast'), type=float, help='Factor for contrast adjustment.') parser.add_argument('--brightness', default=d('brightness'), type=float, help='Factor for brightness adjustment.') @@ -798,7 +904,9 @@ def d(name): brightness=args.brightness, percentile=args.percentile, model_parameters=args.model_parameters, - skip_existing=args.skip_existing + skip_existing=args.skip_existing, + model_kwargs=args.model_kwargs, + group_level=args.group_level ) if not (is_available() and is_initialized()) or get_rank() == 0: # because why not