diff --git a/celldetection_scripts/cpn_inference.py b/celldetection_scripts/cpn_inference.py index c7b9aea..6313843 100644 --- a/celldetection_scripts/cpn_inference.py +++ b/celldetection_scripts/cpn_inference.py @@ -14,12 +14,10 @@ from collections import OrderedDict from PIL import ImageFile import cv2 -from torch.distributed import is_available, all_gather_object, get_world_size, is_initialized, get_rank -from itertools import chain +from torch.distributed import is_available, get_world_size, is_initialized, get_rank, gather_object import albumentations.augmentations.functional as F from typing import Union, List, Optional, Dict, Any from warnings import warn -from skimage import img_as_float, img_as_ubyte def dict_collate_fn(batch, check_padding=True, img_min_ndim=2) -> Union[OrderedDict, None]: @@ -271,42 +269,48 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768 y = trainer.predict(model, data_loader) is_dist = is_available() and is_initialized() - if is_dist: - o = ([None] * get_world_size()) - all_gather_object(obj=y, object_list=o) # give every rank access to results - y = list(chain.from_iterable(o)) + rank = get_rank() + ranks = get_world_size() - if (is_dist and get_rank() == 0) or not is_dist: + pre_results = {} + for y_idx, y_ in enumerate(y): + + if y_ is None or y_ is ...: # skip + continue + + # Iterate batch + keeps = [] + for n in range(len(y_['contours'])): + # Determine window position + h_i, w_i = np.unravel_index(y_['slice_idx'][n], tile_loader.num_slices_per_axis) + + # Remove partial contours + top, bottom = h_i > 0, h_i < (h_tiles - 1) + right, left = w_i < (w_tiles - 1), w_i > 0 + keep = cd.ops.cpn.remove_border_contours(y_['contours'][n], tile_loader.crop_size[:2], + border_removal, + top=top, right=right, bottom=bottom, left=left, + offsets=-y_['offsets'][n]) + + if stitching_rule != 'nms': + keep = cd.ops.filter_contours_by_stitching_rule(y_['contours'][n], tile_loader.crop_size[:2], + y_['overlaps'][n], rule=stitching_rule, + offsets=-y_['offsets'][n]) & keep + + keeps.append(keep) + apply_keep_indices_(y_, keeps, ['offsets', 'overlaps']) + concat_results_(pre_results, y_) + + if is_dist: + pre_results_ = [None] * ranks if rank == 0 else None + gather_object(pre_results, pre_results_, dst=0) + pre_results = pre_results_ + if (is_dist and rank == 0) or not is_dist: results_ = {} - for y_idx, y_ in enumerate(y): - - if y_ is None or y_ is ...: # skip - continue - - # Iterate batch - keeps = [] - - for n in range(len(y_['contours'])): - # Determine window position - h_i, w_i = np.unravel_index(y_['slice_idx'][n], tile_loader.num_slices_per_axis) - - # Remove partial contours - top, bottom = h_i > 0, h_i < (h_tiles - 1) - right, left = w_i < (w_tiles - 1), w_i > 0 - keep = cd.ops.cpn.remove_border_contours(y_['contours'][n], tile_loader.crop_size[:2], - border_removal, - top=top, right=right, bottom=bottom, left=left, - offsets=-y_['offsets'][n]) - - if stitching_rule != 'nms': - keep = cd.ops.filter_contours_by_stitching_rule(y_['contours'][n], tile_loader.crop_size[:2], - y_['overlaps'][n], rule=stitching_rule, - offsets=-y_['offsets'][n]) & keep - - keeps.append(keep) - apply_keep_indices_(y_, keeps, ['offsets', 'overlaps']) - concat_results_(results_, y_) + for r_idx, r in enumerate(pre_results): + assert isinstance(r, dict) + concat_results_flat_(results_, r) # Remove duplicates from tiling if 'nms' in stitching_rule.split(','):