Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_results() function in "cell_segmentation/inference /inference_cellvit_experiment_pannuke.py" #60

Open
PingjunChen opened this issue Aug 28, 2024 · 0 comments

Comments

@PingjunChen
Copy link

Describe the bug
Current plot_results() can't handle the attributes of predictions/ground_truth and the shape well. Updated version from my side below:

def plot_results(
    self,
    imgs: Union[torch.Tensor, np.ndarray],
    predictions: dict,
    ground_truth: dict,
    img_names: List,
    num_nuclei_classes: int,
    outdir: Union[Path, str],
    scores: List[List[float]] = None,
) -> None:
    """Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth

    Args:
        imgs (Union[torch.Tensor, np.ndarray]): Images to process
            Shape: (batch_size, 3, H', W')
        predictions (dict): Predictions of models. Keys:
            "nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
            "nuclei_binary_map": Shape: (batch_size, 2, H', W')
            "hv_map": Shape: (batch_size, 2, H', W')
            "instance_map": Shape: (batch_size, H', W')
        ground_truth (dict): Ground truth values. Keys:
            "nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
            "nuclei_binary_map": Shape: (batch_size, 2, H', W')
            "hv_map": Shape: (batch_size, 2, H', W')
            "instance_map": Shape: (batch_size, H', W')
        img_names (List): Names of images as list
        num_nuclei_classes (int): Number of total nuclei classes including background
        outdir (Union[Path, str]): Output directory where images should be stored
        scores (List[List[float]], optional): List with scores for each image.
            Each list entry is a list with 3 scores: Dice, Jaccard and bPQ for the image.
            Defaults to None.
    """
    outdir = Path(outdir)
    outdir.mkdir(exist_ok=True, parents=True)

    h = ground_truth.hv_map.shape[2]
    w = ground_truth.hv_map.shape[3]

    # convert to rgb and crop to selection
    sample_images = (
        imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy()
    )  # convert to rgb
    sample_images = cropping_center(sample_images, (h, w), True)

    pred_sample_binary_map = (
        predictions.nuclei_binary_map[:, 1, :, :].detach().cpu().numpy()
    )
    pred_sample_hv_map = predictions.hv_map.detach().cpu().numpy()
    pred_sample_instance_maps = predictions.instance_map.detach().cpu().numpy()
    pred_sample_type_maps = (
        torch.argmax(predictions.nuclei_type_map, dim=1).detach().cpu().numpy()
    )

    gt_sample_binary_map = ground_truth.nuclei_binary_map.detach().cpu().numpy()
    gt_sample_hv_map = ground_truth.hv_map.detach().cpu().numpy()
    gt_sample_instance_map = ground_truth.instance_map.detach().cpu().numpy()
    gt_sample_type_map = (
        torch.argmax(ground_truth.nuclei_type_map, dim=1).detach().cpu().numpy()
    )

    # create colormaps
    hv_cmap = plt.get_cmap("jet")
    binary_cmap = plt.get_cmap("jet")
    instance_map = plt.get_cmap("viridis")
    cell_colors = ["#ffffff", "#ff0000", "#00ff00", "#1e00ff", "#feff00", "#ffbf00"]

    # invert the normalization of the sample images
    transform_settings = self.run_conf["transformations"]
    if "normalize" in transform_settings:
        mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5))
        std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5))
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
    inv_normalize = transforms.Normalize(
        mean=[-0.5 / mean[0], -0.5 / mean[1], -0.5 / mean[2]],
        std=[1 / std[0], 1 / std[1], 1 / std[2]],
    )
    inv_samples = inv_normalize(torch.tensor(sample_images).permute(0, 3, 1, 2))
    sample_images = inv_samples.permute(0, 2, 3, 1).detach().cpu().numpy()

    for i in range(len(img_names)):
        fig, axs = plt.subplots(figsize=(6, 2), dpi=300)
        placeholder = np.zeros((2 * h, 7 * w, 3))
        # orig image
        placeholder[:h, :w, :3] = sample_images[i]
        placeholder[h : 2 * h, :w, :3] = sample_images[i]
        # binary prediction
        placeholder[:h, w : 2 * w, :3] = rgba2rgb(
            binary_cmap(gt_sample_binary_map[i] * 255)
        )
        
        placeholder[h : 2 * h, w : 2 * w, :3] = rgba2rgb(
            binary_cmap(pred_sample_binary_map[i] * 255)
        )  
        # hv maps
        placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb(
            hv_cmap((gt_sample_hv_map[i, 0, :, :] + 1) / 2)
        )
        placeholder[h : 2 * h, 2 * w : 3 * w, :3] = rgba2rgb(
            hv_cmap((pred_sample_hv_map[i, 0, :, :] + 1) / 2)
        )
        placeholder[:h, 3 * w : 4 * w, :3] = rgba2rgb(
            hv_cmap((gt_sample_hv_map[i, 1, :, :] + 1) / 2)
        )
        placeholder[h : 2 * h, 3 * w : 4 * w, :3] = rgba2rgb(
            hv_cmap((pred_sample_hv_map[i, 1, :, :] + 1) / 2)
        )
        # instance_predictions
        placeholder[:h, 4 * w : 5 * w, :3] = rgba2rgb(
            instance_map(
                (gt_sample_instance_map[i] - np.min(gt_sample_instance_map[i]))
                / (
                    np.max(gt_sample_instance_map[i])
                    - np.min(gt_sample_instance_map[i] + 1e-10)
                )
            )
        )
        placeholder[h : 2 * h, 4 * w : 5 * w, :3] = rgba2rgb(
            instance_map(
                (
                    pred_sample_instance_maps[i]
                    - np.min(pred_sample_instance_maps[i])
                )
                / (
                    np.max(pred_sample_instance_maps[i])
                    - np.min(pred_sample_instance_maps[i] + 1e-10)
                )
            )
        )
        # type_predictions
        placeholder[:h, 5 * w : 6 * w, :3] = rgba2rgb(
            binary_cmap(gt_sample_type_map[i] / num_nuclei_classes)
        )
        placeholder[h : 2 * h, 5 * w : 6 * w, :3] = rgba2rgb(
            binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes)
        )

        # contours
        # gt
        gt_contours_polygon = [
            v["contour"] for v in ground_truth.instance_types[i].values()
        ]
        gt_contours_polygon = [
            list(zip(poly[:, 0], poly[:, 1])) for poly in gt_contours_polygon
        ]
        gt_contour_colors_polygon = [
            cell_colors[v["type"]]
            for v in ground_truth.instance_types[i].values()
        ]
        gt_cell_image = Image.fromarray(
            (sample_images[i] * 255).astype(np.uint8)
        ).convert("RGB")
        gt_drawing = ImageDraw.Draw(gt_cell_image)
        add_patch = lambda poly, color: gt_drawing.polygon(
            poly, outline=color, width=2
        )
        [
            add_patch(poly, c)
            for poly, c in zip(gt_contours_polygon, gt_contour_colors_polygon)
        ]
        gt_cell_image.save(outdir / f"raw_gt_{img_names[i]}")
        placeholder[:h, 6 * w : 7 * w, :3] = np.asarray(gt_cell_image) / 255
        # pred
        pred_contours_polygon = [
            v["contour"] for v in predictions.instance_types[i].values()
        ]
        pred_contours_polygon = [
            list(zip(poly[:, 0], poly[:, 1])) for poly in pred_contours_polygon
        ]
        pred_contour_colors_polygon = [
            cell_colors[v["type"]]
            for v in predictions.instance_types[i].values()
        ]
        pred_cell_image = Image.fromarray(
            (sample_images[i] * 255).astype(np.uint8)
        ).convert("RGB")
        pred_drawing = ImageDraw.Draw(pred_cell_image)
        add_patch = lambda poly, color: pred_drawing.polygon(
            poly, outline=color, width=2
        )
        [
            add_patch(poly, c)
            for poly, c in zip(pred_contours_polygon, pred_contour_colors_polygon)
        ]
        pred_cell_image.save(outdir / f"raw_pred_{img_names[i]}")
        placeholder[h : 2 * h, 6 * w : 7 * w, :3] = (
            np.asarray(pred_cell_image) / 255
        )

        # plotting
        axs.imshow(placeholder)
        axs.set_xticks(np.arange(w / 2, 7 * w, w))
        axs.set_xticklabels(
            [
                "Image",
                "Binary-Cells",
                "HV-Map-0",
                "HV-Map-1",
                "Instances",
                "Nuclei-Pred",
                "Countours",
            ],
            fontsize=6,
        )
        axs.xaxis.tick_top()

        axs.set_yticks(np.arange(h / 2, 2 * h, h))
        axs.set_yticklabels(["GT", "Pred."], fontsize=6)
        axs.tick_params(axis="both", which="both", length=0)
        grid_x = np.arange(w, 6 * w, w)
        grid_y = np.arange(h, 2 * h, h)

        for x_seg in grid_x:
            axs.axvline(x_seg, color="black")
        for y_seg in grid_y:
            axs.axhline(y_seg, color="black")

        if scores is not None:
            axs.text(
                20,
                1.85 * h,
                f"Dice: {str(np.round(scores[i][0], 2))}\nJac.: {str(np.round(scores[i][1], 2))}\nbPQ: {str(np.round(scores[i][2], 2))}",
                bbox={"facecolor": "white", "pad": 2, "alpha": 0.5},
                fontsize=4,
            )
        fig.suptitle(f"Patch Predictions for {img_names[i]}")
        fig.tight_layout()
        fig.savefig(outdir / f"pred_{img_names[i]}")
        plt.close()

With this version, generated plots look reasonable on my test images
pred_4_3
pred_4_57

@github-staff github-staff deleted a comment from Superstar-IT Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant