From 6e472e3ed0d682668ab289f3b41282e983c2911c Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Tue, 12 Nov 2024 14:09:19 +0100 Subject: [PATCH] fix flow correction in tiled inference, simplify intensity interpolation --- spotiflow/model/spotiflow.py | 17 +++--- spotiflow/utils/utils.py | 114 +++++++++++++++-------------------- 2 files changed, 57 insertions(+), 74 deletions(-) diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 9d1bb0d..f19cfb3 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -24,7 +24,6 @@ from ..augmentations.pipeline import Pipeline as AugmentationPipeline from ..data import Spots3DDataset, SpotsDataset from ..utils import ( - bilinear_interp_points, center_crop, center_pad, filter_shape, @@ -34,8 +33,9 @@ normalize_dask, points_matching_dataset, prob_to_points, + spline_interp_points_2d, + spline_interp_points_3d, subpixel_offset, - trilinear_interp_points, estimate_params ) from ..utils import ( @@ -1037,7 +1037,7 @@ def predict( y_tile_sub = y_tile[s_src_corr[:actual_n_dims]] probs += y_tile_sub[tuple(p.astype(int).T)].tolist() - + p_flow = p + np.array([s.start for s in s_src_corr[:actual_n_dims]])[None] # add global offset p += np.array([s.start for s in s_dst_corr[:actual_n_dims]])[None] if not skip_details: @@ -1058,7 +1058,7 @@ def predict( # Cartesian coordinates subpix_tile = flow_to_vector(flow_tile, sigma=self.config.sigma) _offset = subpixel_offset( - p, subpix_tile, y_tile, radius=subpix_radius + p_flow, subpix_tile, y_tile, radius=subpix_radius ) p = p + _offset @@ -1119,14 +1119,11 @@ def predict( intens = img[tuple(pts.astype(int).T)] else: try: - intens = ( - bilinear_interp_points(img, pts) - if not self.config.is_3d - else trilinear_interp_points(img, pts) - ) + _interp_fun = spline_interp_points_2d if not self.config.is_3d else spline_interp_points_3d + intens = _interp_fun(img, pts) except Exception as _: log.warning( - "Bilinear interpolation failed to retrive spot intensities. Will use nearest neighbour interpolation instead." + "Spline interpolation failed to retrieve spot intensities. Will use nearest neighbour interpolation instead." ) intens = img[tuple(pts.round().astype(int).T)] details = SimpleNamespace( diff --git a/spotiflow/utils/utils.py b/spotiflow/utils/utils.py index 61dd7f2..bf7f49c 100644 --- a/spotiflow/utils/utils.py +++ b/spotiflow/utils/utils.py @@ -513,7 +513,8 @@ def subpixel_offset_2d( and subpix.shape[2] == 2 and prob.ndim == 2 ) - subpix = np.clip(subpix, -1, 1) + subpix_cp = subpix.copy() + subpix_cp[np.linalg.norm(subpix_cp, axis=-1) > np.sqrt(2)] = 0 n, _ = pts.shape _weight = np.zeros((n, 1), np.float32) _add = np.zeros((n, 2), np.float32) @@ -528,7 +529,7 @@ def subpixel_offset_2d( _w[mask] = prob[_p][:, None] _correct = np.zeros((n, 2), np.float32) - _correct[mask] = subpix[_p] + dp + _correct[mask] = subpix_cp[_p] + dp _weight += _w _add += _w * _correct @@ -564,7 +565,8 @@ def subpixel_offset_3d( and subpix.shape[3] == 3 and prob.ndim == 3 ) - subpix = np.clip(subpix, -1, 1) + subpix_cp = subpix.copy() + subpix_cp[np.linalg.norm(subpix, axis=-1) > np.sqrt(3)] = 0 n, _ = pts.shape _weight = np.zeros((n, 1), np.float32) _add = np.zeros((n, 3), np.float32) @@ -646,62 +648,54 @@ def read_npz_dataset(fname: Union[Path, str]) -> Tuple[np.ndarray, ...]: return ret_data -def bilinear_interp_points( - img: np.ndarray, pts: np.ndarray, eps: float = 1e-9 +def spline_interp_points_2d( + img: np.ndarray, pts: np.ndarray, order: int = 1 ) -> np.ndarray: - """Return the bilinearly interpolated iamge intensities at each (subpixel) location. + """Return the spline-interpolated 2D image intensities at each (subpixel) location. Args: img (np.ndarray): image in YX or YXC format. pts (np.ndarray): spot locations to interpolate the intensities from. Array shape should be (N,2). - eps (float, optional): will clip spot locations to SHAPE-eps to avoid numerical issues at image border. Defaults to 1e-9. + order (int, optional): order of the spline interpolation. Defaults to 1 (linear). Returns: np.ndarray: array of shape (N,C) containing intensities for each spot """ - assert img.ndim in (2, 3), "Expected YX or YXC image for interpolating intensities." + assert img.ndim in (2,3), "Expected YX or YXC image for interpolating intensities." assert ( pts.shape[1] == 2 ), "Point coordinates to be interpolated should be an (N,2) array" - if img.ndim == 2: - img = img[..., None] - if pts.shape[0] == 0: - return np.zeros((0, img.shape[-1]), dtype=img.dtype) - ys, xs = pts[:, 0], pts[:, 1] - - # Avoid out of bounds coordinates - ys.clip(0, img.shape[0] - 1 - eps, out=ys) - xs.clip(0, img.shape[1] - 1 - eps, out=xs) - - pys = np.floor(ys).astype(int) - pxs = np.floor(xs).astype(int) - - # Differences to floored coordinates - dys = ys - pys - dxs = xs - pxs - wxs, wys = 1.0 - dxs, 1.0 - dys + if img.ndim == 2: + out_shape = (0,) + else: + out_shape = (0, img.shape[-1]) + return np.zeros(out_shape, dtype=img.dtype) - # Interpolate - weights = np.multiply(img[pys, pxs, :].T, wxs * wys).T - weights += np.multiply(img[pys, pxs + 1, :].T, dxs * wys).T - weights += np.multiply(img[pys + 1, pxs, :].T, wxs * dys).T - weights += np.multiply(img[pys + 1, pxs + 1, :].T, dxs * dys).T - return weights + y_coords = pts[:,0] + x_coords = pts[:,1] + if img.ndim == 3: + intensities = np.stack([ + ndi.map_coordinates(img[..., c], [y_coords, x_coords], order=order, mode='reflect', prefilter=False) + for c in range(img.shape[2]) + ], axis=-1) + else: + intensities = ndi.map_coordinates(img, [y_coords, x_coords], order=order, mode='reflect', prefilter=False) + return intensities -def trilinear_interp_points( - img: np.ndarray, pts: np.ndarray, eps: float = 1e-9 +def spline_interp_points_3d( + img: np.ndarray, pts: np.ndarray, order: int = 1 ) -> np.ndarray: - """Return the trilinearly interpolated iamge intensities at each (subpixel) location. + """Return the spline-interpolated 3D image intensities at each (subpixel) location. Args: img (np.ndarray): image in ZYX or ZYXC format. pts (np.ndarray): spot locations to interpolate the intensities from. Array shape should be (N,3). - eps (float, optional): will clip spot locations to SHAPE-eps to avoid numerical issues at image border. Defaults to 1e-9. + order (int, optional): order of the spline interpolation. Defaults to 1 (linear). Returns: np.ndarray: array of shape (N,C) containing intensities for each spot @@ -714,35 +708,27 @@ def trilinear_interp_points( pts.shape[1] == 3 ), "Point coordinates to be interpolated should be an (N,3) array" - if img.ndim == 3: - img = img[..., None] - if pts.shape[0] == 0: - return np.zeros((0, img.shape[-1]), dtype=img.dtype) + if img.ndim == 3: + out_shape = (0,) + else: + out_shape = (0, img.shape[-1]) + return np.zeros(out_shape, dtype=img.dtype) + zs, ys, xs = pts[:, 0], pts[:, 1], pts[:, 2] - # Avoid out of bounds coordinates - zs.clip(0, img.shape[0] - 1 - eps, out=zs) - ys.clip(0, img.shape[1] - 1 - eps, out=ys) - xs.clip(0, img.shape[2] - 1 - eps, out=xs) - - pzs = np.floor(zs).astype(int) - pys = np.floor(ys).astype(int) - pxs = np.floor(xs).astype(int) - - # Differences to floored coordinates - dzs = zs - pzs - dys = ys - pys - dxs = xs - pxs - wzx, wzy, wys = 1.0 - dxs, 1.0 - dys, 1.0 - dzs - - # Interpolate - weights = np.multiply(img[pzs, pys, pxs, :].T, wzx * wzy * wys).T - weights += np.multiply(img[pzs, pys, pxs + 1, :].T, dxs * wzy * wys).T - weights += np.multiply(img[pzs, pys + 1, pxs, :].T, wzx * dys * wys).T - weights += np.multiply(img[pzs, pys + 1, pxs + 1, :].T, dxs * dys * wys).T - weights += np.multiply(img[pzs + 1, pys, pxs, :].T, wzx * wzy * dzs).T - weights += np.multiply(img[pzs + 1, pys, pxs + 1, :].T, dxs * wzy * dzs).T - weights += np.multiply(img[pzs + 1, pys + 1, pxs, :].T, wzx * dys * dzs).T - weights += np.multiply(img[pzs + 1, pys + 1, pxs + 1, :].T, dxs * dys * dzs).T - return weights + if img.ndim == 4: + intensities = np.stack( + [ + ndi.map_coordinates( + img[..., c], [zs, ys, xs], order=order, mode="reflect", prefilter=False + ) + for c in range(img.shape[3]) + ], + axis=-1, + ) + else: + intensities = ndi.map_coordinates( + img, [zs, ys, xs], order=1, mode="reflect", prefilter=False + ) + return intensities