From 49146db5cfd4220e7158d07c11015497804ef60c Mon Sep 17 00:00:00 2001 From: hanjinliu Date: Sun, 9 Jan 2022 16:58:53 +0900 Subject: [PATCH] typing --- impy/arrays/bases/overload.py | 22 ++++----- impy/arrays/imgarray.py | 93 ++++++++++++++++++----------------- impy/correlation.py | 68 +++++++++++++++++++------ 3 files changed, 111 insertions(+), 72 deletions(-) diff --git a/impy/arrays/bases/overload.py b/impy/arrays/bases/overload.py index 8d43980d..63f48999 100644 --- a/impy/arrays/bases/overload.py +++ b/impy/arrays/bases/overload.py @@ -6,7 +6,7 @@ from ...utils.axesop import del_axis from ...collections import DataList -def safe_set_info(out, img, history, new_axes): +def safe_set_info(out: MetaArray, img: MetaArray, history, new_axes): if isinstance(img, HistoryArray): out._set_info(img, history, new_axes=new_axes) else: @@ -22,14 +22,14 @@ def safe_set_info(out, img, history, new_axes): @MetaArray.implements(np.squeeze) -def _(img:MetaArray): +def _(img: MetaArray): out = np.squeeze(img.value).view(img.__class__) new_axes = "".join(a for a in img.axes if img.sizeof(a) > 1) safe_set_info(out, img, "squeeze", new_axes) return out @MetaArray.implements(np.take) -def _(a:MetaArray, indices, axis=None, out=None, mode="raise"): +def _(a: MetaArray, indices, axis=None, out=None, mode="raise"): new_axes = del_axis(a.axes, axis) if isinstance(axis, str): axis = a.axes.find(axis) @@ -39,7 +39,7 @@ def _(a:MetaArray, indices, axis=None, out=None, mode="raise"): return out @MetaArray.implements(np.stack) -def _(imgs:list[MetaArray], axis="c", dtype=None): +def _(imgs: list[MetaArray], axis="c", dtype=None): """ Create stack image from list of images. @@ -97,7 +97,7 @@ def _(imgs:list[MetaArray], axis="c", dtype=None): return out @MetaArray.implements(np.concatenate) -def _(imgs, axis="c", dtype=None, casting="same_kind"): +def _(imgs: list[MetaArray], axis="c", dtype=None, casting="same_kind"): if not isinstance(axis, (int, str)): raise TypeError(f"`axis` must be int or str, but got {type(axis)}") axis = imgs[0].axisof(axis) @@ -107,7 +107,7 @@ def _(imgs, axis="c", dtype=None, casting="same_kind"): return out @MetaArray.implements(np.block) -def _(imgs): +def _(imgs: list[MetaArray]): def _recursive_view(obj): if isinstance(obj, MetaArray): return obj.value @@ -130,7 +130,7 @@ def _recursive_get0(obj): @MetaArray.implements(np.zeros_like) -def _(img, name:str=None): +def _(img: MetaArray, name: str = None): out = np.zeros_like(img.value).view(img.__class__) out._set_info(img, new_axes=img.axes) if isinstance(name, str): @@ -138,7 +138,7 @@ def _(img, name:str=None): return out @MetaArray.implements(np.empty_like) -def _(img, name:str=None): +def _(img: MetaArray, name:str = None): out = np.empty_like(img.value).view(img.__class__) out._set_info(img, new_axes=img.axes) if isinstance(name, str): @@ -146,7 +146,7 @@ def _(img, name:str=None): return out @MetaArray.implements(np.expand_dims) -def _(img, axis): +def _(img: MetaArray, axis): if isinstance(axis, str): new_axes = Axes(axis + str(img.axes)) new_axes.sort() @@ -160,11 +160,11 @@ def _(img, axis): return out @MetaArray.implements(np.transpose) -def _(img, axes): +def _(img: MetaArray, axes: str): return img.transpose(axes) @MetaArray.implements(np.split) -def _(img, indices_or_sections, axis=0): +def _(img: MetaArray, indices_or_sections, axis=0): if not isinstance(axis, (int, str)): raise TypeError(f"`axis` must be int or str, but got {type(axis)}") axis = img.axisof(axis) diff --git a/impy/arrays/imgarray.py b/impy/arrays/imgarray.py index e11be595..7da01348 100644 --- a/impy/arrays/imgarray.py +++ b/impy/arrays/imgarray.py @@ -1591,7 +1591,6 @@ def log_filter(self, sigma: nDFloat = 1, *, dims: Dims = None) -> ImgArray: args=(sigma,) ) - @_docs.write_docs @dims_to_spatial_axes @same_dtype(asfloat=True) @@ -1644,8 +1643,8 @@ def rolling_ball(self, radius: float = 30, prefilter:str="mean", *, return_bg:bo @dims_to_spatial_axes @same_dtype(asfloat=True) @record - def rof_filter(self, lmd:float=0.05, tol:float=1e-4, max_iter:int=50, *, dims: Dims = None, - update: bool = False) -> ImgArray: + def rof_filter(self, lmd: float = 0.05, tol: float = 1e-4, max_iter: int = 50, *, + dims: Dims = None, update: bool = False) -> ImgArray: """ Rudin-Osher-Fatemi's total variation denoising. @@ -2765,41 +2764,8 @@ def local_dft(self, key: str = "", upsample_factor: nDInt = 1, *, double_precisi key += f";{a}=0:{self.sizeof(a)}" if key.startswith(";"): key = key[1:] - slices = axis_targeted_slicing(np.empty((1,)*ndim), dims, key) - - def wave_num(sl: slice, s: int, uf: int) -> xp.ndarray: - """ - A function that makes wave number vertical vector. Returned vector will - be [k/s, (k + 1/uf)/s, (k + 2/uf)/s, ...] (where k = sl.start) - - Parameters - ---------- - sl : slice - Slice that specify which part of the image will be transformed. - s : int - Size along certain dimension. - uf : int - Up-sampling factor of certain dimension. - """ - start = 0 if sl.start is None else sl.start - stop = s if sl.stop is None else sl.stop - if sl.start and sl.stop and start < 0 and stop > 0: - # like "x=-5:5" - pass - else: - if -s < start < 0: - start += s - elif not 0 <= start < s: - raise ValueError(f"Invalid value encountered in key {key}.") - if -s < stop <= 0: - stop += s - elif not 0 < stop <= s: - raise ValueError(f"Invalid value encountered in key {key}.") - - n = stop - start - return xp.linspace(start, stop, n*uf, endpoint=False)[:, xp.newaxis] - + slices = axis_targeted_slicing(np.empty((1,)*ndim), dims, key) dtype = np.complex128 if double_precision else np.complex64 # Calculate exp(-ikx) @@ -3242,7 +3208,8 @@ def skeletonize(self, radius: float = 0, *, dims: Dims = None, update=False) -> @_docs.write_docs @dims_to_spatial_axes @record(only_binary=True) - def count_neighbors(self, *, connectivity:int=None, mask:bool=True, dims: Dims = None) -> ImgArray: + def count_neighbors(self, *, connectivity: int = None, mask: bool = True, + dims: Dims = None) -> ImgArray: """ Count the number or neighbors of binary images. This function can be used for cross section or branch detection. Only works for binary images. @@ -3282,7 +3249,7 @@ def count_neighbors(self, *, connectivity:int=None, mask:bool=True, dims: Dims = @_docs.write_docs @dims_to_spatial_axes @record(only_binary=True) - def remove_skeleton_structure(self, structure:str="tip", *, connectivity:int=None, + def remove_skeleton_structure(self, structure: str = "tip", *, connectivity: int = None, dims: Dims = None, update: bool = False) -> ImgArray: """ Remove certain structure from skeletonized images. @@ -3412,8 +3379,8 @@ def lineprops(self, src: Coords, dst: Coords, func: str|Callable = "mean", *, @dims_to_spatial_axes @record(need_labels=True) - def watershed(self, coords:MarkerFrame=None, *, connectivity:int=1, input:str="distance", - min_distance:float=2, dims: Dims = None) -> Label: + def watershed(self, coords: MarkerFrame = None, *, connectivity: int = 1, input: str = "distance", + min_distance: float = 2, dims: Dims = None) -> Label: """ Label segmentation using watershed algorithm. @@ -3502,7 +3469,7 @@ def random_walker(self, beta: float = 130, mode: str = "cg_j", tol: float = 1e-3 self.labels._set_info(self, "random_walker") return self.labels - def label_threshold(self, thr: float|str = "otsu", *, dims: Dims = None, **kwargs) -> Label: + def label_threshold(self, thr: float | str = "otsu", *, dims: Dims = None, **kwargs) -> Label: """ Make labels with threshold(). Be sure that keyword argument ``dims`` can be different (in most cases for >4D images) between threshold() and label(). @@ -3589,8 +3556,8 @@ def pathprops(self, paths: PathFrame, properties: str|Callable|Iterable[str|Call return out @record(append_history=False, need_labels=True) - def regionprops(self, properties: tuple[str,...]|str = ("mean_intensity",), *, - extra_properties: Iterable[Callable]|None = None) -> DataDict[str, PropArray]: + def regionprops(self, properties: Iterable[str] | str = ("mean_intensity",), *, + extra_properties: Iterable[Callable] | None = None) -> DataDict[str, PropArray]: """ Run skimage's regionprops() function and return the results as PropArray, so that you can access using flexible slicing. For example, if a tcyx-image is @@ -3976,7 +3943,7 @@ def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *, zero_a @_docs.write_docs @dims_to_spatial_axes @record(append_history=False) - def estimate_sigma(self, *, squeeze: bool = True, dims: Dims = None) -> PropArray|float: + def estimate_sigma(self, *, squeeze: bool = True, dims: Dims = None) -> PropArray | float: """ Wavelet-based estimation of Gaussian noise. @@ -4377,4 +4344,38 @@ def check_filter_func(f): f = lambda x: True elif not callable(f): raise TypeError("`filt` must be callable.") - return f \ No newline at end of file + return f + + +def wave_num(sl: slice, s: int, uf: int) -> xp.ndarray: + """ + A function that makes wave number vertical vector. Returned vector will + be [k/s, (k + 1/uf)/s, (k + 2/uf)/s, ...] (where k = sl.start) + + Parameters + ---------- + sl : slice + Slice that specify which part of the image will be transformed. + s : int + Size along certain dimension. + uf : int + Up-sampling factor of certain dimension. + """ + start = 0 if sl.start is None else sl.start + stop = s if sl.stop is None else sl.stop + + if sl.start and sl.stop and start < 0 and stop > 0: + # like "x=-5:5" + pass + else: + if -s < start < 0: + start += s + elif not 0 <= start < s: + raise ValueError(f"Invalid value encountered in slice {sl}.") + if -s < stop <= 0: + stop += s + elif not 0 < stop <= s: + raise ValueError(f"Invalid value encountered in slice {sl}.") + + n = stop - start + return xp.linspace(start, stop, n*uf, endpoint=False)[:, xp.newaxis] \ No newline at end of file diff --git a/impy/correlation.py b/impy/correlation.py index 9424f462..bb83ba62 100644 --- a/impy/correlation.py +++ b/impy/correlation.py @@ -16,8 +16,13 @@ @_docs.write_docs @dims_to_spatial_axes -def fsc(img0: ImgArray, img1: ImgArray, nbin: int = 32, r_max: float = None, *, - squeeze: bool = True, dims: Dims = None) -> PropArray: +def fsc(img0: ImgArray, + img1: ImgArray, + nbin: int = 32, + r_max: float = None, + *, + squeeze: bool = True, + dims: Dims = None) -> PropArray: r""" Calculate Fourier Shell Correlation (FSC; or Fourier Ring Correlation, FRC, for 2-D images) between two images. FSC is defined as: @@ -109,6 +114,7 @@ def _ncc(img0: ImgArray, img1: ImgArray, dims: Dims): xp.std(img0, axis=dims)*xp.std(img1, axis=dims)) / n return asnumpy(corr) + def _masked_ncc(img0: ImgArray, img1: ImgArray, dims: Dims, mask: ImgArray): if mask.ndim < img0.ndim: mask = add_axes(img0.axes, img0.shape, mask, mask.axes) @@ -119,6 +125,7 @@ def _masked_ncc(img0: ImgArray, img1: ImgArray, dims: Dims, mask: ImgArray): return np.ma.sum(img0ma * img1ma, axis=axis) / ( np.ma.std(img0ma, axis=axis)*np.ma.std(img1ma, axis=axis)) / n + def _zncc(img0: ImgArray, img1: ImgArray, dims: Dims): # Basic Zero-Normalized Cross Correlation with batch processing. # Inputs must be already zero-normalized. @@ -130,6 +137,7 @@ def _zncc(img0: ImgArray, img1: ImgArray, dims: Dims): xp.sqrt(xp.sum(img0**2, axis=dims)*xp.sum(img1**2, axis=dims))) return asnumpy(corr) + def _masked_zncc(img0: ImgArray, img1: ImgArray, dims: Dims, mask: ImgArray): if mask.ndim < img0.ndim: mask = add_axes(img0.axes, img0.shape, mask, mask.axes) @@ -141,7 +149,11 @@ def _masked_zncc(img0: ImgArray, img1: ImgArray, dims: Dims, mask: ImgArray): @_docs.write_docs @dims_to_spatial_axes -def ncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, squeeze: bool = True, *, +def ncc(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, + squeeze: bool = True, + *, dims: Dims = None) -> PropArray | float: """ Normalized Cross Correlation. @@ -170,7 +182,11 @@ def ncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, squeeze: b @_docs.write_docs @dims_to_spatial_axes -def zncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, squeeze: bool = True, *, +def zncc(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, + squeeze: bool = True, + *, dims: Dims = None) -> PropArray | float: """ Zero-Normalized Cross Correlation. @@ -204,8 +220,13 @@ def zncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, squeeze: @_docs.write_docs @dims_to_spatial_axes -def nmi(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, bins: int = 100, - squeeze: bool = True, *, dims: Dims = None) -> PropArray | float: +def nmi(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, + bins: int = 100, + squeeze: bool = True, + *, + dims: Dims = None) -> PropArray | float: r""" Normalized Mutual Information. @@ -229,7 +250,8 @@ def nmi(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, bins: int PropArray or float Correlation value(s). """ - entropy = scipy.stats.entropy + from scipy.stats import entropy + img0, img1 = _check_inputs(img0, img1) c_axes = complement_axes(dims, img0.axes) out = np.empty(img0.sizesof(c_axes), dtype=np.float32) @@ -248,7 +270,11 @@ def nmi(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, bins: int @_docs.write_docs @dims_to_spatial_axes -def fourier_ncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, squeeze: bool = True, *, +def fourier_ncc(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, + squeeze: bool = True, + *, dims: Dims = None) -> PropArray | float: """ Normalized Cross Correlation in Fourier space. @@ -279,8 +305,12 @@ def fourier_ncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, sq @_docs.write_docs @dims_to_spatial_axes -def fourier_zncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, squeeze: bool = True, - *, dims: Dims = None) -> PropArray | float: +def fourier_zncc(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, + squeeze: bool = True, + *, + dims: Dims = None) -> PropArray | float: """ Zero-Normalized Cross Correlation in Fourier space. @@ -311,7 +341,9 @@ def fourier_zncc(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, s return _make_corr_output(corr, img0, "fourier_zncc", squeeze, dims) @_docs.write_docs -def pcc_maximum(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, +def pcc_maximum(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, upsample_factor: int = 10) -> np.ndarray: """ Calculate lateral shift between two images. Same as ``skimage.registration.phase_cross_correlation``. @@ -339,7 +371,9 @@ def pcc_maximum(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, return asnumpy(shift) @_docs.write_docs -def ft_pcc_maximum(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, +def ft_pcc_maximum(img0: ImgArray, + img1: ImgArray, + mask: ImgArray | None = None, upsample_factor: int = 10) -> np.ndarray: """ Calculate lateral shift between two images. @@ -367,7 +401,11 @@ def ft_pcc_maximum(img0: ImgArray, img1: ImgArray, mask: ImgArray | None = None, @_docs.write_docs @dims_to_spatial_axes -def manders_coloc(img0: ImgArray, img1: np.ndarray, *, squeeze: bool = True, dims: Dims = None) -> PropArray|float: +def manders_coloc(img0: ImgArray, + img1: np.ndarray, + *, + squeeze: bool = True, + dims: Dims = None) -> PropArray | float: r""" Manders' correlation coefficient. This is defined as following: @@ -404,12 +442,12 @@ def manders_coloc(img0: ImgArray, img1: np.ndarray, *, squeeze: bool = True, dim return _make_corr_output(coeff, img0, "manders_coloc", squeeze, dims) -def iter2(img0:ImgArray, img1:ImgArray, axes:str, israw:bool=False, exclude:str=""): +def iter2(img0: ImgArray, img1: ImgArray, axes: str, israw: bool = False, exclude: str = ""): for (sl, i0), (sl, i1) in zip(img0.iter(axes, israw=israw, exclude=exclude), img1.iter(axes, israw=israw, exclude=exclude)): yield sl, i0, i1 -def _check_inputs(img0:ImgArray, img1:ImgArray): +def _check_inputs(img0: ImgArray, img1: ImgArray): _check_dimensions(img0, img1) img0 = img0.as_float()