diff --git a/spotiflow/cli/predict.py b/spotiflow/cli/predict.py index 60bb2bc..3d93a92 100644 --- a/spotiflow/cli/predict.py +++ b/spotiflow/cli/predict.py @@ -12,7 +12,7 @@ from .. import __version__ from ..model import Spotiflow -from ..utils import estimate_params, infer_n_tiles, str2bool +from ..utils import infer_n_tiles, str2bool from ..utils.fitting import signal_to_background log = logging.getLogger(__name__) @@ -138,10 +138,10 @@ def get_args(): help="Peak detection mode (can be either 'skimage' or 'fast', which is a faster custom C++ implementation). Defaults to 'fast'.", ) predict.add_argument( - "--estimate-fwhm", + "--estimate-params", type=str2bool, default=False, - help="Estimate FWHM of detected spots by Gaussian fitting. Defaults to False.", + help="Estimate fit parameters of detected spots by Gaussian fitting (eg FWHM, intensity). Defaults to False.", ) predict.add_argument( "-norm", @@ -296,6 +296,7 @@ def main(): normalizer=args.normalizer, verbose=args.verbose, device=args.device, + fit_params=args.estimate_params, ) csv_columns = ("y", "x") if spots.shape[1] == 3: @@ -303,17 +304,11 @@ def main(): df = pd.DataFrame(np.round(spots, 4), columns=csv_columns) df["intensity"] = np.round(details.intens, 2) df["probability"] = np.round(details.prob, 3) - if args.estimate_fwhm: - if spots.shape[1] == 3: - log.warning( - "Estimating FWHM is not supported for 3D images yet. Skipping." - ) - else: - params = estimate_params(img, spots) - df['fwhm'] = np.round(params.fwhm, 3) - df['intens_A'] = np.round(params.intens_A, 3) - df['intens_B'] = np.round(params.intens_B, 3) - df['snb'] = np.round(signal_to_background(params), 3) + if args.estimate_params: + df['fwhm'] = np.round(details.fit_params.fwhm, 3) + df['intens_A'] = np.round(details.fit_params.intens_A, 3) + df['intens_B'] = np.round(details.fit_params.intens_B, 3) + df['snb'] = np.round(signal_to_background(details.fit_params), 3) df.to_csv(out_dir / f"{fname.stem}.csv", index=False) return 0 diff --git a/spotiflow/lib/spotflow3d.cpp b/spotiflow/lib/spotflow3d.cpp index 6c346ae..1b58262 100644 --- a/spotiflow/lib/spotflow3d.cpp +++ b/spotiflow/lib/spotflow3d.cpp @@ -160,11 +160,12 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args) PyArrayObject *points = NULL; PyArrayObject *dst = NULL; + PyArrayObject *sigmas = NULL; + PyArrayObject *probs = NULL; int shape_z, shape_y, shape_x; int grid_z, grid_y, grid_x; - float sigma; - if (!PyArg_ParseTuple(args, "O!iiiiiif", &PyArray_Type, &points, &shape_z, &shape_y, &shape_x, &grid_z, &grid_y, &grid_x, &sigma)) + if (!PyArg_ParseTuple(args, "O!O!O!iiiiii", &PyArray_Type, &points, &PyArray_Type, &probs, &PyArray_Type, &sigmas, &shape_z, &shape_y, &shape_x, &grid_z, &grid_y, &grid_x)) return NULL; npy_intp *dims = PyArray_DIMS(points); @@ -204,8 +205,6 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args) index.buildIndex(); - const float sigma_denom = 2 * sigma * sigma / cbrt(grid_z * grid_y * grid_x); - #ifdef __APPLE__ #pragma omp parallel for #else @@ -237,8 +236,12 @@ static PyObject *c_gaussian3d(PyObject *self, PyObject *args) const float r2 = x * x + y * y + z * z; + const float prob = *(float *)PyArray_GETPTR1(probs, ret_index); + const float sigma = *(float *)PyArray_GETPTR1(sigmas, ret_index); + const float sigma_denom = 2 * sigma * sigma / cbrt(grid_z * grid_y * grid_x); + // the gaussian value - const float val = exp(-r2 / sigma_denom); + const float val = prob * exp(-r2 / sigma_denom); *(float *)PyArray_GETPTR3(dst, i, j, k) = val; } diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 470c2fb..29e727f 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -35,6 +35,7 @@ prob_to_points, subpixel_offset, trilinear_interp_points, + estimate_params ) from ..utils import ( tile_iterator as parallel_tile_iterator, @@ -714,6 +715,7 @@ def predict( Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]] ] = None, distributed_params: Optional[dict] = None, + fit_params: bool = False, ) -> Tuple[np.ndarray, SimpleNamespace]: """Predict spots in an image. @@ -730,7 +732,7 @@ def predict( verbose (bool, optional): Whether to print logs and progress. Defaults to True. progress_bar_wrapper (Optional[callable], optional): Progress bar wrapper to use. Defaults to None. device (Optional[Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]]], optional): computing device to use. If None, will infer from model location. If "auto", will infer from available hardware. Defaults to None. - + fit_params (bool, optional): Whether to fit the model parameters to the input image. Defaults to False. Returns: Tuple[np.ndarray, SimpleNamespace]: Tuple of (points, details). Points are the coordinates of the spots. Details is a namespace containing the spot-wise probabilities (`prob`), the heatmap (`heatmap`), the stereographic flow (`flow`), the 2D local offset vector field (`subpix`) and the spot intensities (`intens`). """ @@ -1098,6 +1100,11 @@ def predict( _subpix = None flow = None + if not skip_details and fit_params: + fit_params = estimate_params(img[...,0], pts) + else: + fit_params = None + if verbose: log.info(f"Found {len(pts)} spots") @@ -1119,8 +1126,9 @@ def predict( ) intens = img[tuple(pts.round().astype(int).T)] details = SimpleNamespace( - prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens - ) + prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens, + fit_params=fit_params + ) return pts, details def predict_dataset( diff --git a/spotiflow/utils/fitting.py b/spotiflow/utils/fitting.py index 424bcff..507d4eb 100644 --- a/spotiflow/utils/fitting.py +++ b/spotiflow/utils/fitting.py @@ -6,6 +6,7 @@ import numpy as np from scipy.optimize import curve_fit + from tqdm.auto import tqdm from dataclasses import dataclass @@ -25,17 +26,31 @@ def _gaussian_2d(yx, y0, x0, sigma, A, B): y, x = yx return A * np.exp(-((y - y0) ** 2 + (x - x0) ** 2) / (2 * sigma**2)) + B +def _gaussian_3d(zyx, z0, y0, x0, sigma, A, B): + z, y, x = zyx + return A * np.exp(-((z - z0) ** 2 + (y - y0) ** 2 + (x - x0) ** 2) / (2 * sigma**2)) + B @dataclass -class SpotParams: +class FitParams2D: + fwhm: Union[float, np.ndarray] + offset_y: Union[float, np.ndarray] + offset_x: Union[float, np.ndarray] + intens_A: Union[float, np.ndarray] + intens_B: Union[float, np.ndarray] + r_squared: Union[float, np.ndarray] + +@dataclass +class FitParams3D: fwhm: Union[float, np.ndarray] + offset_z: Union[float, np.ndarray] offset_y: Union[float, np.ndarray] offset_x: Union[float, np.ndarray] intens_A: Union[float, np.ndarray] intens_B: Union[float, np.ndarray] + r_squared: Union[float, np.ndarray] -def signal_to_background(params: SpotParams) -> np.ndarray: +def signal_to_background(params: FitParams2D) -> np.ndarray: """Calculates the signal to background ratio of the spots. Given a Gaussian fit of the form A*exp(...) + B, the signal to background ratio is computed as A/B. @@ -52,13 +67,35 @@ def signal_to_background(params: SpotParams) -> np.ndarray: return snb +def _r_squared(y_true, y_pred): + y_true, y_pred = np.array(y_true).ravel(), np.array(y_pred).ravel() + ss_res = np.sum((y_true - y_pred)**2) + ss_tot = np.sum((y_true - np.mean(y_true))**2) + r2 = 1 - (ss_res / ss_tot) + return r2 + def _estimate_params_single( center: np.ndarray, image: np.ndarray, window: int, refine_centers: bool, verbose: bool, -) -> SpotParams: +) -> Union[FitParams2D, FitParams3D]: + + if image.ndim == 2: + return _estimate_params_single2(center, image, window, refine_centers, verbose) + elif image.ndim == 3: + return _estimate_params_single3(center, image, window, refine_centers, verbose) + else: + raise ValueError("Image must have 2 or 3 dimensions") + +def _estimate_params_single2( + center: np.ndarray, + image: np.ndarray, + window: int, + refine_centers: bool, + verbose: bool, +) -> FitParams2D: x_range = np.arange(-window, window + 1) y_range = np.arange(-window, window + 1) y, x = np.meshgrid(y_range, x_range, indexing="ij") @@ -89,20 +126,81 @@ def _estimate_params_single( p0=initial_guess, bounds=(lower_bounds, upper_bounds), ) + + pred = _gaussian_2d((y.ravel(), x.ravel()), *popt) + r_squared = _r_squared(region.ravel(), pred) + except Exception as _: if verbose: log.warning("Gaussian fit failed. Returning NaN") mi, ma = np.nan, np.nan popt = np.full(5, np.nan) + r_squared = 0 - return SpotParams( + return FitParams2D( fwhm=FWHM_CONSTANT * popt[2], offset_y=popt[0], offset_x=popt[1], intens_A=(popt[3]+popt[4])*(ma - mi), intens_B=popt[4] * (ma - mi) + mi, + r_squared=r_squared ) +def _estimate_params_single3( + center: np.ndarray, + image: np.ndarray, + window: int, + refine_centers: bool, + verbose: bool, +) -> FitParams3D: + z,y,x = np.meshgrid(*((np.arange(-window, window + 1),)*3), indexing="ij") + + # Crop around the spot + region = image[ + center[0] - window : center[0] + window + 1, + center[1] - window : center[1] + window + 1, + center[2] - window : center[2] + window + 1, + ] + + try: + mi, ma = np.min(region), np.max(region) + region = (region - mi) / (ma - mi) + initial_guess = (0, 0, 0, 1.5, 1, 0) # z0, y0, x0, sigma, A, B + + if refine_centers: + lower_bounds = (-.5, -.5, -.5, 0.1, 0.5, -0.5) # y0, x0, sigma, A, B + upper_bounds = (.5, .5, .5, 10, 1.5, 0.5) # y0, x0, sigma, A, B + else: + lower_bounds = (-1e-6, -1e-6, -1e-6, 0.1, 0.5, -0.5) + upper_bounds = ( 1e-6, 1e-6, 1e-6, 10, 1.5, 0.5) + + popt, _ = curve_fit( + _gaussian_3d, + (z.ravel(), y.ravel(), x.ravel()), + region.ravel(), + p0=initial_guess, + bounds=(lower_bounds, upper_bounds), + ) + + pred = _gaussian_3d((z.ravel(), y.ravel(), x.ravel()), *popt) + r_squared = _r_squared(region.ravel(), pred) + + except Exception as _: + if verbose: + log.warning("Gaussian fit failed. Returning NaN") + mi, ma = np.nan, np.nan + popt = np.full(6, np.nan) + r_squared = 0 + + return FitParams3D( + fwhm=FWHM_CONSTANT * popt[3], + offset_z=popt[0], + offset_y=popt[1], + offset_x=popt[2], + intens_A=(popt[4]+popt[5])*(ma - mi), + intens_B=popt[5] * (ma - mi) + mi, + r_squared=r_squared + ) def estimate_params( img: np.ndarray, @@ -158,9 +256,17 @@ def estimate_params( ) ) - keys = SpotParams.__dataclass_fields__.keys() - - params = SpotParams( - **dict((k, np.array([getattr(p, k) for p in params])) for k in keys) - ) + if img.ndim == 2: + keys = FitParams2D.__dataclass_fields__.keys() + params = FitParams2D( + **dict((k, np.array([getattr(p, k) for p in params])) for k in keys) + ) + elif img.ndim == 3: + keys = FitParams3D.__dataclass_fields__.keys() + params = FitParams3D( + **dict((k, np.array([getattr(p, k) for p in params])) for k in keys) + ) + else: + raise ValueError("Image must have 2 or 3 dimensions") + return params diff --git a/spotiflow/utils/peaks.py b/spotiflow/utils/peaks.py index b36f6f5..6847153 100644 --- a/spotiflow/utils/peaks.py +++ b/spotiflow/utils/peaks.py @@ -1,5 +1,5 @@ from numbers import Number -from typing import Tuple, Union +from typing import Literal, Tuple, Union import numpy as np import numpy as np @@ -166,7 +166,10 @@ def points_to_prob(points, shape, sigma: Union[np.ndarray, float]=1.5, val:Union else: raise ValueError("Wrong dimension of points!") -def points_to_prob2d(points, shape, sigma: Union[np.ndarray, float]=1.5, val: Union[np.ndarray, float]=1., mode:str="max") -> np.ndarray: +def points_to_prob2d(points, shape, + sigma: Union[np.ndarray, float]=1.5, + val: Union[np.ndarray, float]=1., + mode:Literal["max","sum"]="max") -> np.ndarray: """ Create a 2D probability map from a set of points @@ -213,12 +216,21 @@ def points_to_prob2d(points, shape, sigma: Union[np.ndarray, float]=1.5, val: Un np.int32(shape[0]), np.int32(shape[1]), ) + elif mode == "sum": + x = np.zeros(shape, np.float32) + Y, X = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing="ij") + for p, s, v in zip(points, sigma, val): + x += v * np.exp(-((Y - p[0]) ** 2 + (X - p[1]) ** 2) / (2 * s ** 2)) else: raise ValueError(mode) return x -def points_to_prob3d(points, shape, sigma=1.5, mode="max", grid: Union[int, Tuple[int,int,int]]=None): +def points_to_prob3d(points, shape, + sigma: Union[np.ndarray, float]=1.5, + val: Union[np.ndarray, float]=1., + mode:Literal["max","sum"]="max", + grid: Union[int, Tuple[int,int,int]]=None): """points are in (z,y,x) order""" ndim=len(shape) @@ -233,18 +245,36 @@ def points_to_prob3d(points, shape, sigma=1.5, mode="max", grid: Union[int, Tupl if len(points) == 0: return x + + if isinstance(sigma, Number): + sigma = np.ones(len(points), np.float32) * sigma + else: + sigma = np.asarray(sigma, np.float32) + + if isinstance(val, Number): + val = np.ones(len(points), np.float32) * val + else: + val = np.asarray(val, np.float32) + if mode == "max": x = c_gaussian3d( points.astype(np.float32, copy=False), + val.astype(np.float32, copy=False), + sigma.astype(np.float32, copy=False), np.int32(shape[0]), np.int32(shape[1]), np.int32(shape[2]), np.int32(grid[0]), np.int32(grid[1]), np.int32(grid[2]), - np.float32(sigma), ) + elif mode == "sum": + x = np.zeros(shape, np.float32) + Xs = np.stack(np.meshgrid(*(np.arange(s) for s in shape), indexing="ij")) + for p, s, v in zip(points, sigma, val): + x += v * np.exp(- np.sum((Xs - p[:,None,None,None]) ** 2,axis=0) / (2 * s ** 2)) + else: raise ValueError(mode) diff --git a/tests/test_fit.py b/tests/test_fit.py new file mode 100644 index 0000000..9b18ef5 --- /dev/null +++ b/tests/test_fit.py @@ -0,0 +1,47 @@ + +import numpy as np +from spotiflow.utils import points_to_prob, estimate_params +from spotiflow.model import Spotiflow + +def test_fit2d(): + np.random.seed(42) + + n_points=64 + points = np.random.randint(20,245-20, (n_points,2)) + sigmas = np.random.uniform(1, 5, n_points) + + x = points_to_prob(points, (256,256), sigma=sigmas, mode='sum') + + x += .2+0.05*np.random.normal(0, 1, x.shape) + + params = estimate_params(x, points) + + return x, sigmas, params + +def test_fit3d(): + + np.random.seed(42) + ndim=3 + + n_points=64 + points = np.random.randint(20,128-20, (n_points,ndim)) + sigmas = np.random.uniform(1, 5, n_points) + + x = points_to_prob(points, (128,)*ndim, sigma=sigmas, mode='sum') + + x += .2+0.05*np.random.normal(0, 1, x.shape) + + params = estimate_params(x, points) + return x, sigmas, params + +if __name__ == "__main__": + + + x, sigmas, params = test_fit3d() + + + model = Spotiflow.from_pretrained("synth_3d") + + img = np.clip(200*x, 0,255).astype(np.uint8) + + points, details = model.predict(img, fit_params=True) \ No newline at end of file