From ae77c86e4ac93b5017f5d3d24640809a647fd5d5 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sat, 16 Mar 2024 21:46:55 +0300 Subject: [PATCH] doc: added safe typing to dtcwt --- pylops/signalprocessing/dtcwt.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/pylops/signalprocessing/dtcwt.py b/pylops/signalprocessing/dtcwt.py index 4c592f9a..da149585 100644 --- a/pylops/signalprocessing/dtcwt.py +++ b/pylops/signalprocessing/dtcwt.py @@ -1,6 +1,6 @@ __all__ = ["DTCWT"] -from typing import Union +from typing import Any, Union import numpy as np @@ -15,6 +15,10 @@ if dtcwt_message is None: import dtcwt + pyramid_type = dtcwt.numpy.common.Pyramid +else: + pyramid_type = Any + class DTCWT(LinearOperator): r"""Dual-Tree Complex Wavelet Transform @@ -122,7 +126,11 @@ def __init__( name=name, ) - def _interpret_coeffs(self, dims, axis): + def _interpret_coeffs( + self, + dims: Union[int, InputDimsLike], + axis: int, + ) -> None: x = np.ones(dims[axis]) pyr = self._transform.forward( x, nlevels=self.level, include_scale=self.include_scale @@ -134,16 +142,20 @@ def _interpret_coeffs(self, dims, axis): self.highpass_sizes.append(_h.size) self.coeff_array_size += _h.size - def _nd_to_2d(self, arr_nd): + def _nd_to_2d(self, arr_nd: NDArray) -> NDArray: arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze() return arr_2d - def _coeff_to_array(self, pyr): # cannot use dtcwt types as it may not be installed + def _coeff_to_array( + self, pyr: pyramid_type + ) -> NDArray: # cannot use dtcwt types as it may not be installed highpass_coeffs = np.vstack([h for h in pyr.highpasses]) coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0) return coeffs - def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be installed + def _array_to_coeff( + self, X: NDArray + ) -> pyramid_type: # cannot use dtcwt types as it may not be installed lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims)) _ptr = 0 highpasses = () @@ -154,7 +166,9 @@ def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be install highpasses += (_h,) return dtcwt.Pyramid(lowpass, highpasses) - def get_pyramid(self, x): # cannot use dtcwt types as it may not be installed + def get_pyramid( + self, x: NDArray + ) -> pyramid_type: # cannot use dtcwt types as it may not be installed """Return Pyramid object from flat real-valued array""" return self._array_to_coeff(x[0] + 1j * x[1])