From 5f36e5c063629e83693ffd8f83d5a6469f05abda Mon Sep 17 00:00:00 2001 From: ericup Date: Mon, 15 Jul 2024 13:43:43 +0200 Subject: [PATCH] Fix non-distributed --- celldetection_scripts/cpn_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/celldetection_scripts/cpn_inference.py b/celldetection_scripts/cpn_inference.py index 6313843..65c6b5c 100644 --- a/celldetection_scripts/cpn_inference.py +++ b/celldetection_scripts/cpn_inference.py @@ -269,8 +269,7 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768 y = trainer.predict(model, data_loader) is_dist = is_available() and is_initialized() - rank = get_rank() - ranks = get_world_size() + rank, ranks = cd.get_rank(return_world_size=True) pre_results = {} for y_idx, y_ in enumerate(y): @@ -308,6 +307,8 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768 if (is_dist and rank == 0) or not is_dist: results_ = {} + if isinstance(pre_results, dict): + pre_results = pre_results, for r_idx, r in enumerate(pre_results): assert isinstance(r, dict) concat_results_flat_(results_, r)