Skip to content

Commit

Permalink
fix flow correction in tiled inference, simplify intensity interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Nov 12, 2024
1 parent 56dd759 commit 6e472e3
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 74 deletions.
17 changes: 7 additions & 10 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ..augmentations.pipeline import Pipeline as AugmentationPipeline
from ..data import Spots3DDataset, SpotsDataset
from ..utils import (
bilinear_interp_points,
center_crop,
center_pad,
filter_shape,
Expand All @@ -34,8 +33,9 @@
normalize_dask,
points_matching_dataset,
prob_to_points,
spline_interp_points_2d,
spline_interp_points_3d,
subpixel_offset,
trilinear_interp_points,
estimate_params
)
from ..utils import (
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def predict(

y_tile_sub = y_tile[s_src_corr[:actual_n_dims]]
probs += y_tile_sub[tuple(p.astype(int).T)].tolist()

p_flow = p + np.array([s.start for s in s_src_corr[:actual_n_dims]])[None]
# add global offset
p += np.array([s.start for s in s_dst_corr[:actual_n_dims]])[None]
if not skip_details:
Expand All @@ -1058,7 +1058,7 @@ def predict(
# Cartesian coordinates
subpix_tile = flow_to_vector(flow_tile, sigma=self.config.sigma)
_offset = subpixel_offset(
p, subpix_tile, y_tile, radius=subpix_radius
p_flow, subpix_tile, y_tile, radius=subpix_radius
)

p = p + _offset
Expand Down Expand Up @@ -1119,14 +1119,11 @@ def predict(
intens = img[tuple(pts.astype(int).T)]
else:
try:
intens = (
bilinear_interp_points(img, pts)
if not self.config.is_3d
else trilinear_interp_points(img, pts)
)
_interp_fun = spline_interp_points_2d if not self.config.is_3d else spline_interp_points_3d
intens = _interp_fun(img, pts)
except Exception as _:
log.warning(
"Bilinear interpolation failed to retrive spot intensities. Will use nearest neighbour interpolation instead."
"Spline interpolation failed to retrieve spot intensities. Will use nearest neighbour interpolation instead."
)
intens = img[tuple(pts.round().astype(int).T)]
details = SimpleNamespace(
Expand Down
114 changes: 50 additions & 64 deletions spotiflow/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ def subpixel_offset_2d(
and subpix.shape[2] == 2
and prob.ndim == 2
)
subpix = np.clip(subpix, -1, 1)
subpix_cp = subpix.copy()
subpix_cp[np.linalg.norm(subpix_cp, axis=-1) > np.sqrt(2)] = 0
n, _ = pts.shape
_weight = np.zeros((n, 1), np.float32)
_add = np.zeros((n, 2), np.float32)
Expand All @@ -528,7 +529,7 @@ def subpixel_offset_2d(
_w[mask] = prob[_p][:, None]

_correct = np.zeros((n, 2), np.float32)
_correct[mask] = subpix[_p] + dp
_correct[mask] = subpix_cp[_p] + dp

_weight += _w
_add += _w * _correct
Expand Down Expand Up @@ -564,7 +565,8 @@ def subpixel_offset_3d(
and subpix.shape[3] == 3
and prob.ndim == 3
)
subpix = np.clip(subpix, -1, 1)
subpix_cp = subpix.copy()
subpix_cp[np.linalg.norm(subpix, axis=-1) > np.sqrt(3)] = 0
n, _ = pts.shape
_weight = np.zeros((n, 1), np.float32)
_add = np.zeros((n, 3), np.float32)
Expand Down Expand Up @@ -646,62 +648,54 @@ def read_npz_dataset(fname: Union[Path, str]) -> Tuple[np.ndarray, ...]:
return ret_data


def bilinear_interp_points(
img: np.ndarray, pts: np.ndarray, eps: float = 1e-9
def spline_interp_points_2d(
img: np.ndarray, pts: np.ndarray, order: int = 1
) -> np.ndarray:
"""Return the bilinearly interpolated iamge intensities at each (subpixel) location.
"""Return the spline-interpolated 2D image intensities at each (subpixel) location.
Args:
img (np.ndarray): image in YX or YXC format.
pts (np.ndarray): spot locations to interpolate the intensities from. Array shape should be (N,2).
eps (float, optional): will clip spot locations to SHAPE-eps to avoid numerical issues at image border. Defaults to 1e-9.
order (int, optional): order of the spline interpolation. Defaults to 1 (linear).
Returns:
np.ndarray: array of shape (N,C) containing intensities for each spot
"""
assert img.ndim in (2, 3), "Expected YX or YXC image for interpolating intensities."
assert img.ndim in (2,3), "Expected YX or YXC image for interpolating intensities."
assert (
pts.shape[1] == 2
), "Point coordinates to be interpolated should be an (N,2) array"

if img.ndim == 2:
img = img[..., None]

if pts.shape[0] == 0:
return np.zeros((0, img.shape[-1]), dtype=img.dtype)
ys, xs = pts[:, 0], pts[:, 1]

# Avoid out of bounds coordinates
ys.clip(0, img.shape[0] - 1 - eps, out=ys)
xs.clip(0, img.shape[1] - 1 - eps, out=xs)

pys = np.floor(ys).astype(int)
pxs = np.floor(xs).astype(int)

# Differences to floored coordinates
dys = ys - pys
dxs = xs - pxs
wxs, wys = 1.0 - dxs, 1.0 - dys
if img.ndim == 2:
out_shape = (0,)
else:
out_shape = (0, img.shape[-1])
return np.zeros(out_shape, dtype=img.dtype)

# Interpolate
weights = np.multiply(img[pys, pxs, :].T, wxs * wys).T
weights += np.multiply(img[pys, pxs + 1, :].T, dxs * wys).T
weights += np.multiply(img[pys + 1, pxs, :].T, wxs * dys).T
weights += np.multiply(img[pys + 1, pxs + 1, :].T, dxs * dys).T
return weights
y_coords = pts[:,0]
x_coords = pts[:,1]
if img.ndim == 3:
intensities = np.stack([
ndi.map_coordinates(img[..., c], [y_coords, x_coords], order=order, mode='reflect', prefilter=False)
for c in range(img.shape[2])
], axis=-1)
else:
intensities = ndi.map_coordinates(img, [y_coords, x_coords], order=order, mode='reflect', prefilter=False)
return intensities


def trilinear_interp_points(
img: np.ndarray, pts: np.ndarray, eps: float = 1e-9
def spline_interp_points_3d(
img: np.ndarray, pts: np.ndarray, order: int = 1
) -> np.ndarray:
"""Return the trilinearly interpolated iamge intensities at each (subpixel) location.
"""Return the spline-interpolated 3D image intensities at each (subpixel) location.
Args:
img (np.ndarray): image in ZYX or ZYXC format.
pts (np.ndarray): spot locations to interpolate the intensities from. Array shape should be (N,3).
eps (float, optional): will clip spot locations to SHAPE-eps to avoid numerical issues at image border. Defaults to 1e-9.
order (int, optional): order of the spline interpolation. Defaults to 1 (linear).
Returns:
np.ndarray: array of shape (N,C) containing intensities for each spot
Expand All @@ -714,35 +708,27 @@ def trilinear_interp_points(
pts.shape[1] == 3
), "Point coordinates to be interpolated should be an (N,3) array"

if img.ndim == 3:
img = img[..., None]

if pts.shape[0] == 0:
return np.zeros((0, img.shape[-1]), dtype=img.dtype)
if img.ndim == 3:
out_shape = (0,)
else:
out_shape = (0, img.shape[-1])
return np.zeros(out_shape, dtype=img.dtype)

zs, ys, xs = pts[:, 0], pts[:, 1], pts[:, 2]

# Avoid out of bounds coordinates
zs.clip(0, img.shape[0] - 1 - eps, out=zs)
ys.clip(0, img.shape[1] - 1 - eps, out=ys)
xs.clip(0, img.shape[2] - 1 - eps, out=xs)

pzs = np.floor(zs).astype(int)
pys = np.floor(ys).astype(int)
pxs = np.floor(xs).astype(int)

# Differences to floored coordinates
dzs = zs - pzs
dys = ys - pys
dxs = xs - pxs
wzx, wzy, wys = 1.0 - dxs, 1.0 - dys, 1.0 - dzs

# Interpolate
weights = np.multiply(img[pzs, pys, pxs, :].T, wzx * wzy * wys).T
weights += np.multiply(img[pzs, pys, pxs + 1, :].T, dxs * wzy * wys).T
weights += np.multiply(img[pzs, pys + 1, pxs, :].T, wzx * dys * wys).T
weights += np.multiply(img[pzs, pys + 1, pxs + 1, :].T, dxs * dys * wys).T
weights += np.multiply(img[pzs + 1, pys, pxs, :].T, wzx * wzy * dzs).T
weights += np.multiply(img[pzs + 1, pys, pxs + 1, :].T, dxs * wzy * dzs).T
weights += np.multiply(img[pzs + 1, pys + 1, pxs, :].T, wzx * dys * dzs).T
weights += np.multiply(img[pzs + 1, pys + 1, pxs + 1, :].T, dxs * dys * dzs).T
return weights
if img.ndim == 4:
intensities = np.stack(
[
ndi.map_coordinates(
img[..., c], [zs, ys, xs], order=order, mode="reflect", prefilter=False
)
for c in range(img.shape[3])
],
axis=-1,
)
else:
intensities = ndi.map_coordinates(
img, [zs, ys, xs], order=1, mode="reflect", prefilter=False
)
return intensities

0 comments on commit 6e472e3

Please sign in to comment.