diff --git a/spotiflow/cli/predict.py b/spotiflow/cli/predict.py index fcb492b..88847b4 100644 --- a/spotiflow/cli/predict.py +++ b/spotiflow/cli/predict.py @@ -4,13 +4,15 @@ from itertools import chain from pathlib import Path +import numpy as np +import pandas as pd import torch from skimage.io import imread from tqdm.auto import tqdm from .. import __version__ from ..model import Spotiflow -from ..utils import infer_n_tiles, str2bool, write_coords_csv +from ..utils import estimate_fwhm, infer_n_tiles, str2bool log = logging.getLogger(__name__) log.setLevel(logging.INFO) @@ -61,7 +63,7 @@ def get_args(): "--out-dir", type=Path, required=False, - default='spotiflow_results', + default="spotiflow_results", help="Output directory to write the CSV(s). If not provided, will create a 'spotiflow_results' subfolder in the current folder.", ) @@ -134,6 +136,12 @@ def get_args(): choices=["fast", "skimage"], help="Peak detection mode (can be either 'skimage' or 'fast', which is a faster custom C++ implementation). Defaults to 'fast'.", ) + predict.add_argument( + "--estimate-fwhm", + type=str2bool, + default=False, + help="Estimate FWHM of detected spots by Gaussian fitting. Defaults to False.", + ) predict.add_argument( "-norm", "--normalizer", @@ -257,7 +265,7 @@ def main(): img = _imread_wrapped(f) if not _check_valid_input_shape(img.shape, model.config): raise ValueError( - f"image {f} has invalid shape {img.shape} for model with is_3d={model.config.is_3d} and {model.config.in_channels} input channels" + f"image {f} has invalid shape {img.shape} for model with is_3d={model.config.is_3d} and {model.config.in_channels} input channels. The image shape should be either (Y,X,[C]) for a 2D model or (Z,Y,X,[C]) for a 3D model, where the [C] dimension is optional for single-channel inputs." ) images.append(img) @@ -265,14 +273,17 @@ def main(): zip(images, image_files), desc="Predicting", total=len(images) ): if args.n_tiles is None: - n_tiles = infer_n_tiles(img.shape[:2], args.max_tile_size) + n_tiles = infer_n_tiles( + img.shape[:2] if not model.config.is_3d else img.shape[:3], + args.max_tile_size, + ) else: n_tiles = tuple(args.n_tiles) if args.verbose: log.info(f"Predicting spots in {fname} with {n_tiles=}") - spots, _ = model.predict( + spots, details = model.predict( img, prob_thresh=args.probability_threshold, n_tiles=n_tiles, @@ -285,8 +296,21 @@ def main(): verbose=args.verbose, device=args.device, ) - write_coords_csv(spots, out_dir / f"{fname.stem}.csv") - + csv_columns = ("y", "x") + if spots.shape[1] == 3: + csv_columns = ("z",) + csv_columns + df = pd.DataFrame(np.round(spots, 4), columns=csv_columns) + df["intensity"] = np.round(details.intens, 2) + df["probability"] = np.round(details.prob, 3) + if args.estimate_fwhm: + if spots.shape[1] == 3: + log.warning( + "Estimating FWHM is not supported for 3D images yet. Skipping." + ) + else: + fwhm = estimate_fwhm(img, spots) + df["fwhm"] = np.round(fwhm, 3) + df.to_csv(out_dir / f"{fname.stem}.csv", index=False) return 0