Skip to content

Commit

Permalink
add outputs to CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Oct 17, 2024
1 parent 9e28e88 commit d26b6b0
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions spotiflow/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from itertools import chain
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from skimage.io import imread
from tqdm.auto import tqdm

from .. import __version__
from ..model import Spotiflow
from ..utils import infer_n_tiles, str2bool, write_coords_csv
from ..utils import estimate_fwhm, infer_n_tiles, str2bool

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -61,7 +63,7 @@ def get_args():
"--out-dir",
type=Path,
required=False,
default='spotiflow_results',
default="spotiflow_results",
help="Output directory to write the CSV(s). If not provided, will create a 'spotiflow_results' subfolder in the current folder.",
)

Expand Down Expand Up @@ -134,6 +136,12 @@ def get_args():
choices=["fast", "skimage"],
help="Peak detection mode (can be either 'skimage' or 'fast', which is a faster custom C++ implementation). Defaults to 'fast'.",
)
predict.add_argument(
"--estimate-fwhm",
type=str2bool,
default=False,
help="Estimate FWHM of detected spots by Gaussian fitting. Defaults to False.",
)
predict.add_argument(
"-norm",
"--normalizer",
Expand Down Expand Up @@ -257,22 +265,25 @@ def main():
img = _imread_wrapped(f)
if not _check_valid_input_shape(img.shape, model.config):
raise ValueError(
f"image {f} has invalid shape {img.shape} for model with is_3d={model.config.is_3d} and {model.config.in_channels} input channels"
f"image {f} has invalid shape {img.shape} for model with is_3d={model.config.is_3d} and {model.config.in_channels} input channels. The image shape should be either (Y,X,[C]) for a 2D model or (Z,Y,X,[C]) for a 3D model, where the [C] dimension is optional for single-channel inputs."
)
images.append(img)

for img, fname in tqdm(
zip(images, image_files), desc="Predicting", total=len(images)
):
if args.n_tiles is None:
n_tiles = infer_n_tiles(img.shape[:2], args.max_tile_size)
n_tiles = infer_n_tiles(
img.shape[:2] if not model.config.is_3d else img.shape[:3],
args.max_tile_size,
)
else:
n_tiles = tuple(args.n_tiles)

if args.verbose:
log.info(f"Predicting spots in {fname} with {n_tiles=}")

spots, _ = model.predict(
spots, details = model.predict(
img,
prob_thresh=args.probability_threshold,
n_tiles=n_tiles,
Expand All @@ -285,8 +296,21 @@ def main():
verbose=args.verbose,
device=args.device,
)
write_coords_csv(spots, out_dir / f"{fname.stem}.csv")

csv_columns = ("y", "x")
if spots.shape[1] == 3:
csv_columns = ("z",) + csv_columns
df = pd.DataFrame(np.round(spots, 4), columns=csv_columns)
df["intensity"] = np.round(details.intens, 2)
df["probability"] = np.round(details.prob, 3)
if args.estimate_fwhm:
if spots.shape[1] == 3:
log.warning(
"Estimating FWHM is not supported for 3D images yet. Skipping."
)
else:
fwhm = estimate_fwhm(img, spots)
df["fwhm"] = np.round(fwhm, 3)
df.to_csv(out_dir / f"{fname.stem}.csv", index=False)
return 0


Expand Down

0 comments on commit d26b6b0

Please sign in to comment.