Skip to content

Commit

Permalink
Merge pull request PyLops#549 from mrava87/patch-slidingdtype
Browse files Browse the repository at this point in the history
fix: ensure sliding ops work with fp32
  • Loading branch information
mrava87 authored Nov 22, 2023
2 parents 99f5aa8 + 01568da commit cc52ae8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
6 changes: 4 additions & 2 deletions pylops/signalprocessing/sliding1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def Sliding1D(

# create tapers
if tapertype is not None:
tap = taper(nwin, nover, tapertype=tapertype)
tap = taper(nwin, nover, tapertype=tapertype).astype(Op.dtype)
tapin = tap.copy()
tapin[:nover] = 1
tapend = tap.copy()
Expand All @@ -172,7 +172,9 @@ def Sliding1D(
if tapertype is None:
OOp = BlockDiag([Op for _ in range(nwins)])
else:
OOp = BlockDiag([Diagonal(taps[itap].ravel()) * Op for itap in range(nwins)])
OOp = BlockDiag(
[Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
)

combining = HStack(
[
Expand Down
6 changes: 4 additions & 2 deletions pylops/signalprocessing/sliding2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def Sliding2D(

# create tapers
if tapertype is not None:
tap = taper2d(dimsd[1], nwin, nover, tapertype=tapertype)
tap = taper2d(dimsd[1], nwin, nover, tapertype=tapertype).astype(Op.dtype)
tapin = tap.copy()
tapin[:nover] = 1
tapend = tap.copy()
Expand All @@ -206,7 +206,9 @@ def Sliding2D(
if tapertype is None:
OOp = BlockDiag([Op for _ in range(nwins)])
else:
OOp = BlockDiag([Diagonal(taps[itap].ravel()) * Op for itap in range(nwins)])
OOp = BlockDiag(
[Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
)

combining = HStack(
[
Expand Down
7 changes: 5 additions & 2 deletions pylops/signalprocessing/sliding3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,16 @@ def Sliding3D(

# create tapers
if tapertype is not None:
tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype)
tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype).astype(Op.dtype)

# transform to apply
if tapertype is None:
OOp = BlockDiag([Op for _ in range(nwins)], nproc=nproc)
else:
OOp = BlockDiag([Diagonal(tap.ravel()) * Op for _ in range(nwins)], nproc=nproc)
OOp = BlockDiag(
[Diagonal(tap.ravel(), dtype=Op.dtype) * Op for _ in range(nwins)],
nproc=nproc,
)

hstack = HStack(
[
Expand Down

0 comments on commit cc52ae8

Please sign in to comment.