You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Current plot_results() can't handle the attributes of predictions/ground_truth and the shape well. Updated version from my side below:
def plot_results(
self,
imgs: Union[torch.Tensor, np.ndarray],
predictions: dict,
ground_truth: dict,
img_names: List,
num_nuclei_classes: int,
outdir: Union[Path, str],
scores: List[List[float]] = None,
) -> None:
"""Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth
Args:
imgs (Union[torch.Tensor, np.ndarray]): Images to process
Shape: (batch_size, 3, H', W')
predictions (dict): Predictions of models. Keys:
"nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
"nuclei_binary_map": Shape: (batch_size, 2, H', W')
"hv_map": Shape: (batch_size, 2, H', W')
"instance_map": Shape: (batch_size, H', W')
ground_truth (dict): Ground truth values. Keys:
"nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
"nuclei_binary_map": Shape: (batch_size, 2, H', W')
"hv_map": Shape: (batch_size, 2, H', W')
"instance_map": Shape: (batch_size, H', W')
img_names (List): Names of images as list
num_nuclei_classes (int): Number of total nuclei classes including background
outdir (Union[Path, str]): Output directory where images should be stored
scores (List[List[float]], optional): List with scores for each image.
Each list entry is a list with 3 scores: Dice, Jaccard and bPQ for the image.
Defaults to None.
"""
outdir = Path(outdir)
outdir.mkdir(exist_ok=True, parents=True)
h = ground_truth.hv_map.shape[2]
w = ground_truth.hv_map.shape[3]
# convert to rgb and crop to selection
sample_images = (
imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy()
) # convert to rgb
sample_images = cropping_center(sample_images, (h, w), True)
pred_sample_binary_map = (
predictions.nuclei_binary_map[:, 1, :, :].detach().cpu().numpy()
)
pred_sample_hv_map = predictions.hv_map.detach().cpu().numpy()
pred_sample_instance_maps = predictions.instance_map.detach().cpu().numpy()
pred_sample_type_maps = (
torch.argmax(predictions.nuclei_type_map, dim=1).detach().cpu().numpy()
)
gt_sample_binary_map = ground_truth.nuclei_binary_map.detach().cpu().numpy()
gt_sample_hv_map = ground_truth.hv_map.detach().cpu().numpy()
gt_sample_instance_map = ground_truth.instance_map.detach().cpu().numpy()
gt_sample_type_map = (
torch.argmax(ground_truth.nuclei_type_map, dim=1).detach().cpu().numpy()
)
# create colormaps
hv_cmap = plt.get_cmap("jet")
binary_cmap = plt.get_cmap("jet")
instance_map = plt.get_cmap("viridis")
cell_colors = ["#ffffff", "#ff0000", "#00ff00", "#1e00ff", "#feff00", "#ffbf00"]
# invert the normalization of the sample images
transform_settings = self.run_conf["transformations"]
if "normalize" in transform_settings:
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5))
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5))
else:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
inv_normalize = transforms.Normalize(
mean=[-0.5 / mean[0], -0.5 / mean[1], -0.5 / mean[2]],
std=[1 / std[0], 1 / std[1], 1 / std[2]],
)
inv_samples = inv_normalize(torch.tensor(sample_images).permute(0, 3, 1, 2))
sample_images = inv_samples.permute(0, 2, 3, 1).detach().cpu().numpy()
for i in range(len(img_names)):
fig, axs = plt.subplots(figsize=(6, 2), dpi=300)
placeholder = np.zeros((2 * h, 7 * w, 3))
# orig image
placeholder[:h, :w, :3] = sample_images[i]
placeholder[h : 2 * h, :w, :3] = sample_images[i]
# binary prediction
placeholder[:h, w : 2 * w, :3] = rgba2rgb(
binary_cmap(gt_sample_binary_map[i] * 255)
)
placeholder[h : 2 * h, w : 2 * w, :3] = rgba2rgb(
binary_cmap(pred_sample_binary_map[i] * 255)
)
# hv maps
placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb(
hv_cmap((gt_sample_hv_map[i, 0, :, :] + 1) / 2)
)
placeholder[h : 2 * h, 2 * w : 3 * w, :3] = rgba2rgb(
hv_cmap((pred_sample_hv_map[i, 0, :, :] + 1) / 2)
)
placeholder[:h, 3 * w : 4 * w, :3] = rgba2rgb(
hv_cmap((gt_sample_hv_map[i, 1, :, :] + 1) / 2)
)
placeholder[h : 2 * h, 3 * w : 4 * w, :3] = rgba2rgb(
hv_cmap((pred_sample_hv_map[i, 1, :, :] + 1) / 2)
)
# instance_predictions
placeholder[:h, 4 * w : 5 * w, :3] = rgba2rgb(
instance_map(
(gt_sample_instance_map[i] - np.min(gt_sample_instance_map[i]))
/ (
np.max(gt_sample_instance_map[i])
- np.min(gt_sample_instance_map[i] + 1e-10)
)
)
)
placeholder[h : 2 * h, 4 * w : 5 * w, :3] = rgba2rgb(
instance_map(
(
pred_sample_instance_maps[i]
- np.min(pred_sample_instance_maps[i])
)
/ (
np.max(pred_sample_instance_maps[i])
- np.min(pred_sample_instance_maps[i] + 1e-10)
)
)
)
# type_predictions
placeholder[:h, 5 * w : 6 * w, :3] = rgba2rgb(
binary_cmap(gt_sample_type_map[i] / num_nuclei_classes)
)
placeholder[h : 2 * h, 5 * w : 6 * w, :3] = rgba2rgb(
binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes)
)
# contours
# gt
gt_contours_polygon = [
v["contour"] for v in ground_truth.instance_types[i].values()
]
gt_contours_polygon = [
list(zip(poly[:, 0], poly[:, 1])) for poly in gt_contours_polygon
]
gt_contour_colors_polygon = [
cell_colors[v["type"]]
for v in ground_truth.instance_types[i].values()
]
gt_cell_image = Image.fromarray(
(sample_images[i] * 255).astype(np.uint8)
).convert("RGB")
gt_drawing = ImageDraw.Draw(gt_cell_image)
add_patch = lambda poly, color: gt_drawing.polygon(
poly, outline=color, width=2
)
[
add_patch(poly, c)
for poly, c in zip(gt_contours_polygon, gt_contour_colors_polygon)
]
gt_cell_image.save(outdir / f"raw_gt_{img_names[i]}")
placeholder[:h, 6 * w : 7 * w, :3] = np.asarray(gt_cell_image) / 255
# pred
pred_contours_polygon = [
v["contour"] for v in predictions.instance_types[i].values()
]
pred_contours_polygon = [
list(zip(poly[:, 0], poly[:, 1])) for poly in pred_contours_polygon
]
pred_contour_colors_polygon = [
cell_colors[v["type"]]
for v in predictions.instance_types[i].values()
]
pred_cell_image = Image.fromarray(
(sample_images[i] * 255).astype(np.uint8)
).convert("RGB")
pred_drawing = ImageDraw.Draw(pred_cell_image)
add_patch = lambda poly, color: pred_drawing.polygon(
poly, outline=color, width=2
)
[
add_patch(poly, c)
for poly, c in zip(pred_contours_polygon, pred_contour_colors_polygon)
]
pred_cell_image.save(outdir / f"raw_pred_{img_names[i]}")
placeholder[h : 2 * h, 6 * w : 7 * w, :3] = (
np.asarray(pred_cell_image) / 255
)
# plotting
axs.imshow(placeholder)
axs.set_xticks(np.arange(w / 2, 7 * w, w))
axs.set_xticklabels(
[
"Image",
"Binary-Cells",
"HV-Map-0",
"HV-Map-1",
"Instances",
"Nuclei-Pred",
"Countours",
],
fontsize=6,
)
axs.xaxis.tick_top()
axs.set_yticks(np.arange(h / 2, 2 * h, h))
axs.set_yticklabels(["GT", "Pred."], fontsize=6)
axs.tick_params(axis="both", which="both", length=0)
grid_x = np.arange(w, 6 * w, w)
grid_y = np.arange(h, 2 * h, h)
for x_seg in grid_x:
axs.axvline(x_seg, color="black")
for y_seg in grid_y:
axs.axhline(y_seg, color="black")
if scores is not None:
axs.text(
20,
1.85 * h,
f"Dice: {str(np.round(scores[i][0], 2))}\nJac.: {str(np.round(scores[i][1], 2))}\nbPQ: {str(np.round(scores[i][2], 2))}",
bbox={"facecolor": "white", "pad": 2, "alpha": 0.5},
fontsize=4,
)
fig.suptitle(f"Patch Predictions for {img_names[i]}")
fig.tight_layout()
fig.savefig(outdir / f"pred_{img_names[i]}")
plt.close()
With this version, generated plots look reasonable on my test images
The text was updated successfully, but these errors were encountered:
Describe the bug
Current plot_results() can't handle the attributes of predictions/ground_truth and the shape well. Updated version from my side below:
With this version, generated plots look reasonable on my test images
The text was updated successfully, but these errors were encountered: