Skip to content

Commit

Permalink
fix: force dtype for shift operator inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Nov 29, 2023
1 parent 294f951 commit f130a96
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pylops/signalprocessing/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def Shift(
shift = _value_or_sized_to_array(shift)

if shift.size == 1:
shift = np.exp(-1j * 2 * np.pi * Fop.f * shift)
shift = np.exp(-1j * 2 * np.pi * Fop.f * shift).astype(Fop.cdtype)
Sop = Diagonal(shift, dims=dimsdiag, axis=axis, dtype=Fop.cdtype)
else:
# add dimensions to shift to match dimensions of model and data
Expand All @@ -120,7 +120,7 @@ def Shift(
sdims = np.ones(shift.ndim + 1, dtype=int)
sdims[:axis] = shift.shape[:axis]
sdims[axis + 1 :] = shift.shape[axis:]
shift = np.exp(-1j * 2 * np.pi * f * shift.reshape(sdims))
shift = np.exp(-1j * 2 * np.pi * f * shift.reshape(sdims)).astype(Fop.cdtype)
Sop = Diagonal(shift, dtype=Fop.cdtype)
Op = Fop.H * Sop * Fop
Op.dims = Op.dimsd = Fop.dims
Expand Down
2 changes: 1 addition & 1 deletion pylops/waveeqprocessing/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
# Define shift operator
self.shifts = (times // self.dt).astype(np.int32)
diff = (times / self.dt - self.shifts) * self.dt
diff = np.repeat(diff[:, np.newaxis], self.nr, axis=1)
diff = np.repeat(diff[:, np.newaxis], self.nr, axis=1).astype(self.dtype)
self.ShiftOp = Shift(
(self.ns, self.nr, self.nt + 1),
diff,
Expand Down

0 comments on commit f130a96

Please sign in to comment.