diff --git a/celldetection_scripts/cpn_inference.py b/celldetection_scripts/cpn_inference.py index 2a87d62..c7b9aea 100644 --- a/celldetection_scripts/cpn_inference.py +++ b/celldetection_scripts/cpn_inference.py @@ -377,6 +377,8 @@ def cpn_inference( model_kwargs: Union[Dict[str, Any], List[Dict[str, Any]], str, List[str]] = None, group_level: str = 'job', continue_on_exception: bool = False, + return_results: bool = False, + num_nodes: Union[str, int] = 'auto', ): """ Process contour proposals for instance segmentation using specified parameters. @@ -444,6 +446,9 @@ def cpn_inference( ensured for example via `CUDA_VISIBLE_DEVICES` for GPUs. continue_on_exception (bool): If ``True``, try to continue processing when certain Exceptions are raised. Only works for selected stages (e.g. loading of an input file). + return_results (bool): Whether to return results. Should be False when used for long sequences of large inputs, + as collection of large results can lead to OOM exception. + num_nodes (int): Number of nodes. Default is 'auto'. """ args = dict(locals()) @@ -572,7 +577,10 @@ def resolve_inputs_(collection, x, tag='inputs'): ])) # Load model + if num_nodes == 'auto': + num_nodes = cd.get_num_nodes() trainer = pl.Trainer( + num_nodes=num_nodes, accelerator=accelerator, strategy=strategy, devices=devices, @@ -689,7 +697,8 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): is_dist = is_available() and is_initialized() output = cd.asnumpy(y) - output_list.append(output) + if return_results: + output_list.append(output) out_files = dict() if (is_dist and get_rank() == 0) or not is_dist: props = properties @@ -756,7 +765,8 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)): if cd.mpi.has_mpi(): comm.barrier() - return output_list + if return_results: + return output_list def main(): @@ -843,7 +853,7 @@ def d(name): help='Separator string for region properties that are written to multiple columns. ' 'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.') - parser.add_argument('--group_level', type=str, + parser.add_argument('--group_level', type=str, default='job', help='Processing group level. One of `("job", "node", "rank")`, indicating the scope of ' 'processing groups that jointly process the same inputs. `"rank"` indicates for example ' 'that each input is processed by just one rank.') @@ -906,7 +916,8 @@ def d(name): model_parameters=args.model_parameters, skip_existing=args.skip_existing, model_kwargs=args.model_kwargs, - group_level=args.group_level + group_level=args.group_level, + return_results=False, ) if not (is_available() and is_initialized()) or get_rank() == 0: # because why not