diff --git a/pylops/utils/tapers.py b/pylops/utils/tapers.py index ea109a00..52c95e8e 100644 --- a/pylops/utils/tapers.py +++ b/pylops/utils/tapers.py @@ -7,7 +7,7 @@ "tapernd", ] -from typing import Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -59,6 +59,7 @@ def cosinetaper( nmask: int, ntap: int, square: bool = False, + exponent: Optional[float] = None, ) -> npt.ArrayLike: r"""1D Cosine or Cosine square taper @@ -71,8 +72,10 @@ def cosinetaper( Number of samples of mask ntap : :obj:`int` Number of samples of hanning tapering at edges - square : :obj:`bool` - Cosine square taper (``True``)or Cosine taper (``False``) + square : :obj:`bool`, optional + Cosine square taper (``True``) or Cosine taper (``False``) + exponent : :obj:`float`, optional + Exponent to apply to Cosine taper. If provided, takes precedence over ``square`` Returns ------- @@ -81,7 +84,8 @@ def cosinetaper( """ ntap = 0 if ntap == 1 else ntap - exponent = 1 if not square else 2 + if exponent is None: + exponent = 1 if not square else 2 cos_win = ( 0.5 * ( @@ -123,7 +127,8 @@ def taper( ntap : :obj:`int` Number of samples of hanning tapering at edges tapertype : :obj:`str`, optional - Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``) + Type of taper (``hanning``, ``cosine``, + ``cosinesquare``, ``cosinesqrt`` or ``None``) Returns ------- @@ -137,6 +142,8 @@ def taper( tpr_1d = cosinetaper(nmask, ntap, False) elif tapertype == "cosinesquare": tpr_1d = cosinetaper(nmask, ntap, True) + elif tapertype == "cosinesqrt": + tpr_1d = cosinetaper(nmask, ntap, False, 0.5) else: tpr_1d = np.ones(nmask) return tpr_1d @@ -214,7 +221,7 @@ def taper3d( Number of samples of tapering at edges of first and second dimensions tapertype : :obj:`int` Type of taper (``hanning``, ``cosine``, - ``cosinesquare`` or ``None``) + ``cosinesquare``, ``cosinesqrt`` or ``None``) Returns ------- @@ -236,6 +243,9 @@ def taper3d( elif tapertype == "cosinesquare": tpr_y = cosinetaper(nmasky, ntapy, True) tpr_x = cosinetaper(nmaskx, ntapx, True) + elif tapertype == "cosinesqrt": + tpr_y = cosinetaper(nmasky, ntapy, False, 0.5) + tpr_x = cosinetaper(nmaskx, ntapx, False, 0.5) else: tpr_y = np.ones(nmasky) tpr_x = np.ones(nmaskx) @@ -266,7 +276,7 @@ def tapernd( Number of samples of tapering at edges of every dimension tapertype : :obj:`int` Type of taper (``hanning``, ``cosine``, - ``cosinesquare`` or ``None``) + ``cosinesquare``, ``cosinesqrt`` or ``None``) Returns ------- @@ -282,6 +292,8 @@ def tapernd( tpr = [cosinetaper(nm, nt, False) for nm, nt in zip(nmask, ntap)] elif tapertype == "cosinesquare": tpr = [cosinetaper(nm, nt, True) for nm, nt in zip(nmask, ntap)] + elif tapertype == "cosinesqrt": + tpr = [cosinetaper(nm, nt, False, 0.5) for nm, nt in zip(nmask, ntap)] else: tpr = [np.ones(nm) for nm in nmask] diff --git a/pytests/test_tapers.py b/pytests/test_tapers.py index 823097a4..320af932 100755 --- a/pytests/test_tapers.py +++ b/pytests/test_tapers.py @@ -40,9 +40,23 @@ "ntap": (4, 6), "tapertype": "cosinesquare", } # cosinesquare, even samples and taper +par7 = { + "nt": 21, + "nspat": (11, 13), + "ntap": (3, 5), + "tapertype": "cosinesqrt", +} # cosinesqrt, odd samples and taper +par8 = { + "nt": 20, + "nspat": (12, 16), + "ntap": (4, 6), + "tapertype": "cosinesqrt", +} # cosinesqrt, even samples and taper -@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +@pytest.mark.parametrize( + "par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)] +) def test_taper2d(par): """Create taper wavelet and check size and values""" tap = taper2d(par["nt"], par["nspat"][0], par["ntap"][0], par["tapertype"]) @@ -54,7 +68,9 @@ def test_taper2d(par): assert_array_equal(tap[par["nspat"][0] // 2], np.ones(par["nt"])) -@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)]) +@pytest.mark.parametrize( + "par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)] +) def test_taper3d(par): """Create taper wavelet and check size and values""" tap = taper3d(par["nt"], par["nspat"], par["ntap"], par["tapertype"])