Skip to content

Commit

Permalink
Fix num nodes and output collection
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Jul 12, 2024
1 parent 43881f2 commit 4b0613e
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions celldetection_scripts/cpn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def cpn_inference(
model_kwargs: Union[Dict[str, Any], List[Dict[str, Any]], str, List[str]] = None,
group_level: str = 'job',
continue_on_exception: bool = False,
return_results: bool = False,
num_nodes: Union[str, int] = 'auto',
):
"""
Process contour proposals for instance segmentation using specified parameters.
Expand Down Expand Up @@ -444,6 +446,9 @@ def cpn_inference(
ensured for example via `CUDA_VISIBLE_DEVICES` for GPUs.
continue_on_exception (bool): If ``True``, try to continue processing when certain Exceptions are raised.
Only works for selected stages (e.g. loading of an input file).
return_results (bool): Whether to return results. Should be False when used for long sequences of large inputs,
as collection of large results can lead to OOM exception.
num_nodes (int): Number of nodes. Default is 'auto'.
"""

args = dict(locals())
Expand Down Expand Up @@ -572,7 +577,10 @@ def resolve_inputs_(collection, x, tag='inputs'):
]))

# Load model
if num_nodes == 'auto':
num_nodes = cd.get_num_nodes()
trainer = pl.Trainer(
num_nodes=num_nodes,
accelerator=accelerator,
strategy=strategy,
devices=devices,
Expand Down Expand Up @@ -689,7 +697,8 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):

is_dist = is_available() and is_initialized()
output = cd.asnumpy(y)
output_list.append(output)
if return_results:
output_list.append(output)
out_files = dict()
if (is_dist and get_rank() == 0) or not is_dist:
props = properties
Expand Down Expand Up @@ -756,7 +765,8 @@ def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):
if cd.mpi.has_mpi():
comm.barrier()

return output_list
if return_results:
return output_list


def main():
Expand Down Expand Up @@ -843,7 +853,7 @@ def d(name):
help='Separator string for region properties that are written to multiple columns. '
'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.')

parser.add_argument('--group_level', type=str,
parser.add_argument('--group_level', type=str, default='job',
help='Processing group level. One of `("job", "node", "rank")`, indicating the scope of '
'processing groups that jointly process the same inputs. `"rank"` indicates for example '
'that each input is processed by just one rank.')
Expand Down Expand Up @@ -906,7 +916,8 @@ def d(name):
model_parameters=args.model_parameters,
skip_existing=args.skip_existing,
model_kwargs=args.model_kwargs,
group_level=args.group_level
group_level=args.group_level,
return_results=False,
)

if not (is_available() and is_initialized()) or get_rank() == 0: # because why not
Expand Down

0 comments on commit 4b0613e

Please sign in to comment.