Skip to content

Commit

Permalink
Update distributed processing
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Jul 15, 2024
1 parent 4b0613e commit 02fafb3
Showing 1 changed file with 40 additions and 36 deletions.
76 changes: 40 additions & 36 deletions celldetection_scripts/cpn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
from collections import OrderedDict
from PIL import ImageFile
import cv2
from torch.distributed import is_available, all_gather_object, get_world_size, is_initialized, get_rank
from itertools import chain
from torch.distributed import is_available, get_world_size, is_initialized, get_rank, gather_object
import albumentations.augmentations.functional as F
from typing import Union, List, Optional, Dict, Any
from warnings import warn
from skimage import img_as_float, img_as_ubyte


def dict_collate_fn(batch, check_padding=True, img_min_ndim=2) -> Union[OrderedDict, None]:
Expand Down Expand Up @@ -271,42 +269,48 @@ def apply_model(img, models, trainer, mask=None, point_mask=None, crop_size=(768
y = trainer.predict(model, data_loader)

is_dist = is_available() and is_initialized()
if is_dist:
o = ([None] * get_world_size())
all_gather_object(obj=y, object_list=o) # give every rank access to results
y = list(chain.from_iterable(o))
rank = get_rank()
ranks = get_world_size()

if (is_dist and get_rank() == 0) or not is_dist:
pre_results = {}
for y_idx, y_ in enumerate(y):

if y_ is None or y_ is ...: # skip
continue

# Iterate batch
keeps = []
for n in range(len(y_['contours'])):
# Determine window position
h_i, w_i = np.unravel_index(y_['slice_idx'][n], tile_loader.num_slices_per_axis)

# Remove partial contours
top, bottom = h_i > 0, h_i < (h_tiles - 1)
right, left = w_i < (w_tiles - 1), w_i > 0
keep = cd.ops.cpn.remove_border_contours(y_['contours'][n], tile_loader.crop_size[:2],
border_removal,
top=top, right=right, bottom=bottom, left=left,
offsets=-y_['offsets'][n])

if stitching_rule != 'nms':
keep = cd.ops.filter_contours_by_stitching_rule(y_['contours'][n], tile_loader.crop_size[:2],
y_['overlaps'][n], rule=stitching_rule,
offsets=-y_['offsets'][n]) & keep

keeps.append(keep)
apply_keep_indices_(y_, keeps, ['offsets', 'overlaps'])
concat_results_(pre_results, y_)

if is_dist:
pre_results_ = [None] * ranks if rank == 0 else None
gather_object(pre_results, pre_results_, dst=0)
pre_results = pre_results_

if (is_dist and rank == 0) or not is_dist:
results_ = {}
for y_idx, y_ in enumerate(y):

if y_ is None or y_ is ...: # skip
continue

# Iterate batch
keeps = []

for n in range(len(y_['contours'])):
# Determine window position
h_i, w_i = np.unravel_index(y_['slice_idx'][n], tile_loader.num_slices_per_axis)

# Remove partial contours
top, bottom = h_i > 0, h_i < (h_tiles - 1)
right, left = w_i < (w_tiles - 1), w_i > 0
keep = cd.ops.cpn.remove_border_contours(y_['contours'][n], tile_loader.crop_size[:2],
border_removal,
top=top, right=right, bottom=bottom, left=left,
offsets=-y_['offsets'][n])

if stitching_rule != 'nms':
keep = cd.ops.filter_contours_by_stitching_rule(y_['contours'][n], tile_loader.crop_size[:2],
y_['overlaps'][n], rule=stitching_rule,
offsets=-y_['offsets'][n]) & keep

keeps.append(keep)
apply_keep_indices_(y_, keeps, ['offsets', 'overlaps'])
concat_results_(results_, y_)
for r_idx, r in enumerate(pre_results):
assert isinstance(r, dict)
concat_results_flat_(results_, r)

# Remove duplicates from tiling
if 'nms' in stitching_rule.split(','):
Expand Down

0 comments on commit 02fafb3

Please sign in to comment.