diff --git a/bnpm/plotting_helpers.py b/bnpm/plotting_helpers.py index 85fb88b..d63ad2a 100644 --- a/bnpm/plotting_helpers.py +++ b/bnpm/plotting_helpers.py @@ -101,8 +101,8 @@ def display_toggle_image_stack( RH 2023 Args: - images (Union[List[np.ndarray], List[torch.Tensor]]): - List of images as numpy arrays or PyTorch tensors. + images (List[np.ndarray]): + List of images as numpy arrays. image_size (Optional[Tuple[int, int]]): Tuple of *(width, height)* for resizing images.\n If ``None``, images are not resized.\n @@ -119,15 +119,15 @@ def display_toggle_image_stack( the Image.Resampling.* methods from PIL. (Default is 'nearest') """ from IPython.display import display, HTML - import numpy as np import base64 from PIL import Image from io import BytesIO - import torch import datetime import hashlib import sys + import numpy as np + # Get the image size for display if image_size is None: image_size = images[0].shape[:2] @@ -140,9 +140,6 @@ def display_toggle_image_stack( def normalize_image(image, clim=None): """Normalize the input image using the min-max scaling method. Optionally, use the given clim values for scaling.""" - if isinstance(image, torch.Tensor): - image = image.detach().cpu().numpy() - if clim is None: clim = (np.min(image), np.max(image)) @@ -151,9 +148,6 @@ def normalize_image(image, clim=None): return (norm_image * 255).astype(np.uint8) def resize_image(image, new_size, interpolation): """Resize the given image to the specified new size using the specified interpolation method.""" - if isinstance(image, torch.Tensor): - image = image.detach().cpu().numpy() - pil_image = Image.fromarray(image.astype(np.uint8)) resized_image = pil_image.resize(new_size, resample=interpolation) return np.array(resized_image)