Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Jan 9, 2022
1 parent b100171 commit 49146db
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 72 deletions.
22 changes: 11 additions & 11 deletions impy/arrays/bases/overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...utils.axesop import del_axis
from ...collections import DataList

def safe_set_info(out, img, history, new_axes):
def safe_set_info(out: MetaArray, img: MetaArray, history, new_axes):
if isinstance(img, HistoryArray):
out._set_info(img, history, new_axes=new_axes)
else:
Expand All @@ -22,14 +22,14 @@ def safe_set_info(out, img, history, new_axes):


@MetaArray.implements(np.squeeze)
def _(img:MetaArray):
def _(img: MetaArray):
out = np.squeeze(img.value).view(img.__class__)
new_axes = "".join(a for a in img.axes if img.sizeof(a) > 1)
safe_set_info(out, img, "squeeze", new_axes)
return out

@MetaArray.implements(np.take)
def _(a:MetaArray, indices, axis=None, out=None, mode="raise"):
def _(a: MetaArray, indices, axis=None, out=None, mode="raise"):
new_axes = del_axis(a.axes, axis)
if isinstance(axis, str):
axis = a.axes.find(axis)
Expand All @@ -39,7 +39,7 @@ def _(a:MetaArray, indices, axis=None, out=None, mode="raise"):
return out

@MetaArray.implements(np.stack)
def _(imgs:list[MetaArray], axis="c", dtype=None):
def _(imgs: list[MetaArray], axis="c", dtype=None):
"""
Create stack image from list of images.
Expand Down Expand Up @@ -97,7 +97,7 @@ def _(imgs:list[MetaArray], axis="c", dtype=None):
return out

@MetaArray.implements(np.concatenate)
def _(imgs, axis="c", dtype=None, casting="same_kind"):
def _(imgs: list[MetaArray], axis="c", dtype=None, casting="same_kind"):
if not isinstance(axis, (int, str)):
raise TypeError(f"`axis` must be int or str, but got {type(axis)}")
axis = imgs[0].axisof(axis)
Expand All @@ -107,7 +107,7 @@ def _(imgs, axis="c", dtype=None, casting="same_kind"):
return out

@MetaArray.implements(np.block)
def _(imgs):
def _(imgs: list[MetaArray]):
def _recursive_view(obj):
if isinstance(obj, MetaArray):
return obj.value
Expand All @@ -130,23 +130,23 @@ def _recursive_get0(obj):


@MetaArray.implements(np.zeros_like)
def _(img, name:str=None):
def _(img: MetaArray, name: str = None):
out = np.zeros_like(img.value).view(img.__class__)
out._set_info(img, new_axes=img.axes)
if isinstance(name, str):
out.name = name
return out

@MetaArray.implements(np.empty_like)
def _(img, name:str=None):
def _(img: MetaArray, name:str = None):
out = np.empty_like(img.value).view(img.__class__)
out._set_info(img, new_axes=img.axes)
if isinstance(name, str):
out.name = name
return out

@MetaArray.implements(np.expand_dims)
def _(img, axis):
def _(img: MetaArray, axis):
if isinstance(axis, str):
new_axes = Axes(axis + str(img.axes))
new_axes.sort()
Expand All @@ -160,11 +160,11 @@ def _(img, axis):
return out

@MetaArray.implements(np.transpose)
def _(img, axes):
def _(img: MetaArray, axes: str):
return img.transpose(axes)

@MetaArray.implements(np.split)
def _(img, indices_or_sections, axis=0):
def _(img: MetaArray, indices_or_sections, axis=0):
if not isinstance(axis, (int, str)):
raise TypeError(f"`axis` must be int or str, but got {type(axis)}")
axis = img.axisof(axis)
Expand Down
93 changes: 47 additions & 46 deletions impy/arrays/imgarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,6 @@ def log_filter(self, sigma: nDFloat = 1, *, dims: Dims = None) -> ImgArray:
args=(sigma,)
)


@_docs.write_docs
@dims_to_spatial_axes
@same_dtype(asfloat=True)
Expand Down Expand Up @@ -1644,8 +1643,8 @@ def rolling_ball(self, radius: float = 30, prefilter:str="mean", *, return_bg:bo
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@record
def rof_filter(self, lmd:float=0.05, tol:float=1e-4, max_iter:int=50, *, dims: Dims = None,
update: bool = False) -> ImgArray:
def rof_filter(self, lmd: float = 0.05, tol: float = 1e-4, max_iter: int = 50, *,
dims: Dims = None, update: bool = False) -> ImgArray:
"""
Rudin-Osher-Fatemi's total variation denoising.
Expand Down Expand Up @@ -2765,41 +2764,8 @@ def local_dft(self, key: str = "", upsample_factor: nDInt = 1, *, double_precisi
key += f";{a}=0:{self.sizeof(a)}"
if key.startswith(";"):
key = key[1:]
slices = axis_targeted_slicing(np.empty((1,)*ndim), dims, key)

def wave_num(sl: slice, s: int, uf: int) -> xp.ndarray:
"""
A function that makes wave number vertical vector. Returned vector will
be [k/s, (k + 1/uf)/s, (k + 2/uf)/s, ...] (where k = sl.start)
Parameters
----------
sl : slice
Slice that specify which part of the image will be transformed.
s : int
Size along certain dimension.
uf : int
Up-sampling factor of certain dimension.
"""
start = 0 if sl.start is None else sl.start
stop = s if sl.stop is None else sl.stop

if sl.start and sl.stop and start < 0 and stop > 0:
# like "x=-5:5"
pass
else:
if -s < start < 0:
start += s
elif not 0 <= start < s:
raise ValueError(f"Invalid value encountered in key {key}.")
if -s < stop <= 0:
stop += s
elif not 0 < stop <= s:
raise ValueError(f"Invalid value encountered in key {key}.")

n = stop - start
return xp.linspace(start, stop, n*uf, endpoint=False)[:, xp.newaxis]

slices = axis_targeted_slicing(np.empty((1,)*ndim), dims, key)
dtype = np.complex128 if double_precision else np.complex64

# Calculate exp(-ikx)
Expand Down Expand Up @@ -3242,7 +3208,8 @@ def skeletonize(self, radius: float = 0, *, dims: Dims = None, update=False) ->
@_docs.write_docs
@dims_to_spatial_axes
@record(only_binary=True)
def count_neighbors(self, *, connectivity:int=None, mask:bool=True, dims: Dims = None) -> ImgArray:
def count_neighbors(self, *, connectivity: int = None, mask: bool = True,
dims: Dims = None) -> ImgArray:
"""
Count the number or neighbors of binary images. This function can be used for cross section
or branch detection. Only works for binary images.
Expand Down Expand Up @@ -3282,7 +3249,7 @@ def count_neighbors(self, *, connectivity:int=None, mask:bool=True, dims: Dims =
@_docs.write_docs
@dims_to_spatial_axes
@record(only_binary=True)
def remove_skeleton_structure(self, structure:str="tip", *, connectivity:int=None,
def remove_skeleton_structure(self, structure: str = "tip", *, connectivity: int = None,
dims: Dims = None, update: bool = False) -> ImgArray:
"""
Remove certain structure from skeletonized images.
Expand Down Expand Up @@ -3412,8 +3379,8 @@ def lineprops(self, src: Coords, dst: Coords, func: str|Callable = "mean", *,

@dims_to_spatial_axes
@record(need_labels=True)
def watershed(self, coords:MarkerFrame=None, *, connectivity:int=1, input:str="distance",
min_distance:float=2, dims: Dims = None) -> Label:
def watershed(self, coords: MarkerFrame = None, *, connectivity: int = 1, input: str = "distance",
min_distance: float = 2, dims: Dims = None) -> Label:
"""
Label segmentation using watershed algorithm.
Expand Down Expand Up @@ -3502,7 +3469,7 @@ def random_walker(self, beta: float = 130, mode: str = "cg_j", tol: float = 1e-3
self.labels._set_info(self, "random_walker")
return self.labels

def label_threshold(self, thr: float|str = "otsu", *, dims: Dims = None, **kwargs) -> Label:
def label_threshold(self, thr: float | str = "otsu", *, dims: Dims = None, **kwargs) -> Label:
"""
Make labels with threshold(). Be sure that keyword argument ``dims`` can be
different (in most cases for >4D images) between threshold() and label().
Expand Down Expand Up @@ -3589,8 +3556,8 @@ def pathprops(self, paths: PathFrame, properties: str|Callable|Iterable[str|Call
return out

@record(append_history=False, need_labels=True)
def regionprops(self, properties: tuple[str,...]|str = ("mean_intensity",), *,
extra_properties: Iterable[Callable]|None = None) -> DataDict[str, PropArray]:
def regionprops(self, properties: Iterable[str] | str = ("mean_intensity",), *,
extra_properties: Iterable[Callable] | None = None) -> DataDict[str, PropArray]:
"""
Run skimage's regionprops() function and return the results as PropArray, so
that you can access using flexible slicing. For example, if a tcyx-image is
Expand Down Expand Up @@ -3976,7 +3943,7 @@ def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *, zero_a
@_docs.write_docs
@dims_to_spatial_axes
@record(append_history=False)
def estimate_sigma(self, *, squeeze: bool = True, dims: Dims = None) -> PropArray|float:
def estimate_sigma(self, *, squeeze: bool = True, dims: Dims = None) -> PropArray | float:
"""
Wavelet-based estimation of Gaussian noise.
Expand Down Expand Up @@ -4377,4 +4344,38 @@ def check_filter_func(f):
f = lambda x: True
elif not callable(f):
raise TypeError("`filt` must be callable.")
return f
return f


def wave_num(sl: slice, s: int, uf: int) -> xp.ndarray:
"""
A function that makes wave number vertical vector. Returned vector will
be [k/s, (k + 1/uf)/s, (k + 2/uf)/s, ...] (where k = sl.start)
Parameters
----------
sl : slice
Slice that specify which part of the image will be transformed.
s : int
Size along certain dimension.
uf : int
Up-sampling factor of certain dimension.
"""
start = 0 if sl.start is None else sl.start
stop = s if sl.stop is None else sl.stop

if sl.start and sl.stop and start < 0 and stop > 0:
# like "x=-5:5"
pass
else:
if -s < start < 0:
start += s
elif not 0 <= start < s:
raise ValueError(f"Invalid value encountered in slice {sl}.")
if -s < stop <= 0:
stop += s
elif not 0 < stop <= s:
raise ValueError(f"Invalid value encountered in slice {sl}.")

n = stop - start
return xp.linspace(start, stop, n*uf, endpoint=False)[:, xp.newaxis]
Loading

0 comments on commit 49146db

Please sign in to comment.