Skip to content

Commit

Permalink
extend fitting function to 3D and add to model predict
Browse files Browse the repository at this point in the history
  • Loading branch information
maweigert authored and AlbertDominguez committed Oct 24, 2024
1 parent 11b1291 commit 58d0fcc
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 35 deletions.
23 changes: 9 additions & 14 deletions spotiflow/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .. import __version__
from ..model import Spotiflow
from ..utils import estimate_params, infer_n_tiles, str2bool
from ..utils import infer_n_tiles, str2bool
from ..utils.fitting import signal_to_background

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,10 +138,10 @@ def get_args():
help="Peak detection mode (can be either 'skimage' or 'fast', which is a faster custom C++ implementation). Defaults to 'fast'.",
)
predict.add_argument(
"--estimate-fwhm",
"--estimate-params",
type=str2bool,
default=False,
help="Estimate FWHM of detected spots by Gaussian fitting. Defaults to False.",
help="Estimate fit parameters of detected spots by Gaussian fitting (eg FWHM, intensity). Defaults to False.",
)
predict.add_argument(
"-norm",
Expand Down Expand Up @@ -296,24 +296,19 @@ def main():
normalizer=args.normalizer,
verbose=args.verbose,
device=args.device,
fit_params=args.estimate_params,
)
csv_columns = ("y", "x")
if spots.shape[1] == 3:
csv_columns = ("z",) + csv_columns
df = pd.DataFrame(np.round(spots, 4), columns=csv_columns)
df["intensity"] = np.round(details.intens, 2)
df["probability"] = np.round(details.prob, 3)
if args.estimate_fwhm:
if spots.shape[1] == 3:
log.warning(
"Estimating FWHM is not supported for 3D images yet. Skipping."
)
else:
params = estimate_params(img, spots)
df['fwhm'] = np.round(params.fwhm, 3)
df['intens_A'] = np.round(params.intens_A, 3)
df['intens_B'] = np.round(params.intens_B, 3)
df['snb'] = np.round(signal_to_background(params), 3)
if args.estimate_params:
df['fwhm'] = np.round(details.fit_params.fwhm, 3)
df['intens_A'] = np.round(details.fit_params.intens_A, 3)
df['intens_B'] = np.round(details.fit_params.intens_B, 3)
df['snb'] = np.round(signal_to_background(details.fit_params), 3)

df.to_csv(out_dir / f"{fname.stem}.csv", index=False)
return 0
Expand Down
13 changes: 8 additions & 5 deletions spotiflow/lib/spotflow3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args)

PyArrayObject *points = NULL;
PyArrayObject *dst = NULL;
PyArrayObject *sigmas = NULL;
PyArrayObject *probs = NULL;
int shape_z, shape_y, shape_x;
int grid_z, grid_y, grid_x;
float sigma;

if (!PyArg_ParseTuple(args, "O!iiiiiif", &PyArray_Type, &points, &shape_z, &shape_y, &shape_x, &grid_z, &grid_y, &grid_x, &sigma))
if (!PyArg_ParseTuple(args, "O!O!O!iiiiii", &PyArray_Type, &points, &PyArray_Type, &probs, &PyArray_Type, &sigmas, &shape_z, &shape_y, &shape_x, &grid_z, &grid_y, &grid_x))
return NULL;

npy_intp *dims = PyArray_DIMS(points);
Expand Down Expand Up @@ -204,8 +205,6 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args)
index.buildIndex();


const float sigma_denom = 2 * sigma * sigma / cbrt(grid_z * grid_y * grid_x);

#ifdef __APPLE__
#pragma omp parallel for
#else
Expand Down Expand Up @@ -237,8 +236,12 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args)

const float r2 = x * x + y * y + z * z;

const float prob = *(float *)PyArray_GETPTR1(probs, ret_index);
const float sigma = *(float *)PyArray_GETPTR1(sigmas, ret_index);
const float sigma_denom = 2 * sigma * sigma / cbrt(grid_z * grid_y * grid_x);

// the gaussian value
const float val = exp(-r2 / sigma_denom);
const float val = prob * exp(-r2 / sigma_denom);

*(float *)PyArray_GETPTR3(dst, i, j, k) = val;
}
Expand Down
14 changes: 11 additions & 3 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
prob_to_points,
subpixel_offset,
trilinear_interp_points,
estimate_params
)
from ..utils import (
tile_iterator as parallel_tile_iterator,
Expand Down Expand Up @@ -714,6 +715,7 @@ def predict(
Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]]
] = None,
distributed_params: Optional[dict] = None,
fit_params: bool = False,
) -> Tuple[np.ndarray, SimpleNamespace]:
"""Predict spots in an image.
Expand All @@ -730,7 +732,7 @@ def predict(
verbose (bool, optional): Whether to print logs and progress. Defaults to True.
progress_bar_wrapper (Optional[callable], optional): Progress bar wrapper to use. Defaults to None.
device (Optional[Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]]], optional): computing device to use. If None, will infer from model location. If "auto", will infer from available hardware. Defaults to None.
fit_params (bool, optional): Whether to fit the model parameters to the input image. Defaults to False.
Returns:
Tuple[np.ndarray, SimpleNamespace]: Tuple of (points, details). Points are the coordinates of the spots. Details is a namespace containing the spot-wise probabilities (`prob`), the heatmap (`heatmap`), the stereographic flow (`flow`), the 2D local offset vector field (`subpix`) and the spot intensities (`intens`).
"""
Expand Down Expand Up @@ -1098,6 +1100,11 @@ def predict(
_subpix = None
flow = None

if not skip_details and fit_params:
fit_params = estimate_params(img[...,0], pts)
else:
fit_params = None

if verbose:
log.info(f"Found {len(pts)} spots")

Expand All @@ -1119,8 +1126,9 @@ def predict(
)
intens = img[tuple(pts.round().astype(int).T)]
details = SimpleNamespace(
prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens
)
prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens,
fit_params=fit_params
)
return pts, details

def predict_dataset(
Expand Down
124 changes: 115 additions & 9 deletions spotiflow/utils/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from scipy.optimize import curve_fit

from tqdm.auto import tqdm
from dataclasses import dataclass

Expand All @@ -25,17 +26,31 @@ def _gaussian_2d(yx, y0, x0, sigma, A, B):
y, x = yx
return A * np.exp(-((y - y0) ** 2 + (x - x0) ** 2) / (2 * sigma**2)) + B

def _gaussian_3d(zyx, z0, y0, x0, sigma, A, B):
z, y, x = zyx
return A * np.exp(-((z - z0) ** 2 + (y - y0) ** 2 + (x - x0) ** 2) / (2 * sigma**2)) + B

@dataclass
class SpotParams:
class FitParams2D:
fwhm: Union[float, np.ndarray]
offset_y: Union[float, np.ndarray]
offset_x: Union[float, np.ndarray]
intens_A: Union[float, np.ndarray]
intens_B: Union[float, np.ndarray]
r_squared: Union[float, np.ndarray]

@dataclass
class FitParams3D:
fwhm: Union[float, np.ndarray]
offset_z: Union[float, np.ndarray]
offset_y: Union[float, np.ndarray]
offset_x: Union[float, np.ndarray]
intens_A: Union[float, np.ndarray]
intens_B: Union[float, np.ndarray]
r_squared: Union[float, np.ndarray]


def signal_to_background(params: SpotParams) -> np.ndarray:
def signal_to_background(params: FitParams2D) -> np.ndarray:
"""Calculates the signal to background ratio of the spots. Given a Gaussian fit
of the form A*exp(...) + B, the signal to background
ratio is computed as A/B.
Expand All @@ -52,13 +67,35 @@ def signal_to_background(params: SpotParams) -> np.ndarray:
return snb


def _r_squared(y_true, y_pred):
y_true, y_pred = np.array(y_true).ravel(), np.array(y_pred).ravel()
ss_res = np.sum((y_true - y_pred)**2)
ss_tot = np.sum((y_true - np.mean(y_true))**2)
r2 = 1 - (ss_res / ss_tot)
return r2

def _estimate_params_single(
center: np.ndarray,
image: np.ndarray,
window: int,
refine_centers: bool,
verbose: bool,
) -> SpotParams:
) -> Union[FitParams2D, FitParams3D]:

if image.ndim == 2:
return _estimate_params_single2(center, image, window, refine_centers, verbose)
elif image.ndim == 3:
return _estimate_params_single3(center, image, window, refine_centers, verbose)
else:
raise ValueError("Image must have 2 or 3 dimensions")

def _estimate_params_single2(
center: np.ndarray,
image: np.ndarray,
window: int,
refine_centers: bool,
verbose: bool,
) -> FitParams2D:
x_range = np.arange(-window, window + 1)
y_range = np.arange(-window, window + 1)
y, x = np.meshgrid(y_range, x_range, indexing="ij")
Expand Down Expand Up @@ -89,20 +126,81 @@ def _estimate_params_single(
p0=initial_guess,
bounds=(lower_bounds, upper_bounds),
)

pred = _gaussian_2d((y.ravel(), x.ravel()), *popt)
r_squared = _r_squared(region.ravel(), pred)

except Exception as _:
if verbose:
log.warning("Gaussian fit failed. Returning NaN")
mi, ma = np.nan, np.nan
popt = np.full(5, np.nan)
r_squared = 0

return SpotParams(
return FitParams2D(
fwhm=FWHM_CONSTANT * popt[2],
offset_y=popt[0],
offset_x=popt[1],
intens_A=(popt[3]+popt[4])*(ma - mi),
intens_B=popt[4] * (ma - mi) + mi,
r_squared=r_squared
)

def _estimate_params_single3(
center: np.ndarray,
image: np.ndarray,
window: int,
refine_centers: bool,
verbose: bool,
) -> FitParams3D:
z,y,x = np.meshgrid(*((np.arange(-window, window + 1),)*3), indexing="ij")

# Crop around the spot
region = image[
center[0] - window : center[0] + window + 1,
center[1] - window : center[1] + window + 1,
center[2] - window : center[2] + window + 1,
]

try:
mi, ma = np.min(region), np.max(region)
region = (region - mi) / (ma - mi)
initial_guess = (0, 0, 0, 1.5, 1, 0) # z0, y0, x0, sigma, A, B

if refine_centers:
lower_bounds = (-.5, -.5, -.5, 0.1, 0.5, -0.5) # y0, x0, sigma, A, B
upper_bounds = (.5, .5, .5, 10, 1.5, 0.5) # y0, x0, sigma, A, B
else:
lower_bounds = (-1e-6, -1e-6, -1e-6, 0.1, 0.5, -0.5)
upper_bounds = ( 1e-6, 1e-6, 1e-6, 10, 1.5, 0.5)

popt, _ = curve_fit(
_gaussian_3d,
(z.ravel(), y.ravel(), x.ravel()),
region.ravel(),
p0=initial_guess,
bounds=(lower_bounds, upper_bounds),
)

pred = _gaussian_3d((z.ravel(), y.ravel(), x.ravel()), *popt)
r_squared = _r_squared(region.ravel(), pred)

except Exception as _:
if verbose:
log.warning("Gaussian fit failed. Returning NaN")
mi, ma = np.nan, np.nan
popt = np.full(6, np.nan)
r_squared = 0

return FitParams3D(
fwhm=FWHM_CONSTANT * popt[3],
offset_z=popt[0],
offset_y=popt[1],
offset_x=popt[2],
intens_A=(popt[4]+popt[5])*(ma - mi),
intens_B=popt[5] * (ma - mi) + mi,
r_squared=r_squared
)

def estimate_params(
img: np.ndarray,
Expand Down Expand Up @@ -158,9 +256,17 @@ def estimate_params(
)
)

keys = SpotParams.__dataclass_fields__.keys()

params = SpotParams(
**dict((k, np.array([getattr(p, k) for p in params])) for k in keys)
)
if img.ndim == 2:
keys = FitParams2D.__dataclass_fields__.keys()
params = FitParams2D(
**dict((k, np.array([getattr(p, k) for p in params])) for k in keys)
)
elif img.ndim == 3:
keys = FitParams3D.__dataclass_fields__.keys()
params = FitParams3D(
**dict((k, np.array([getattr(p, k) for p in params])) for k in keys)
)
else:
raise ValueError("Image must have 2 or 3 dimensions")

return params
Loading

0 comments on commit 58d0fcc

Please sign in to comment.